From 8c5b66eeaae396878efc86a605a5b15063ea5d69 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Dec 2023 15:30:34 +0200 Subject: [PATCH] metal : reduce the kernel launches for ggml_mul_mat_id --- ggml-metal.m | 50 +++++++++++++++++++++++++++++------------------- ggml-metal.metal | 26 ++++++++++++++++++------- 2 files changed, 49 insertions(+), 27 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 00dc4b0e1..a84b88dda 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1495,6 +1495,9 @@ void ggml_metal_graph_compute( const int idx = ((int32_t *) dst->op_params)[0]; + // batch size + GGML_ASSERT(ne01 == ne11); + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && @@ -1515,19 +1518,25 @@ void ggml_metal_graph_compute( default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); } const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3]; - [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:15]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; + [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; + [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:16]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:18]; // TODO: how to make this an array? read Metal docs for (int j = 0; j < n_as; ++j) { struct ggml_tensor * src_cur = dst->src[2 + j]; @@ -1535,18 +1544,19 @@ void ggml_metal_graph_compute( size_t offs_src_cur = 0; id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); - [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j]; + [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j]; } [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) { - [encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2]; + [encoder dispatchThreadgroups:MTLSizeMake( (1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + //[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + //for (int64_t i01 = 0; i01 < src0->ne[1]; i01++) { + // [encoder setBuffer:id_src0 offset:offs_src0 + i01*nb01 atIndex:0]; + // [encoder setBuffer:id_src1 offset:offs_src1 + i01*nb11 atIndex:1]; + // [encoder setBuffer:id_dst offset:offs_dst + i01*nb1 atIndex:2]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; - } + //} } } break; case GGML_OP_GET_ROWS: diff --git a/ggml-metal.metal b/ggml-metal.metal index 6723200c7..f25e813c2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -3474,19 +3474,22 @@ kernel void kernel_mul_mm(device const uchar * src0, template kernel void kernel_mul_mm_id( - device const int32_t * ids, + device const uchar * ids, device const uchar * src1, - device float * dst, + device uchar * dst, + constant int64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, constant int64_t & nb01, constant int64_t & nb02, constant int64_t & ne12, + constant int64_t & ne13, constant int64_t & nb10, constant int64_t & nb11, constant int64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant int64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx, @@ -3504,10 +3507,16 @@ kernel void kernel_mul_mm_id( uint sgitg[[simdgroup_index_in_threadgroup]]) { device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + const int64_t bid = tgpig.z/(ne12*ne13); + + tgpig.z = tgpig.z%(ne12*ne13); + + const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; + kernel_mul_mm_impl( - src0[ids[idx]], - src1, - dst, + src0[id], + src1 + bid*nb11, + (device float *) (dst + bid*nb1), ne00, ne02, nb01, @@ -3589,19 +3598,22 @@ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; typedef void (mat_mm_id_t)( - device const int32_t * ids, + device const uchar * ids, device const uchar * src1, - device float * dst, + device uchar * dst, + constant int64_t & nbi1, constant int64_t & ne00, constant int64_t & ne02, constant int64_t & nb01, constant int64_t & nb02, constant int64_t & ne12, + constant int64_t & ne13, constant int64_t & nb10, constant int64_t & nb11, constant int64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant int64_t & nb1, constant uint & r2, constant uint & r3, constant int & idx,