diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 53c135496..d8dd361a6 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -192,6 +192,29 @@ typedef struct { int16_t r3; } ggml_metal_kargs_mul_mv; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; + int16_t nsg; + int16_t nxpsg; +} ggml_metal_kargs_mul_mv_ext; + typedef struct { int32_t nei0; int32_t nei1; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index ae6b25ede..0ef286f25 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1,6 +1,8 @@ #import "ggml-metal.h" #import "ggml-impl.h" +#define GGML_COMMON_DECL_C +#import "ggml-common.h" #import "ggml-backend-impl.h" #import "ggml-metal-impl.h" @@ -175,6 +177,7 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, + GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, @@ -699,6 +702,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32, mul_mv_ext_q8_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); @@ -1952,6 +1956,55 @@ static void ggml_metal_encode_node( } #endif + if (src0t == GGML_TYPE_Q8_0 && (ne00%16 == 0) && (ne11 >= 4 && ne11 < 32)) { + //if (false) { + id pipeline = nil; + + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline; + + const int nsg = 2; + const int r0pt = 2; + const int r1pt = 1; + const int nxpsg = ne11 > 1 ? 8 : 32; + const int nypsg = 32/nxpsg; + const int nr0ptg = nypsg*r0pt*nsg; + + //GGML_ASSERT(ne00%1024 == 0); + //GGML_ASSERT(ne01%nr0ptg == 0); + //printf("ne01 = %lld, nr0ptg = %d, ne00 = %lld\n", ne01, nr0ptg, ne00); + + ggml_metal_kargs_mul_mv_ext args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + /*.nsg =*/ nsg, + /*.nxpsg =*/ nxpsg, + }; + + [encoder setComputePipelineState:pipeline]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + + //printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg); + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0ptg - 1)/nr0ptg, (ne11 + r1pt - 1)/r1pt, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } else // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([device supportsFamily:MTLGPUFamilyApple7] && diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index eaca38864..4b3f2acdb 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1752,6 +1752,135 @@ kernel void kernel_mul_mv_q8_0_f32( kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template +void kernel_mul_mv_ext_q8_0_f32_impl( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const short chpt = 1; + const short r0pt = 2; + + //const short nxpsg = (32); + const short nypsg = (32/nxpsg)*r0pt; + + const short tx = tiisg%nxpsg; + const short ty = tiisg/nxpsg; + + const int i01 = tgpig.x*(nypsg*nsg) + nypsg*sgitg + ty*r0pt; + const int i11 = tgpig.y; + const int i1m = tgpig.z; + + const int i12 = i1m%args.ne12; + const int i13 = i1m/args.ne12; + + const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + + device const block_q8_0 * xq[r0pt]; + + for (short ir0 = 0; ir0 < r0pt; ++ir0) { + xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/2 : (device const block_q8_0 *) src0; + } + + device const float4x4 * y4x4 = (device const float4x4 *) (src1 + offset1) + chpt*tx; + + float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f }; + + for (int iib = 0; (16*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) { + float4x4 lx; + +#pragma unroll(2) + for (short ir0 = 0; ir0 < r0pt; ++ir0) { +#pragma unroll + for (short ch = 0; ch < chpt; ++ch) { + dequantize_q8_0(xq[ir0] + ch/2, (chpt*tx + ch)%2, lx); + + const float4x4 ly = y4x4[ch]; + + sumf[ir0] += + dot(lx[0], ly[0]) + + dot(lx[1], ly[1]) + + dot(lx[2], ly[2]) + + dot(lx[3], ly[3]); + } + } + + y4x4 += ((16*chpt)*nxpsg)/16; + + for (short ir0 = 0; ir0 < r0pt; ++ir0) { + xq[ir0] += ((16*chpt)*nxpsg)/32; + } + } + + for (short ir0 = 0; ir0 < r0pt; ++ir0) { + if (nxpsg >= 32) { + sumf[ir0] += simd_shuffle_down(sumf[ir0], 16); + } + if (nxpsg >= 16) { + sumf[ir0] += simd_shuffle_down(sumf[ir0], 8); + } + if (nxpsg >= 8) { + sumf[ir0] += simd_shuffle_down(sumf[ir0], 4); + } + if (nxpsg >= 4) { + sumf[ir0] += simd_shuffle_down(sumf[ir0], 2); + } + if (nxpsg >= 2) { + sumf[ir0] += simd_shuffle_down(sumf[ir0], 1); + } + + //sumf[ir0] = simd_sum(sumf[ir0]); + } + + device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)i11*args.ne0; + + if (tx == 0) { + for (short ir0 = 0; ir0 < r0pt && i01 + ir0 < args.ne01; ++ir0) { + dst_f32[i01 + ir0] = sumf[ir0]; + } + } +} + +[[host_name("kernel_mul_mv_ext_q8_0_f32")]] +kernel void kernel_mul_mv_ext_q8_0_f32( + constant ggml_metal_kargs_mul_mv_ext & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + switch (args.nsg) { + case 1: + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + } break; + case 2: + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + } break; + case 4: + switch (args.nxpsg) { + case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; + } break; + } +} + #define N_MV_T_T 4 template