From 6ccbd1777ad07c5dbc2eba9e83e7a2bfe9231c90 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 24 Jan 2024 15:45:04 +0200 Subject: [PATCH] wip --- ggml-metal.m | 8 +- ggml-metal.metal | 237 ++++++++++++++++++++++------------------------- 2 files changed, 117 insertions(+), 128 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 00a8a0e92..4431306a6 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2253,13 +2253,15 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; [encoder setBytes:&scale length:sizeof( float) atIndex:27]; - const int64_t nsg = 2; // simdgroups per threadgroup (a.k.a. warps) - const int64_t nhptg = 4; // heads per threadgroup !! sync with kernel template arguments !! + const int64_t nsg = 16; // simdgroups per threadgroup (a.k.a. warps) + const int64_t nhptg = 2; // heads per threadgroup !! sync with kernel template arguments !! const int64_t nqptg = 2; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 8; //const size_t smem = nqptg*(nhptg*ne00 + nsg*(nhptg*ne00 + 256))*(sizeof(float)/2); - const size_t smem = nqptg*(nhptg*ne00 + nsg*(256))*(sizeof(float)/2); + const size_t smem = nqptg*(nhptg*ne00 + nsg*(32*ncpsg))*(sizeof(float)/2); + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); [encoder setThreadgroupMemoryLength:smem atIndex:0]; diff --git a/ggml-metal.metal b/ggml-metal.metal index 1a6eaed14..0c91cc336 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1995,7 +1995,7 @@ typedef void (flash_attn_ext_f16_t)( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]); -template // head size, heads per threadgroup, queries per threadgroup +template // head size, heads per threadgroup, queries per threadgroup kernel void kernel_flash_attn_ext_f16( device const char * q, device const char * k, @@ -2042,48 +2042,14 @@ kernel void kernel_flash_attn_ext_f16( return; } - // assume K and V are same shape - const int64_t ne22 = ne12; - const int64_t ne23 = ne13; - - const uint64_t nb21 = nb11; - const uint64_t nb22 = nb12; - const uint64_t nb23 = nb13; - - // broadcast - const int64_t rk2 = ne02/ne12; - const int64_t rk3 = ne03/ne13; - - const int64_t rv2 = ne02/ne22; - const int64_t rv3 = ne03/ne23; - - // k indices - const int64_t ik2 = iq2 / rk2; - const int64_t ik3 = iq3 / rk3; - - // v indices - const int64_t iv2 = iq2 / rv2; - const int64_t iv3 = iq3 / rv3; - - const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; - - device const float * mp[Q]; - for (int64_t j = 0; j < Q; ++j) { - if (iq1 + j < ne01) { - mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); - } else { - mp[j] = nullptr; - } - } - const int64_t D4 = D/4; - const int64_t T = (H*D + nsg*(256)); // shared memory size per query in half - const int64_t T4 = T/4; // shared memory size per query in half4 + const int64_t T = (H*D + nsg*(32*C)); // shared memory size per query in half + const int64_t T4 = T/4; // shared memory size per query in half4 - threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*H*D); - threadgroup half * ss = (threadgroup half *) (shared + sgitg*(256) + 1*H*D); - threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(256) + 1*H*D); + threadgroup half4 * pq4 = (threadgroup half4 *) (shared + 0*H*D); + threadgroup half * ss = (threadgroup half *) (shared + sgitg*(32*C) + 1*H*D); + threadgroup half4 * ss4 = (threadgroup half4 *) (shared + sgitg*(32*C) + 1*H*D); const uint tiih = tiisg%tph; // thread index in head const uint hiisg = tiisg/tph; // head index in simdgroup @@ -2116,98 +2082,122 @@ kernel void kernel_flash_attn_ext_f16( half S = { 0.0h }; half M = { -INFINITY }; - for (int64_t iic = 8*sgitg; iic < ne11; iic += 8*nsg) { - half mv[Q]; + { + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; - bool skip = true; - for (int64_t j = 0; j < Q; ++j) { - mv[j] = mp[j][iic]; - skip = skip && (mv[j] == -INFINITY); - } - if (skip) { - continue; - } + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; - for (int p = 0; p < 8; ++p) { - const int64_t ic = iic + p; + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; - device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; - half s[Q] = { 0.0h }; - half4 pk4v[D4/tph]; + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; - for (int64_t i = 0; i < D4/tph; ++i) { - pk4v[i] = pk4[tph*i + tiih]; - } + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; + + device const float * mp[Q]; + + { + const int64_t ir = iq3*ne02*ne01 + iq2*ne01 + iq1; for (int64_t j = 0; j < Q; ++j) { - for (int64_t i = 0; i < D4/tph; ++i) { - s[j] += dot(pq4[j*T4 + hiisg*D4 + tph*i + tiih], pk4v[i]); + if (iq1 + j < ne01) { + mp[j] = (device const float *) (mask + ((ir + j)%ne31)*nb31); + } else { + mp[j] = nullptr; + } + } + } + + for (int64_t iic = C*sgitg; iic < ne11; iic += C*nsg) { + { + bool skip = true; + for (int64_t j = 0; j < Q; ++j) { + skip = skip && (mp[j][iic] == -INFINITY); + } + if (skip) { + continue; } } - for (int64_t j = 0; j < Q; ++j) { - ss[j*T + 32*p + hiisg*tph + tiih] = s[j]; + for (int p = 0; p < C; ++p) { + const int64_t ic = iic + p; + + device const half4 * pk4 = (device const half4 *) ((device char *) k + (ic*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t j = 0; j < Q; ++j) { + half4 s4 = 0.0h; + + for (int64_t i = 0; i < D4/tph; ++i) { + s4 += pq4[j*T4 + hiisg*D4 + tph*i + tiih]*pk4[tph*i + tiih]; + } + + ss[j*T + 32*p + hiisg*tph + tiih] = s4.x + s4.y + s4.z + s4.w; + } + } + + simdgroup_barrier(mem_flags::mem_none); + + if (tiih < Q) { + const int64_t j = tiih; + + for (int p = 0; p < C; ++p) { + half4 s4 = 0.0h; + + for (int64_t i = 0; i < tph/4; ++i) { + s4 += ss4[j*T4 + 8*p + hiisg*tph/4 + i]; + } + + half s = (s4.x + s4.y + s4.z + s4.w)*scale + mp[j][iic + p]; + + const half m = M; + + M = max(M, s); + + const half ms = m == -INFINITY ? 0.0h : exp(m - M); + const half vs = s == -INFINITY ? 0.0h : exp(s - M); + + S = S*ms + vs; + + ss[j*T + 32*p + 2*hiisg + 0] = ms; + ss[j*T + 32*p + 2*hiisg + 1] = vs; + } + } + + simdgroup_barrier(mem_flags::mem_none); + + for (int p = 0; p < C; ++p) { + const int64_t ic = iic + p; + device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); + + for (int64_t j = 0; j < Q; ++j) { + const half ms = ss[j*T + 32*p + 2*hiisg + 0]; + const half vs = ss[j*T + 32*p + 2*hiisg + 1]; + + for (int64_t i = 0; i < D4/tph; ++i) { + ps4[j][i] = ps4[j][i]*ms + pv4[tph*i + tiih]*vs; + } + } } } - simdgroup_barrier(mem_flags::mem_none); - if (tiih < Q) { const int64_t j = tiih; - for (int p = 0; p < 8; ++p) { - half4 s4 = 0.0h; - - for (int64_t i = 0; i < tph/4; ++i) { - s4 += ss4[j*T4 + 8*p + hiisg*tph/4 + i]; - } - - half s = (s4.x + s4.y + s4.z + s4.w)*scale + mp[j][iic + p]; - - const half m = M; - - M = max(M, s); - - const half ms = m == -INFINITY ? 0.0h : exp(m - M); - const half vs = s == -INFINITY ? 0.0h : exp(s - M); - - S = S*ms + vs; - - ss[j*T + 32*p + 2*hiisg + 0] = ms; - ss[j*T + 32*p + 2*hiisg + 1] = vs; - } + ss[j*T + 2*hiisg + 0] = S; + ss[j*T + 2*hiisg + 1] = M; } - - simdgroup_barrier(mem_flags::mem_none); - - for (int64_t i = 0; i < D4/tph; ++i) { - half4 pv4v[8]; - - for (int p = 0; p < 8; ++p) { - const int64_t ic = iic + p; - - device const half4 * pv4 = (device const half4 *) ((device char *) v + (ic*nb21 + iv2*nb22 + iv3*nb23)); - - pv4v[p] = pv4[tph*i + tiih]; - } - - for (int64_t j = 0; j < Q; ++j) { - for (int p = 0; p < 8; ++p) { - const half ms = ss[j*T + 32*p + 2*hiisg + 0]; - const half vs = ss[j*T + 32*p + 2*hiisg + 1]; - - ps4[j][i] = ps4[j][i]*ms + pv4v[p]*vs; - } - } - } - } - - if (tiih < Q) { - const int64_t j = tiih; - - ss[j*T + 2*hiisg + 0] = S; - ss[j*T + 2*hiisg + 1] = M; } threadgroup_barrier(mem_flags::mem_threadgroup); @@ -2260,11 +2250,11 @@ kernel void kernel_flash_attn_ext_f16( if (sgitg == 0) { for (int64_t j = 0; j < Q; ++j) { - const half S0 = ss[j*T + 2*hiisg + 0]; - const half S1 = ss[j*T + sg*(256) + 2*hiisg + 0]; + const half S0 = ss[j*T + 2*hiisg + 0]; + const half S1 = ss[j*T + sg*(32*C) + 2*hiisg + 0]; - const half M0 = ss[j*T + 2*hiisg + 1]; - const half M1 = ss[j*T + sg*(256) + 2*hiisg + 1]; + const half M0 = ss[j*T + 2*hiisg + 1]; + const half M1 = ss[j*T + sg*(32*C) + 2*hiisg + 1]; M = max(M0, M1); @@ -2279,7 +2269,6 @@ kernel void kernel_flash_attn_ext_f16( } for (int64_t i = 0; i < D4/tph; ++i) { - //ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]*ms0 + ps4[j*T4 + sg*(256)/4 + hiisg*D4 + tph*i + tiih]*ms1; ps4[j][i] = ps4[j][i]*ms0 + pq4[j*T4 + hiisg*D4 + tph*i + tiih]*ms1; } } @@ -2292,7 +2281,6 @@ kernel void kernel_flash_attn_ext_f16( for (int64_t j = 0; j < Q; ++j) { S = ss[j*T + 2*hiisg + 0]; for (int64_t i = 0; i < D4/tph; ++i) { - //ps4[j*T4 + hiisg*D4 + tph*i + tiih] = ps4[j*T4 + hiisg*D4 + tph*i + tiih]/S; ps4[j][i] = ps4[j][i]/S; } } @@ -2305,16 +2293,15 @@ kernel void kernel_flash_attn_ext_f16( if (sgitg == 0) { for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { for (int64_t i = 0; i < D4/tph; ++i) { - //dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + tph*i + tiih] = (float4) ps4[j*T4 + hiisg*D4 + tph*i + tiih]; dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + tph*i + tiih] = (float4) ps4[j][i]; } } } } -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 4, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 4, 2>; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 4, 2>; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 2, 2, 8>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 2, 2, 8>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 2, 2, 8>; kernel void kernel_cpy_f16_f16( device const half * src0,