From 54c859e195ca2ea5e00711eb20dd6b819647f375 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 13:16:54 +0200 Subject: [PATCH] metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV --- ggml/src/ggml-common.h | 28 +++++ ggml/src/ggml-metal.m | 124 ++++++++++---------- ggml/src/ggml-metal.metal | 234 ++++++++++++-------------------------- 3 files changed, 164 insertions(+), 222 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index cd54c18ed..6d7e07ee6 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -447,6 +447,34 @@ typedef struct { int32_t dim; } ggml_metal_kargs_concat; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 2058ea25e..8a96e09ba 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1243,8 +1243,6 @@ static void ggml_metal_encode_node( bool bcast_row = false; - int64_t nb = ne00; // used by the "row" kernels - id pipeline = nil; if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { @@ -1253,7 +1251,6 @@ static void ggml_metal_encode_node( // src1 is a row GGML_ASSERT(ne11 == 1); - nb = ne00 / 4; switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; @@ -1273,36 +1270,39 @@ static void ggml_metal_encode_node( } } + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (bcast_row) { const int64_t n = ggml_nelements(dst)/4; @@ -1400,35 +1400,39 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index b11af928b..51bba1aef 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -493,200 +493,106 @@ enum ggml_sort_order { // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient kernel void kernel_add( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_sub( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_mul( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_div( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); } } @@ -740,38 +646,42 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_add_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] + src1[tpig % nb]; } kernel void kernel_sub_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] - src1[tpig % nb]; } kernel void kernel_mul_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] * src1[tpig % nb]; } kernel void kernel_div_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] / src1[tpig % nb]; }