From c704d581126b236036fbbed77ce77798c06ce043 Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sat, 30 Nov 2024 18:59:04 +0000 Subject: [PATCH] Improve performance with better q4_k and q5_k dequant and store unrolling --- ggml/src/ggml-vulkan/ggml-vulkan.cpp | 3 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 54 +++++++++++-------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index be198bd7a..a5f075ccd 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -1861,7 +1861,8 @@ static vk_device ggml_vk_get_device(size_t idx) { if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 && (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 && (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 && - (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32 + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup ) { device->coop_mat_m = prop.MSize; device->coop_mat_n = prop.NSize; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index f5dee7acb..9a4ebffaf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -365,15 +365,20 @@ void main() { const vec2 loadd = vec2(data_a[ib].d); - uint8_t sc; - uint8_t mbyte; - if (is < 4) { - sc = uint8_t(data_a[ib].scales[is ] & 63); - mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); - } else { - sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); - mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); - } + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const float d = loadd.x * sc; const float m = -loadd.y * mbyte; @@ -396,15 +401,20 @@ void main() { const vec2 loadd = vec2(data_a[ib].d); - uint8_t sc; - uint8_t mbyte; - if (is < 4) { - sc = uint8_t(data_a[ib].scales[is ] & 63); - mbyte = uint8_t(data_a[ib].scales[is + 4] & 63); - } else { - sc = uint8_t((data_a[ib].scales[is + 4] & 0xF) | ((data_a[ib].scales[is - 4] >> 6) << 4)); - mbyte = uint8_t((data_a[ib].scales[is + 4] >> 4) | ((data_a[ib].scales[is ] >> 6) << 4)); - } + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + const float d = loadd.x * sc; const float m = -loadd.y * mbyte; @@ -547,8 +557,8 @@ void main() { #ifdef COOPMAT #ifdef MUL_MAT_ID - for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); [[unroll]] for (uint col = 0; col < BN; col += storestride) { @@ -564,8 +574,8 @@ void main() { #else const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float - for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { - for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; if (is_aligned && is_in_bounds) {