From 9622fbe373416b3e27d20a66478a9a9f8f6c5ac7 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Tue, 3 Dec 2024 19:52:36 +0000 Subject: [PATCH] Vulkan: Unroll more loops for more mul mat mat performance --- ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 3ffa8209e..d422070bf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -196,7 +196,7 @@ void main() { coopmat cache_b; coopmat sums[cms_per_row * cms_per_col]; - for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { sums[i] = coopmat(0.0f); } #else @@ -209,7 +209,7 @@ void main() { } #endif - [[dont_unroll]] for (uint block = start_k; block < end_k; block += BK) { + for (uint block = start_k; block < end_k; block += BK) { [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { #if defined(DATA_A_F32) || defined(DATA_A_F16) @@ -506,12 +506,12 @@ void main() { pos_b += BK / LOAD_VEC_B; #ifdef COOPMAT - for (uint i = 0; i < BK; i += TK) { - for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint i = 0; i < BK; i += TK) { + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { // Load from shared into cache coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); - for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); @@ -519,7 +519,7 @@ void main() { } } #else - for (uint i = 0; i < BK; i++) { + [[unroll]] for (uint i = 0; i < BK; i++) { // Load from shared into cache [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { [[unroll]] for (uint j = 0; j < TM; j++) {