From 3c8a2a83feb01390a67762eb37655815a86f3617 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 13 Nov 2024 11:04:04 +0200 Subject: [PATCH] shmem experiments --- ggml/src/ggml-metal/ggml-metal.m | 14 ++- ggml/src/ggml-metal/ggml-metal.metal | 133 +++++++++++++++++++++++---- 2 files changed, 123 insertions(+), 24 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 895008b7a..b26adce77 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1961,14 +1961,15 @@ static void ggml_metal_encode_node( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline; - const int nsg = 2; - const int r0pt = 1; + const int nsg = 4; + const int r0pt = 4; const int r1pt = 1; - const int nxpsg = ne11 > 1 ? 8 : 32; + //const int nxpsg = ne11 > 1 ? 8 : 32; + const int nxpsg = 32; const int nypsg = 32/nxpsg; const int nr0ptg = nypsg*r0pt*nsg; - //GGML_ASSERT(ne00%1024 == 0); + //GGML_ASSERT(ne00%4096 == 0); //GGML_ASSERT(ne01%nr0ptg == 0); //printf("ne01 = %lld, nr0ptg = %d, ne00 = %lld\n", ne01, nr0ptg, ne00); @@ -2003,6 +2004,11 @@ static void ggml_metal_encode_node( //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg); [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0ptg - 1)/nr0ptg, (ne11 + r1pt - 1)/r1pt, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + + [encoder setThreadgroupMemoryLength:2*8192 atIndex:0]; + + //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg); + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0ptg - 1)/nr0ptg, (ne11 + r1pt - 1)/r1pt, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; } else // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f9f1186f7..7662ba463 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -190,6 +190,27 @@ void dequantize_q8_0x(device const block_q8_0 *xb, short il, thread type4 & reg) } } +template +void dequantize_q8_0s(threadgroup const block_q8_0 * xb, short il, thread type4 & reg) { + threadgroup const int8_t * qs = ((threadgroup const int8_t *) xb->qs); + const float d = xb->d; + + for (int i = 0; i < 4; i++) { + reg[i] = (qs[4*(il%4) + i + 16*(il/4)]*d); + } +} + +//template +//type4 dequantize_q8_0x(device const int8_t * qs, float d, short il) { +// thread type4 reg; +// for (int i = 0; i < 4; i++) { +// reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d); +// //reg[i] = qs[i/2]; +// } +// +// return reg; +//} + template void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { const float d = xb->d; @@ -1778,12 +1799,13 @@ void kernel_mul_mv_ext_q8_0_f32_impl( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - const short chpt = 4; - const short r0pt = 1; + const short chpt = 8; + const short r0pt = 4; //const short nxpsg = (32); const short nypsg = (32/nxpsg)*r0pt; @@ -1802,10 +1824,12 @@ void kernel_mul_mv_ext_q8_0_f32_impl( const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q8_0 * xq[r0pt]; + device const block_q8_0 * xq0[r0pt]; for (short ir0 = 0; ir0 < r0pt; ++ir0) { //xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0; - xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0; + //xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0; + xq0[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) : (device const block_q8_0 *) src0; } //device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx; @@ -1813,23 +1837,63 @@ void kernel_mul_mv_ext_q8_0_f32_impl( float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f }; + threadgroup block_q8_0 * shmem_q = (threadgroup block_q8_0 *) shmem + (((4*chpt)*nxpsg)/32)*r0pt*sgitg; + for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) { + //shmem_q[(4*chpt)*(tiisg/16 ) + tiisg%16] = xq0[tiisg/16 ][16*iib + tiisg%16]; + //shmem_q[(4*chpt)*(tiisg/16 + 2) + tiisg%16] = xq0[tiisg/16 + 2][16*iib + tiisg%16]; + //shmem_q[(4*chpt)*(tiisg/16 + 4) + tiisg%16] = xq0[tiisg/16 + 4][16*iib + tiisg%16]; + //shmem_q[(4*chpt)*(tiisg/16 + 6) + tiisg%16] = xq0[tiisg/16 + 6][16*iib + tiisg%16]; + //shmem_q[(4*chpt)*2 + tiisg] = xq0[2][32*iib + tiisg]; + //shmem_q[(4*chpt)*3 + tiisg] = xq0[3][32*iib + tiisg]; + + shmem_q[((4*chpt))*(tiisg/32 ) + tiisg%32] = xq0[tiisg/32 ][32*iib + tiisg%32]; + shmem_q[((4*chpt))*(tiisg/32 + 1) + tiisg%32] = xq0[tiisg/32 + 1][32*iib + tiisg%32]; + shmem_q[((4*chpt))*(tiisg/32 + 2) + tiisg%32] = xq0[tiisg/32 + 2][32*iib + tiisg%32]; + shmem_q[((4*chpt))*(tiisg/32 + 3) + tiisg%32] = xq0[tiisg/32 + 3][32*iib + tiisg%32]; + + //if (chpt == 2) { + // shmem_q[(4*chpt)*(tiisg/8 ) + tiisg%8] = xq0[tiisg/8 ][8*iib + tiisg%8]; + //} + + simdgroup_barrier(mem_flags::mem_threadgroup); + for (short ir0 = 0; ir0 < r0pt; ++ir0) { -#pragma unroll(4) + //const float d = xq[ir0]->d; + //device const int8_t * qs = ((device const int8_t *) xq[ir0]->qs); + +// float d[chpt]; +// device const int8_t * qs[chpt]; +//#pragma unroll(chpt) +// for (short ch = 0; ch < chpt; ++ch) { +// device const block_q8_0 * xc = xq[ir0] + (ch*nxpsg)/8; +// d[ch] = xc->d; +// qs[ch] = xc->qs; +// } +#pragma unroll(chpt) for (short ch = 0; ch < chpt; ++ch) { float4 lx; - dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx); + //float4 lx = dequantize_q8_0x(qs, d, (chpt*tx + ch)%8); + //dequantize_q8_0x(xq[ir0] + ch/8, (chpt*tx + ch)%8, lx); + //float4 lx = dequantize_q8_0x(qs, d, (tx)%8); + //float4 lx = dequantize_q8_0x(qs[ch], d[ch], (tx)%8); + //dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx); + //dequantize_q8_0x(xq0[ir0] + 8*iib + (ch*nxpsg)/8 + tx/8, (tx)%8, lx); + dequantize_q8_0s(shmem_q + (((4*chpt)*nxpsg)/32)*ir0 + (ch*nxpsg)/8 + tx/8, (tx)%8, lx); + //dequantize_q8_0s(shmem_q + 8*ir0 + (ch*nxpsg)/8 + tx/8, (tx)%8, lx); + + //sumf[ir0] += dot(lx, y4[ch]); sumf[ir0] += dot(lx, y4[ch*nxpsg]); } } y4 += ((4*chpt)*nxpsg)/4; - for (short ir0 = 0; ir0 < r0pt; ++ir0) { - xq[ir0] += ((4*chpt)*nxpsg)/32; - } + //for (short ir0 = 0; ir0 < r0pt; ++ir0) { + // xq[ir0] += ((4*chpt)*nxpsg)/32; + //} } for (short ir0 = 0; ir0 < r0pt; ++ir0) { @@ -1867,6 +1931,7 @@ kernel void kernel_mul_mv_ext_q8_0_f32( device const char * src0, device const char * src1, device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -1874,24 +1939,52 @@ kernel void kernel_mul_mv_ext_q8_0_f32( switch (args.nsg) { case 1: switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; } break; case 2: switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; } break; case 4: switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + } break; + case 6: + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q8_0_f32_impl<6, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q8_0_f32_impl<6, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q8_0_f32_impl<6, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q8_0_f32_impl<6, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + } break; + case 8: + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q8_0_f32_impl<8, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q8_0_f32_impl<8, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q8_0_f32_impl<8, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q8_0_f32_impl<8, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + } break; + case 12: + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q8_0_f32_impl<12, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q8_0_f32_impl<12, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q8_0_f32_impl<12, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q8_0_f32_impl<12, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + } break; + case 16: + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q8_0_f32_impl<16, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q8_0_f32_impl<16, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q8_0_f32_impl<16, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q8_0_f32_impl<16, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break; } break; } }