From 06d4b21598da0162999b35429cfb567ed962d7ec Mon Sep 17 00:00:00 2001 From: Adam Treat Date: Mon, 2 Oct 2023 11:30:10 -0400 Subject: [PATCH] Fix offset into the qh and now we have working vulkan accelerated for gguff'd llama. --- kompute/op_mul_mat_q6_k.comp | 26 ++------------------------ 1 file changed, 2 insertions(+), 24 deletions(-) diff --git a/kompute/op_mul_mat_q6_k.comp b/kompute/op_mul_mat_q6_k.comp index 1e4ea37f8..c7b9aa753 100644 --- a/kompute/op_mul_mat_q6_k.comp +++ b/kompute/op_mul_mat_q6_k.comp @@ -32,28 +32,13 @@ layout (push_constant) uniform parameter { int gqa; } pcs; -block_q6_k get_unaligned_block_q6_k(uint index) { - block_q6_k fres; - [[unroll]] for (uint it = 0; it != QK_K / 2; it++) { - fres.ql[it] = inA[index + it]; - } - [[unroll]] for (uint it = 0; it != QK_K / 4; it++) { - fres.qh[it] = inA[index + QK_K/2 + it]; - } - [[unroll]] for (uint it = 0; it != QK_K / 16; it++) { - fres.scales[it] = int8_t(inA[index + QK_K/2 + QK_K/4 + it]); - } - fres.d = u8BufToFloat16(inA, index + QK_K/2 + QK_K/4 + QK_K/16); - return fres; -} - void main() { const uint8_t kmask1 = uint8_t(0x03); const uint8_t kmask2 = uint8_t(0x0C); const uint8_t kmask3 = uint8_t(0x30); const uint8_t kmask4 = uint8_t(0xC0); - const int nb = pcs.ne00/QK_K; + const uint nb = pcs.ne00/QK_K; const uint r0 = gl_WorkGroupID.x; const uint r1 = gl_WorkGroupID.y; @@ -81,8 +66,6 @@ void main() { for (uint i = ix; i < nb; i += 2) { const uint baseIndex = (x + i) * SIZE_OF_BLOCK + pcs.inAOff; -// const uint index = (x + i) * SIZE_OF_BLOCK + pcs.inAOff; -// const block_q6_k block = get_unaligned_block_q6_k(index); const uint qlIndex = q_offset_l; const uint q2Index = qlIndex + 32; @@ -91,13 +74,9 @@ void main() { float sums[4] = {0.0f, 0.0f, 0.0f, 0.0f}; for (uint l = 0; l < n; ++l) { - -// const uint8_t currentQ1 = block.ql[qlIndex + l]; -// const uint8_t currentQ2 = block.ql[q2Index + l]; -// const uint8_t currentQh = block.qh[qhIndex + l]; const uint8_t currentQ1 = inA[baseIndex + qlIndex + l]; const uint8_t currentQ2 = inA[baseIndex + q2Index + l]; - const uint8_t currentQh = inA[baseIndex + qhIndex + l]; + const uint8_t currentQh = inA[baseIndex + QK_K/2 + qhIndex + l]; sums[0] += inB[y+l+ 0] * (int8_t((currentQ1 & 0xF) | ((currentQh & kmask1) << 4)) - 32); sums[1] += inB[y+l+32] * (int8_t((currentQ2 & 0xF) | ((currentQh & kmask2) << 2)) - 32); @@ -105,7 +84,6 @@ void main() { sums[3] += inB[y+l+96] * (int8_t((currentQ2 >> 4) | ((currentQh & kmask4) >> 2)) - 32); } -// sumf += block.d * (sums[0] * block.scales[0+is] + sums[1] * block.scales[2+is] + sums[2] * block.scales[4+is] + sums[3] * block.scales[6+is]); float d = u8BufToFloat16(inA, baseIndex + QK_K/2 + QK_K/4 + QK_K/16); sumf += d * (sums[0] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + is]) + sums[1] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 2 + is]) + sums[2] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 4 + is]) + sums[3] * int8_t(inA[baseIndex + QK_K/2 + QK_K/4 + 6 + is])); }