From c5cf1d74f03330bcf1b9a76f5d60f0125c00cc05 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 10:32:15 +0200 Subject: [PATCH] cont : mul mm id ggml-ci --- ggml/src/ggml-common.h | 18 +++++ ggml/src/ggml-metal.m | 43 ++++++----- ggml/src/ggml-metal.metal | 151 +++++++++++++++++++------------------- 3 files changed, 118 insertions(+), 94 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index fcf0d997c..a20703269 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -509,6 +509,24 @@ typedef struct { int16_t r3; } ggml_metal_kargs_mul_mv; +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} ggml_metal_kargs_mul_mm_id; + typedef struct { int32_t nei0; int32_t nei1; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index d3497233e..4c83aa7b7 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -2297,27 +2297,30 @@ static void ggml_metal_encode_node( default: GGML_ABORT("MUL_MAT_ID not implemented"); } + ggml_metal_kargs_mul_mm_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f8591239a..561476551 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -5764,31 +5764,32 @@ kernel void kernel_mul_mm( } // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids +// TODO: this kernel needs to be reimplemented from scratch for better performance template void kernel_mul_mm_id_impl( - device const uchar * src0, - device const uchar * src1, + int32_t ne00, + int32_t ne02, + uint64_t nb01, + uint64_t nb02, + int32_t ne11, + int32_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int32_t ne0, + int32_t ne1, + int64_t ne0ne1, + device const char * src0, + device const char * src1, threadgroup ushort2 * rowids, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - int64_t ne1, - int64_t ne0ne1, - threadgroup uchar * shared_memory, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + device char * dst, + threadgroup char * shmem, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + threadgroup half * sa = (threadgroup half *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); const uint r0 = tgpig.y; const uint r1 = tgpig.x; @@ -5805,9 +5806,9 @@ void kernel_mul_mm_id_impl( simdgroup_half8x8 ma[4]; simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; + simdgroup_float8x8 mc[8]; for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); + mc[i] = make_filled_simdgroup_matrix(0.f); } short il = (tiitg % THREAD_PER_ROW); @@ -5845,11 +5846,14 @@ void kernel_mul_mm_id_impl( threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(BLOCK_SIZE_K/8) for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) for (int i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) for (int i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } @@ -5857,29 +5861,42 @@ void kernel_mul_mm_id_impl( lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } } } { threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0); if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int joff = jid[0] * ne0 + jid[1] * ne0ne1; - for (int i = 0; i < n_rows; i++) { - *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); + int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1; + + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); } } } @@ -5888,48 +5905,34 @@ void kernel_mul_mm_id_impl( template kernel void kernel_mul_mm_id( - device const uchar * src0s, - device const uchar * src1, - device float * dst, - device const uchar * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const int32_t i02 = tgpig.z; + tgpig.z = 0; - device const uchar * src0 = src0s + i02*nb02; + device const char * src0 = src0s + i02*args.nb02; // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); // TODO: parallelize this loop int64_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < nei1; ii1++) { - for (ushort ii0 = 0; ii0 < nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { + for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; if (id == i02) { - //if (tiitg == 0) { + if (tiitg == 0) { rowids[_ne1] = ushort2(ii0, ii1); - //} + } _ne1++; } } @@ -5938,23 +5941,23 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); kernel_mul_mm_id_impl( + args.ne00, + args.ne02, + args.nb01, + args.nb02, + args.ne11, + args.ne12, + args.nb10, + args.nb11, + args.nb12, + args.ne0, + _ne1, + (int64_t)args.ne0*args.ne1, src0, src1, rowids, dst, - ne00, - ne02, - nb01, - nb02, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - ne0*ne1, - shared_memory, + shmem, tgpig, tiitg, sgitg);