diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 81c795ea6..0619f9855 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1943,7 +1943,7 @@ static void ggml_metal_encode_node( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline; const int nsg = 2; - const int r0pt = 2; + const int r0pt = 1; const int r1pt = 1; const int nxpsg = ne11 > 1 ? 8 : 32; const int nypsg = 32/nxpsg; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index e82071f2f..1a5ac2dd5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -165,6 +165,26 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg reg = (type4x4) reg_f; } +template +void dequantize_q4_0x(device const block_q4_0 *xb, short il, thread type4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const half d = xb->d; + + for (int i = 0; i < 4; i++) { + reg[i] = qs[0]; + } +} + +template +void dequantize_q8_0x(device const block_q8_0 *xb, short il, thread type4 & reg) { + device const int8_t * qs = ((device const int8_t *)xb->qs); + const half d = xb->d; + + for (int i = 0; i < 4; i++) { + reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d); + } +} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -1749,8 +1769,8 @@ void kernel_mul_mv_ext_q8_0_f32_impl( ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short chpt = 1; - const short r0pt = 2; + const short chpt = 4; + const short r0pt = 1; //const short nxpsg = (32); const short nypsg = (32/nxpsg)*r0pt; @@ -1771,36 +1791,31 @@ void kernel_mul_mv_ext_q8_0_f32_impl( device const block_q8_0 * xq[r0pt]; for (short ir0 = 0; ir0 < r0pt; ++ir0) { - xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/2 : (device const block_q8_0 *) src0; + //xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0; + xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0; } - device const float4x4 * y4x4 = (device const float4x4 *) (src1 + offset1) + chpt*tx; + //device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx; + device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx; float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f }; - for (int iib = 0; (16*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) { - float4x4 lx; - -#pragma unroll(2) + for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) { for (short ir0 = 0; ir0 < r0pt; ++ir0) { -#pragma unroll +#pragma unroll(4) for (short ch = 0; ch < chpt; ++ch) { - dequantize_q8_0(xq[ir0] + ch/2, (chpt*tx + ch)%2, lx); + float4 lx; - const float4x4 ly = y4x4[ch]; + dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx); - sumf[ir0] += - dot(lx[0], ly[0]) + - dot(lx[1], ly[1]) + - dot(lx[2], ly[2]) + - dot(lx[3], ly[3]); + sumf[ir0] += dot(lx, y4[ch*nxpsg]); } } - y4x4 += ((16*chpt)*nxpsg)/16; + y4 += ((4*chpt)*nxpsg)/4; for (short ir0 = 0; ir0 < r0pt; ++ir0) { - xq[ir0] += ((16*chpt)*nxpsg)/32; + xq[ir0] += ((4*chpt)*nxpsg)/32; } }