From e68e32548fa1e824f0e7bfa8414b8d853efff808 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 7 Feb 2024 23:12:22 +0200 Subject: [PATCH] metal : opts --- ggml-metal.m | 2 +- ggml-metal.metal | 110 ++++++++++++++++++++++++++--------------------- 2 files changed, 62 insertions(+), 50 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 6d051e8ab..80cfb2e22 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1343,7 +1343,7 @@ static bool ggml_metal_graph_compute( const int nsg = 8; const int nsg0 = 1; - const int nsh0 = 8; + const int nsh0 = 16; const int nsg1 = 1; const int nsh1 = 64; diff --git a/ggml-metal.metal b/ggml-metal.metal index 8926ec6bb..41d6f78ea 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4785,7 +4785,7 @@ void kernel_mul_mm_impl( } #define NSG0 1 -#define NSH0 8 +#define NSH0 16 #define NSG1 1 #define NSH1 64 @@ -4815,33 +4815,34 @@ void kernel_mul_mm2_impl( uint sgitg[[simdgroup_index_in_threadgroup]]) { const uint nsg = ntg.y; // number of simdgroups - const int64_t im = tgpig[2]; - const int64_t i11 = tgpig[1]*(8*NSG1); - const int64_t i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0); + const int im = tgpig[2]; + const int i11 = tgpig[1]*(8*NSG1); + const int i01 = tgpig[0]*(8*NSG0*nsg) + sgitg*(8*NSG0); - const int64_t i12 = im%ne12; - const int64_t i13 = im/ne12; + const int i12 = im%ne12; + const int i13 = im/ne12; - const int64_t ne01 = ne0; - const int64_t ne11 = ne1; + const int ne01 = ne0; + const int ne11 = ne1; - const int64_t NW = N_SIMDWIDTH; + const int NW = N_SIMDWIDTH; - const int64_t SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half) - const int64_t SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4) + const int SH0 = (8*NSG0)*(8*NSH0); // shread memory per threadgroup for src0 data in (half) + const int SH04 = SH0/4; // shread memory per threadgroup for src0 data in (half4) - const int64_t SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float) - const int64_t SH14 = SH1/4; // shread memory for src1 data in (float4) + const int SH1 = (8*NSG1)*(8*NSH1); // shread memory for src1 data in (float) + const int SH14 = SH1/4; // shread memory for src1 data in (float4) - const int64_t T1 = 8*NSH1; // row of src1 in shared memory in (float) - const int64_t T14 = T1/4; // row of src1 in shared memory in (float4) + const int T1 = 8*NSH1; // row of src1 in shared memory in (float) + const int T14 = T1/4; // row of src1 in shared memory in (float4) threadgroup half * shared = (threadgroup half *) shared_u8; - threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0); - threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0); - threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0); - threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0); + threadgroup half * s0 = (threadgroup half *)(shared + sgitg*SH0); + threadgroup half4 * s04 = (threadgroup half4 *)(shared + sgitg*SH0); + threadgroup half4x4 * s016 = (threadgroup half4x4 *)(shared + sgitg*SH0); + threadgroup float * s1 = (threadgroup float *)(shared + nsg*SH0); + threadgroup float4 * s14 = (threadgroup float4 *)(shared + nsg*SH0); threadgroup float * r0 = (threadgroup float *)(shared + 2*sgitg*(8*NSG0)*(8*NSG1)); @@ -4850,12 +4851,12 @@ void kernel_mul_mm2_impl( simdgroup_float8x8 mr[NSG0][NSG1]; // zero out shared memory SH0 for src0 - for (int64_t i = tiisg; i < SH04; i += NW) { + for (int i = tiisg; i < SH04; i += NW) { s04[i] = 0.0h; } // zero out shared memory SH1 for src1 - for (int64_t i = tiitg; i < SH14; i += nsg*NW) { + for (int i = tiitg; i < SH14; i += nsg*NW) { s14[i] = 0.0f; } @@ -4868,24 +4869,27 @@ void kernel_mul_mm2_impl( } } - for (int64_t i00 = 0; i00 < ne00; i00 += 8*NSH1) { + for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) { // load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory { threadgroup_barrier(mem_flags::mem_threadgroup); - const int64_t nload = min(8ll*NSG1, ne11 - i11) * (8*NSH1); + const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1); - for (int64_t i = tiitg; i < nload; i += nsg*NW) { - const int64_t ic = i%(8*NSH1); - const int64_t ir = i/(8*NSH1); + const size_t offs0 = im*nb12; - // TODO: use float4 - device const float * p1 = (device const float *)(src1 + im*nb12 + (i11 + ir)*nb11 + (i00 + ic)*nb10); + for (int i = 4*tiitg; i < nload; i += 4*nsg*NW) { + const int ic = i%(8*NSH1); + const int ir = i/(8*NSH1); - if (i00 + ic < ne00) { - s1[8*NSH1*ir + ic] = *p1; + device const float4 * p1 = (device const float4 *)(src1 + offs0 + (i11 + ir)*nb11 + (i00 + ic)*nb10); + + if (i00 + ic + 4 <= ne00) { + s14[(8*NSH1*ir + ic)/4] = *p1; } else { - s1[8*NSH1*ir + ic] = 0.0f; + for (int k = 0; i00 + ic + k < ne00; k++){ + s1[8*NSH1*ir + ic + k] = (*p1)[k]; + } } } @@ -4895,28 +4899,36 @@ void kernel_mul_mm2_impl( for (int b0 = 0; b0 < NSH1/NSH0; ++b0) { // load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory { - const int64_t nload = min(8ll*NSG0, ne01 - i01) * (8*NSH0); + const int nload = MIN(8*NSG0, ne01 - i01) * (8*NSH0); half4x4 tmp0; - for (int64_t i = 16*tiisg; i < nload; i += 16*NW) { - const int64_t ic = i%(8*NSH0); - const int64_t ir = i/(8*NSH0); + const size_t offs0 = (i13/r3)*(nb02*ne02) + (i12/r2)*nb02; - const int64_t icc = i00 + 8*b0*NSH0 + ic; + for (int i = 16*tiisg; i < nload; i += 16*NW) { + const int ic = i%(8*NSH0); + const int ir = i/(8*NSH0); - const int64_t ib = (icc/(16*nl)); - const int64_t il = (icc%(16*nl))/16; + const int icc = i00 + 8*b0*NSH0 + ic; - device const block_q * p0 = (device const block_q *)(src0 + (i13/r3)*(nb02*ne02) + (i12/r2)*nb02 + (i01 + ir)*nb01) + ib; + const int ib = (icc/(16*nl)); + const int il = (icc%(16*nl))/16; + + device const block_q * p0 = (device const block_q *)(src0 + offs0 + (i01 + ir)*nb01) + ib; dequantize_func(p0, il, tmp0); - for (int k = 0; k < 4; k++){ - if (icc + 4*k < ne00) { - s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k]; - } else { - s04[(8*NSH0*ir + ic)/4 + k] = 0.0h; + if (icc + 16 <= ne00) { + s016[(8*NSH0*ir + ic)/16] = tmp0; + } else { + for (int k = 0; k < 4; k++){ + if (icc + 4*k <= ne00) { + s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k]; + } else { + for (int p = 0; icc + 4*k + p < ne00; p++) { + s0[8*NSH0*ir + ic + 4*k + p] = tmp0[k][p]; + } + } } } } @@ -4958,12 +4970,12 @@ void kernel_mul_mm2_impl( device float * pdst = dst + im*ne1*ne0; for (int is = 0; is < NSG1; is++) { - const int64_t i1 = i11 + is*8; - const int64_t nstore = min(8ll*NSG1, ne1 - i1) * (8*NSG0); + const int i1 = i11 + is*8; + const int nstore = MIN(8*NSG1, ne1 - i1) * (8*NSG0); - for (int64_t i = tiisg; i < nstore; i += NW) { - const int64_t ic = i%(8*NSG0); - const int64_t ir = i/(8*NSG0); + for (int i = tiisg; i < nstore; i += NW) { + const int ic = i%(8*NSG0); + const int ir = i/(8*NSG0); if (i1 + ir < ne1 && i01 + ic < ne0) { pdst[(i1 + ir)*ne0 + (i01 + ic)] = r0[(8*is)*(8*NSG0) + 8*NSG0*ir + ic];