From a6c8dbfa5df9cdc9dd9d26e24eca6bf01dc7e256 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 7 Nov 2024 18:20:25 +0200 Subject: [PATCH] wip --- ggml/src/ggml-metal.m | 35 +++------ ggml/src/ggml-metal.metal | 159 ++++++++++++++------------------------ 2 files changed, 70 insertions(+), 124 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 741a4aef0..ef86d6873 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1139,7 +1139,7 @@ static void ggml_metal_encode_node( const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); const uint64_t nb21 = src2 ? src2->nb[1] : 0; const uint64_t nb22 = src2 ? src2->nb[2] : 0; - const uint64_t nb23 = src2 ? src2->nb[3] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23); const int64_t ne0 = dst ? dst->ne[0] : 0; const int64_t ne1 = dst ? dst->ne[1] : 0; @@ -3241,18 +3241,15 @@ static void ggml_metal_encode_node( [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22]; - [encoder setBytes:&scale length:sizeof( float) atIndex:23]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:24]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:25]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:26]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27]; - [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19]; + [encoder setBytes:&scale length:sizeof( float) atIndex:20]; + [encoder setBytes:&max_bias length:sizeof( float) atIndex:21]; + [encoder setBytes:&m0 length:sizeof(m0) atIndex:22]; + [encoder setBytes:&m1 length:sizeof(m1) atIndex:23]; + [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24]; + [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25]; if (!use_vec_kernel) { // half8x8 kernel @@ -3263,21 +3260,11 @@ static void ggml_metal_encode_node( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); -#ifdef GGML_METAL_FORCE_FATTN_PREC_F16 - const enum ggml_prec prec = GGML_PREC_DEFAULT; -#else - // TODO: support both precisions - const enum ggml_prec prec = GGML_PREC_F32; - //const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst); -#endif - - const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2; - // 16*32*(nsg) // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2cd1f8462..d9743ce56 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2815,14 +2815,11 @@ kernel void kernel_flash_attn_ext( constant uint32_t & nb02, constant uint32_t & nb03, constant int32_t & ne11, - constant int32_t & ne12, - constant int32_t & ne13, - constant uint32_t & nb11, - constant uint32_t & nb12, - constant uint32_t & nb13, - constant uint32_t & nb21, - constant uint32_t & nb22, - constant uint32_t & nb23, + constant int32_t & ne_12_2, // assume K and V are same shape + constant int32_t & ne_12_3, + constant uint32_t & nb_12_1, + constant uint32_t & nb_12_2, + constant uint32_t & nb_12_3, constant uint32_t & nb31, constant int32_t & ne1, constant int32_t & ne2, @@ -2830,13 +2827,13 @@ kernel void kernel_flash_attn_ext( constant float & max_bias, constant float & m0, constant float & m1, - constant uint32_t & n_head_log2, + constant uint16_t & n_head_log2, constant float & logit_softcap, threadgroup half * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 ntg[[threads_per_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + ushort3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -2849,17 +2846,14 @@ kernel void kernel_flash_attn_ext( const short NW = N_SIMDWIDTH; const short SH = (2*C + Q); // shared memory per simdgroup in (half) - const short SF = sizeof(s_t)/sizeof(half); + const short TS = nsg*SH; // shared memory size per query in (s_t) + const short T = D + 2*TS; // shared memory size per query in (half) - const short T = D + SF*nsg*SH; // shared memory size per query in (half) - const short TS = T/SF; // shared memory size per query in (s_t) - const short T4 = T/4; // shared memory size per query in (half4) - - threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation - threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // reuse query data for accumulation + threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention and diagonal matrix threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t @@ -2876,9 +2870,9 @@ kernel void kernel_flash_attn_ext( for (short i = tiisg; i < D4; i += NW) { if (iq1 + j < ne01) { - sq4[j*T4 + i] = (q4_t) q4[i]; + sq4[j*D4 + i] = (q4_t) q4[i]; } else { - sq4[j*T4 + i] = (q4_t) 0.0f; + sq4[j*D4 + i] = (q4_t) 0.0f; } } } @@ -2907,29 +2901,18 @@ kernel void kernel_flash_attn_ext( const short tx = tiisg%4; const short ty = tiisg/4; - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; + // broadcast kv + //const short rk2 = ne02/ne12; + //const short rk3 = ne03/ne13; - // broadcast k - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; - - const short ik2 = iq2/rk2; - const short ik3 = iq3/rk3; - - // broadcast v - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; - - const short iv2 = iq2/rv2; - const short iv3 = iq3/rv3; + const short ikv2 = iq2/(ne02/ne_12_2); + const short ikv3 = iq3/(ne03/ne_12_3); // load the queries from shared memory into local memory q8x8_t mq[D8]; for (short i = 0; i < D8; ++i) { - simdgroup_load(mq[i], sq + i*8, T); + simdgroup_load(mq[i], sq + i*8, D); } const bool has_mask = mask != q; @@ -2982,18 +2965,18 @@ kernel void kernel_flash_attn_ext( // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); #pragma unroll for (short i = 0; i < D8; ++i) { k8x8_t mk; - simdgroup_load(mk, pk + i*8, nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } } else { for (short ii = 0; ii < D16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb11 + ik2*nb12 + ik3*nb13)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); if (D16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -3046,7 +3029,7 @@ kernel void kernel_flash_attn_ext( // online softmax { - for (short j = 0; j < Q; ++j) { + for (ushort j = 0; j < Q; ++j) { const half m = M[j]; // scale and apply the logitcap / mask @@ -3095,17 +3078,17 @@ kernel void kernel_flash_attn_ext( if (is_same::value) { // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); #pragma unroll for (short i = 0; i < D8; ++i) { v8x8_t mv; - simdgroup_load(mv, pv + i*8, nb21/sizeof(v_t), 0, false); // TODO: use ne20 + simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); } } else { for (short ii = 0; ii < D16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb21 + iv2*nb22 + iv3*nb23)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); if (D16%4 == 0) { // no need for bound checks @@ -3162,7 +3145,7 @@ kernel void kernel_flash_attn_ext( } // reduce the warps sequentially - for (short sg = 1; sg < nsg; ++sg) { + for (ushort sg = 1; sg < nsg; ++sg) { half S = { 0.0f }; half M = { -__FLT16_MAX__/2 }; @@ -3171,7 +3154,7 @@ kernel void kernel_flash_attn_ext( // each simdgroup stores its output to shared memory, reusing sq if (sgitg == sg) { for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], so + i*8, T, 0, false); + simdgroup_store(lo[i], so + i*8, D, 0, false); } } @@ -3213,7 +3196,7 @@ kernel void kernel_flash_attn_ext( for (short i = 0; i < D8; ++i) { o8x8_t t; - simdgroup_load (t, so + i*8, T, 0, false); + simdgroup_load (t, so + i*8, D, 0, false); simdgroup_multiply(t, ms1, t); simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t); @@ -3225,7 +3208,7 @@ kernel void kernel_flash_attn_ext( // store result to shared memory (reuse sq) if (sgitg == 0) { for (short i = 0; i < D8; ++i) { - simdgroup_store(lo[i], so + i*8, T, 0, false); + simdgroup_store(lo[i], so + i*8, D, 0, false); } } @@ -3237,21 +3220,12 @@ kernel void kernel_flash_attn_ext( const half S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*T4 + i]/S; + dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } } -#if defined(GGML_METAL_FORCE_FATTN_PREC_F16) -#define FA_TYPES \ - half, half4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - half, half4x4, simdgroup_half8x8, \ - float, simdgroup_float8x8, \ - half, simdgroup_half8x8, \ - half, half4, simdgroup_half8x8 -#else #define FA_TYPES \ half, half4, simdgroup_half8x8, \ half, half4x4, simdgroup_half8x8, \ @@ -3259,7 +3233,6 @@ kernel void kernel_flash_attn_ext( float, simdgroup_float8x8, \ float, simdgroup_float8x8, \ half, half4, simdgroup_half8x8 -#endif // TOOD: static_assert @@ -3339,24 +3312,21 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device float * dst, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant uint64_t & nb21, - constant uint64_t & nb22, - constant uint64_t & nb23, - constant uint64_t & nb31, - constant int64_t & ne1, - constant int64_t & ne2, + constant int32_t & ne01, + constant int32_t & ne02, + constant int32_t & ne03, + constant uint32_t & nb01, + constant uint32_t & nb02, + constant uint32_t & nb03, + constant int32_t & ne11, + constant int32_t & ne_12_2, // assume K and V are same shape + constant int32_t & ne_12_3, + constant uint32_t & nb_12_1, + constant uint32_t & nb_12_2, + constant uint32_t & nb_12_3, + constant uint32_t & nb31, + constant int32_t & ne1, + constant int32_t & ne2, constant float & scale, constant float & max_bias, constant float & m0, @@ -3426,23 +3396,12 @@ kernel void kernel_flash_attn_ext_vec( const short tx = tiisg%8; const short ty = tiisg/8; - // assume K and V are same shape - const short ne22 = ne12; - const short ne23 = ne13; + // broadcast kv + //const short rk2 = ne02/ne12; + //const short rk3 = ne03/ne13; - // broadcast k - const short rk2 = ne02/ne12; - const short rk3 = ne03/ne13; - - const short ik2 = iq2/rk2; - const short ik3 = iq3/rk3; - - // broadcast v - const short rv2 = ne02/ne22; - const short rv3 = ne03/ne23; - - const short iv2 = iq2/rv2; - const short iv3 = iq3/rv3; + const short ikv2 = iq2/(ne02/ne_12_2); + const short ikv3 = iq3/(ne03/ne_12_3); // load the queries from shared memory into local memory k4x4_t mq[D16/NW4]; @@ -3480,7 +3439,7 @@ kernel void kernel_flash_attn_ext_vec( for (short cc = 0; cc < C/4; ++cc) { s_t mqk = 0.0; - device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb11 + ik2*nb12 + ik3*nb13)); + device const block_q * pk = (device const block_q *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); #pragma unroll for (short ii = 0; ii < D16; ii += NW4) { @@ -3554,7 +3513,7 @@ kernel void kernel_flash_attn_ext_vec( { #pragma unroll for (short cc = 0; cc < C/4; ++cc) { - device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb21 + iv2*nb22 + iv3*nb23)); + device const block_q * pv4 = (device const block_q *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); const s4x4_t ms(ss[4*cc + ty]);