From c46e9e488b4c455318739d6de46fb00d5e56b7d2 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 28 Nov 2024 21:59:15 +0200 Subject: [PATCH] wip --- ggml/src/ggml-metal/ggml-metal-impl.h | 1 + ggml/src/ggml-metal/ggml-metal.m | 15 ++++- ggml/src/ggml-metal/ggml-metal.metal | 80 +++++++++++++-------------- tests/test-backend-ops.cpp | 2 +- 4 files changed, 52 insertions(+), 46 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index d8dd361a6..8c99716f4 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -213,6 +213,7 @@ typedef struct { int16_t r3; int16_t nsg; int16_t nxpsg; + int16_t r1pt; } ggml_metal_kargs_mul_mv_ext; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d42877116..18ef53815 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -23,6 +23,14 @@ #define UNUSED(x) (void)(x) +static int up2(int x) { + int r = 1; + while (r < x) { + r *= 2; + } + return r; +} + // globals // overload of MTLGPUFamilyMetal3 (not available in some environments) @@ -1956,7 +1964,7 @@ static void ggml_metal_encode_node( } #endif - if (src0t == GGML_TYPE_Q8_0 && (ne00%16 == 0) && (ne11 >= 2 && ne11 < 32)) { + if (src0t == GGML_TYPE_Q8_0 && (ne00%256 == 0) && (ne11 >= 2 && ne11 < 16)) { //if (false) { id pipeline = nil; @@ -1964,8 +1972,8 @@ static void ggml_metal_encode_node( const int nsg = 2; const int r0pt = 1; - const int r1pt = 4; - const int nxpsg = ne11 > 1 ? 8 : 32; + const int r1pt = ne11 < 3 ? 2 : 4; + const int nxpsg = ne11 < 3 ? 16 : 8; const int nypsg = 32/nxpsg; const int nr0ptg = nypsg*r0pt*nsg; @@ -1994,6 +2002,7 @@ static void ggml_metal_encode_node( /*.r3 =*/ r3, /*.nsg =*/ nsg, /*.nxpsg =*/ nxpsg, + /*.r1pt =*/ r1pt, }; [encoder setComputePipelineState:pipeline]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a5f12a4c5..53cacd929 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1772,19 +1772,17 @@ kernel void kernel_mul_mv_q8_0_f32( kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_ext_q8_0_f32_impl( constant ggml_metal_kargs_mul_mv_ext & args, device const char * src0, device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short chpt = 4; const short r0pt = 1; - const short r1pt = 4; //const short nxpsg = (32); const short nypsg = (32/nxpsg)*r0pt; @@ -1802,47 +1800,42 @@ void kernel_mul_mv_ext_q8_0_f32_impl( const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - //device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx; - //device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx; - device const block_q8_0 * xq[r0pt]; for (short ir0 = 0; ir0 < r0pt; ++ir0) { - //xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0; xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0; } - device const float4 * y4[r1pt]; + device const float4 * y4[r1pt]; + for (int ir1 = 0; ir1 < r1pt; ++ir1) { - //y4[ir1] = (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx; - y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1; + y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1; } - //float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f }; float sumf[r1pt][r0pt] = { [ 0 ... r1pt - 1 ] = { [0 ... r0pt - 1] = 0.0f } }; for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) { for (short ir0 = 0; ir0 < r0pt; ++ir0) { -#pragma unroll(4) +#pragma unroll(chpt) for (short ch = 0; ch < chpt; ++ch) { float4 lx; dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx); -#pragma unroll(4) +#pragma unroll(r1pt) for (short ir1 = 0; ir1 < r1pt; ++ir1) { sumf[ir1][ir0] += dot(lx, y4[ir1][ch*nxpsg]); } } } - for (short ir1 = 0; ir1 < r1pt; ++ir1) { - y4[ir1] += ((4*chpt)*nxpsg)/4; - } - for (short ir0 = 0; ir0 < r0pt; ++ir0) { xq[ir0] += ((4*chpt)*nxpsg)/32; } + + for (short ir1 = 0; ir1 < r1pt; ++ir1) { + y4[ir1] += ((4*chpt)*nxpsg)/4; + } } for (short ir1 = 0; ir1 < r1pt; ++ir1) { @@ -1867,8 +1860,6 @@ void kernel_mul_mv_ext_q8_0_f32_impl( } } - //device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)i11*args.ne0; - if (tx == 0) { for (short ir1 = 0; ir1 < r1pt && i11 + ir1 < args.ne11; ++ir1) { device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0; @@ -1887,32 +1878,37 @@ kernel void kernel_mul_mv_ext_q8_0_f32( device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - switch (args.nsg) { - case 1: - switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - } break; - case 2: - switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - } break; - case 4: - switch (args.nxpsg) { - case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8> (args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32>(args, src0, src1, dst, tgpig, ntg, tiisg, sgitg); break; - } break; +#define CASE_R1PT(r1pt) \ + switch (args.nsg) { \ + case 1: \ + switch (args.nxpsg) { \ + case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + } break; \ + case 2: \ + switch (args.nxpsg) { \ + case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + } break; \ + case 4: \ + switch (args.nxpsg) { \ + case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32, (r1pt)>(args, src0, src1, dst, tgpig, tiisg, sgitg); break; \ + } break; \ } + + switch (args.r1pt) { + case 2: CASE_R1PT( 2); break; + case 4: CASE_R1PT( 4); break; + }; } #define N_MV_T_T 4 diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index d8657d5df..1699923e1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3571,7 +3571,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4)); for (int i = 1; i < 64; ++i) { - test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 64, i, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 64, i, 512, { 1, 1}, {1, 1})); } #if 1