From c1fd64548d2c8d42eaedae940c619a6cf2d9741f Mon Sep 17 00:00:00 2001 From: Aaron Miller Date: Fri, 13 Oct 2023 13:14:36 -0700 Subject: [PATCH] attempted speedups 2 --- ggml-vulkan.cpp | 24 +++++++++++++----------- kompute/op_mul_mat_mat_f16.comp | 12 ++++++++---- kompute/op_mul_mat_mat_f32.comp | 21 ++++++++++++++------- kompute/op_mul_mat_mat_q6_k.comp | 2 +- 4 files changed, 36 insertions(+), 23 deletions(-) diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 67270a3c7..010f49226 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -989,26 +989,27 @@ void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq, nb1, nb2 }; + const uint32_t local_x = ggml_vk_current_device().subgroupSize; std::shared_ptr s_algo = nullptr; if (!komputeManager()->hasAlgorithm(__func__)) { - //std::cerr << "init f32 matmat shader" << std::endl; - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), + s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne11), - unsigned(ne12)}, - {}, + unsigned(std::max(ne12, ne02)) + }, + {local_x}, {pushConsts}); } else { s_algo = komputeManager()->getAlgorithm(__func__); s_algo->setTensors({inA, inB, out}); s_algo->setWorkgroup({unsigned(ne01), unsigned(ne11), - unsigned(std::max(ne12, ne02))}); + unsigned(std::max(ne12, ne02)), + }); s_algo->setPushConstants({pushConsts}); s_algo->updateDescriptors(s_kompute_context->pool.get()); } - //seq.record({out}); seq.record(s_algo); } @@ -1038,15 +1039,16 @@ void ggml_vk_mul_mat_mat_f16(kp::Sequence& seq, nb1, nb2 }; + const uint32_t local_x = ggml_vk_current_device().subgroupSize; std::shared_ptr s_algo = nullptr; if (!komputeManager()->hasAlgorithm(__func__)) { - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), + s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned(ne01), unsigned(ne11), unsigned(std::max(ne12, ne02)) }, - {}, + {local_x}, {pushConsts}); } else { s_algo = komputeManager()->getAlgorithm(__func__); @@ -1141,7 +1143,7 @@ void ggml_vk_mul_mat_mat_q6_k( if (!komputeManager()->hasAlgorithm(__func__)) { s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, - {unsigned(ne01)/32, + {unsigned(ne01)/256, unsigned(ne11), unsigned(std::max(ne12, ne02)) }, @@ -1150,7 +1152,7 @@ void ggml_vk_mul_mat_mat_q6_k( } else { s_algo = komputeManager()->getAlgorithm(__func__); s_algo->setTensors({inA, inB, out}); - s_algo->setWorkgroup({unsigned(ne01)/32, + s_algo->setWorkgroup({unsigned(ne01)/256, unsigned(ne11), unsigned(std::max(ne12, ne02)), }); @@ -1192,7 +1194,7 @@ void ggml_vk_mul_mat_mat_q4_x(const std::vector& spirv, {unsigned(ne01), unsigned(ne11), unsigned(std::max(ne12, ne02))}, - {local_x, 4}, + {local_x, 1}, {pushConsts}); } else { s_algo = komputeManager()->getAlgorithm(__func__); diff --git a/kompute/op_mul_mat_mat_f16.comp b/kompute/op_mul_mat_mat_f16.comp index b62f06d10..03872fed5 100644 --- a/kompute/op_mul_mat_mat_f16.comp +++ b/kompute/op_mul_mat_mat_f16.comp @@ -14,7 +14,8 @@ #extension GL_KHR_shader_subgroup_arithmetic : require #extension GL_EXT_debug_printf : enable -// layout(local_size_x = 8) in; +// device subgroup size +layout (local_size_x_id = 0) in; layout(binding = 0) readonly buffer tensorInA { float16_t inA[]; }; layout(binding = 1) readonly buffer tensorInB { float inB[]; }; @@ -40,7 +41,7 @@ pcs; void main() { - uvec3 gid = gl_GlobalInvocationID; + uvec3 gid = gl_WorkGroupID; uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z; uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z; @@ -48,9 +49,12 @@ void main() { const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 2 + pcs.inAOff; // Based from inA const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB float sum = 0.0f; - for (uint i = 0; i < pcs.ne00; i ++) { + for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) { sum += float(inA[x+i]) * float(inB[y+i]); } - out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = sum; + const float all_sum = subgroupAdd(sum); + if (subgroupElect()) { + out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum; + } } \ No newline at end of file diff --git a/kompute/op_mul_mat_mat_f32.comp b/kompute/op_mul_mat_mat_f32.comp index 6234322ca..a2dba0560 100644 --- a/kompute/op_mul_mat_mat_f32.comp +++ b/kompute/op_mul_mat_mat_f32.comp @@ -14,7 +14,8 @@ #extension GL_KHR_shader_subgroup_arithmetic : require #extension GL_EXT_debug_printf : enable -// layout(local_size_x = 8) in; +// device subgroup size +layout (local_size_x_id = 0) in; layout(binding = 0) readonly buffer tensorInA { float inA[]; }; layout(binding = 1) readonly buffer tensorInB { float inB[]; }; @@ -40,14 +41,20 @@ pcs; void main() { - uvec3 gid = gl_GlobalInvocationID; + uvec3 gid = gl_WorkGroupID; - const uint x = (gid.x*pcs.nb01 + gid.z/(pcs.ne12/pcs.ne02)*pcs.nb02) / 4 + pcs.inAOff; // Based from inA - const uint y = (gid.y*pcs.nb11 + gid.z/(pcs.ne02/pcs.ne12)*pcs.nb12) / 4 + pcs.inBOff; // based from inB + uint bc_ab = pcs.ne12 > pcs.ne02 ? gid.z / (pcs.ne12 / pcs.ne02) : gid.z; + uint bc_ba = pcs.ne02 > pcs.ne12 ? gid.z / (pcs.ne02 / pcs.ne12) : gid.z; + + const uint x = (gid.x*pcs.nb01 + bc_ab*pcs.nb02) / 4 + pcs.inAOff; // Based from inA + const uint y = (gid.y*pcs.nb11 + bc_ba*pcs.nb12) / 4 + pcs.inBOff; // based from inB float sum = 0.0f; - for (uint i = 0; i < pcs.ne00; i ++) { + for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) { sum += float(inA[x+i]) * float(inB[y+i]); } - out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = sum; -} + const float all_sum = subgroupAdd(sum); + if (subgroupElect()) { + out_[gid.z*(pcs.nb2/4) + gid.y*(pcs.nb1/4) + gid.x + pcs.outOff] = all_sum; + } +} \ No newline at end of file diff --git a/kompute/op_mul_mat_mat_q6_k.comp b/kompute/op_mul_mat_mat_q6_k.comp index 127f17df6..8e3e44d7d 100644 --- a/kompute/op_mul_mat_mat_q6_k.comp +++ b/kompute/op_mul_mat_mat_q6_k.comp @@ -14,7 +14,7 @@ #extension GL_KHR_shader_subgroup_arithmetic : require #extension GL_EXT_debug_printf : enable -layout(local_size_x = 32) in; +layout(local_size_x = 256) in; layout(binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; layout(binding = 1) readonly buffer tensorInB { float inB[]; };