diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 2326f56b5..86794e886 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -785,7 +785,8 @@ void ggml_vk_soft_max(kp::Sequence& seq, std::shared_ptr s_algo = nullptr; if (!komputeManager()->hasAlgorithm(__func__)) { - const uint32_t local_x = ggml_vk_current_device().subgroupSize; + // FIXME: The softmax kernel needs to be fixed to use the subgroupsize which can vary by device + const uint32_t local_x = 32; s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {in, out}, spirv, {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {local_x}, {pushConsts}); } else { s_algo = komputeManager()->getAlgorithm(__func__); @@ -981,8 +982,8 @@ void ggml_vk_mul_mat_q6_k(kp::Sequence& seq, std::shared_ptr s_algo = nullptr; if (!komputeManager()->hasAlgorithm(__func__)) { -// const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2; - s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {2,32}, {pushConsts}); + const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2; + s_algo = komputeManager()->algorithm(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts}); } else { s_algo = komputeManager()->getAlgorithm(__func__); s_algo->setTensors({inA, inB, out}); diff --git a/kompute/op_mul_mat_q6_k.comp b/kompute/op_mul_mat_q6_k.comp index c7b9aa753..6148053b2 100644 --- a/kompute/op_mul_mat_q6_k.comp +++ b/kompute/op_mul_mat_q6_k.comp @@ -44,31 +44,38 @@ void main() { const uint r1 = gl_WorkGroupID.y; const uint r2 = gl_WorkGroupID.z; - const uint row = 2 * r0 + gl_SubgroupID; + const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID); const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0); const uint x = row * nb + offset0; // Based from inA without base offset const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB float sumf = 0; - const uint tid = gl_SubgroupInvocationID/2; - const uint ix = gl_SubgroupInvocationID%2; - const uint ip = tid/8; // 0 or 1 - const uint il = tid%8; - const uint n = 4; - const uint l0 = n*il; - const uint is = 8*ip + l0/16; + // bits of invocation ID for gl_SubgroupSize=32: + // x x x x x + // 4 3 2 1 0 + // ( tid ) ix + // ip ( il ) + + const uint block_stride = gl_SubgroupSize / 16; // number of blocks each subgroup processes + const uint tid = gl_SubgroupInvocationID/block_stride; // first block_stride groups have tid=0 + const uint ix = gl_SubgroupInvocationID%block_stride; // first block is 0..block_stride-1 + const uint ip = tid/8; // first or second half of block (0 or 1) + const uint il = tid%8; // each half has 8 parts, one per scale + const uint n = 4; // 4 scales at a time (and 4 sums) + const uint l0 = n*il; // offset into half-block, 0..28 + const uint is = 8*ip + l0/16; // 0, 1, 8, 9 const uint y_offset = 128*ip + l0; const uint q_offset_l = 64*ip + l0; const uint q_offset_h = 32*ip + l0; - for (uint i = ix; i < nb; i += 2) { + for (uint i = ix; i < nb; i += block_stride) { const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff; const uint qlIndex = q_offset_l; - const uint q2Index = qlIndex + 32; + const uint q2Index = qlIndex + QK_K/8; const uint qhIndex = q_offset_h; const uint y = yy + i * QK_K + y_offset; diff --git a/kompute/op_mul_mv_q_n.comp b/kompute/op_mul_mv_q_n.comp index 15bcbf765..a9b64fe16 100644 --- a/kompute/op_mul_mv_q_n.comp +++ b/kompute/op_mul_mv_q_n.comp @@ -7,6 +7,9 @@ */ void main() { + if (gl_SubgroupInvocationID > 31) + return; + const uint nb = uint(pcs.ne00/BLOCKS_IN_QUANT); const uint r0 = gl_WorkGroupID.x; const uint r1 = gl_WorkGroupID.y; @@ -28,13 +31,13 @@ void main() { // gl_NumSubgroups, gl_SubgroupID, gl_SubgroupInvocationID, gl_SubgroupSize, // gl_WorkGroupSize.x, gl_WorkGroupSize.y, gl_WorkGroupSize.z); - for (uint ib = ix; ib < nb; ib += gl_SubgroupSize/2) { + for (uint ib = ix; ib < nb; ib += 16) { for (int row = 0; row < N_ROWS; row++) { const uint block_index = x + ib + row * nb; sumf[row] += block_q_n_dot_y(block_index, yb, il); } - yb += BLOCKS_IN_QUANT * gl_SubgroupSize/2; + yb += BLOCKS_IN_QUANT * 16; } for (int row = 0; row < N_ROWS; ++row) { diff --git a/kompute/op_softmax.comp b/kompute/op_softmax.comp index d21577ac0..30b6f0260 100644 --- a/kompute/op_softmax.comp +++ b/kompute/op_softmax.comp @@ -24,6 +24,9 @@ layout(push_constant) uniform PushConstants { } pcs; void main() { + if (gl_SubgroupInvocationID > 31) + return; + const uint i03 = gl_WorkGroupID.z; const uint i02 = gl_WorkGroupID.y; const uint i01 = gl_WorkGroupID.x; @@ -34,21 +37,21 @@ void main() { // parallel max float localMax = uintBitsToFloat(0xFF800000); - for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += gl_SubgroupSize) { + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { localMax = max(localMax, in_[psrc0 + i00]); } float max_ = subgroupMax(localMax); // parallel sum float localSum = 0.0f; - for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += gl_SubgroupSize) { + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { const float exp_psrc0 = exp(in_[psrc0 + i00] - max_); localSum += exp_psrc0; out_[pdst + i00] = exp_psrc0; } const float sum = subgroupAdd(localSum); - for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += gl_SubgroupSize) { + for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) { out_[pdst + i00] /= sum; } }