perf: use bigger threadgroups in mm

This commit is contained in:
Aaron Miller 2023-10-11 16:02:53 -07:00 committed by cebtenzzre
parent 46385ee0d5
commit 3327d84a7f
2 changed files with 2 additions and 2 deletions

View File

@ -1148,7 +1148,7 @@ void ggml_vk_mul_mat_mat_q4_x(const std::vector<uint32_t>& spirv,
} else { } else {
s_algo = komputeManager()->getAlgorithm(__func__); s_algo = komputeManager()->getAlgorithm(__func__);
s_algo->setTensors({inA, inB, out}); s_algo->setTensors({inA, inB, out});
s_algo->setWorkgroup({unsigned(ne01), s_algo->setWorkgroup({unsigned(ne01)/32,
unsigned(ne11), unsigned(ne11),
unsigned(std::max(ne12, ne02)), unsigned(std::max(ne12, ne02)),
}); });

View File

@ -14,7 +14,7 @@
#extension GL_KHR_shader_subgroup_arithmetic : require #extension GL_KHR_shader_subgroup_arithmetic : require
#extension GL_EXT_debug_printf : enable #extension GL_EXT_debug_printf : enable
// layout(local_size_x = 8) in; layout(local_size_x = 32) in;
layout(binding = 0) readonly buffer tensorInA { uint8_t inA[]; }; layout(binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
layout(binding = 1) readonly buffer tensorInB { float inB[]; }; layout(binding = 1) readonly buffer tensorInB { float inB[]; };