From eefc132bb7707c3faff6a75b925c35e7948ec37c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 6 Nov 2024 15:33:30 +0200 Subject: [PATCH] metal : use F16 precision in FA kernel --- ggml/src/ggml-metal.m | 18 +++++++++++++++++- ggml/src/ggml-metal.metal | 25 +++++++++++++------------ 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index f13adee38..4e503e6ac 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -12,6 +12,9 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) +// TODO: for now, always use F32 for flash attention to avoid compiling 2 sets of kernels +#define GGML_METAL_FORCE_FATTN_PREC_F32 + // max memory buffers that can be mapped to the device #define GGML_METAL_MAX_BUFFERS 64 @@ -496,6 +499,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de // dictionary of preprocessor macros NSMutableDictionary * prep = [NSMutableDictionary dictionary]; + // add GGML_METAL_FORCE_FATTN_PREC_F32 +#if defined(GGML_METAL_FORCE_FATTN_PREC_F32) + [prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F32"]; +#endif + MTLCompileOptions * options = [MTLCompileOptions new]; options.preprocessorMacros = prep; @@ -3216,11 +3224,19 @@ static void ggml_metal_encode_node( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); +#ifdef GGML_METAL_FORCE_FATTN_PREC_F32 + const enum ggml_prec prec = GGML_PREC_DEFAULT; +#else + const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst); +#endif + + const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2; + // 16*32*(nsg) // the shared memory needed for the simdgroups to load the KV cache // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG // -#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) int64_t nsgmax = 2; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 16b5da3ff..f338e317f 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2805,13 +2805,13 @@ kernel void kernel_flash_attn_ext( const short NW = N_SIMDWIDTH; const short SH = (C + Q); // shared memory per simdgroup in (half) - const short T = D + 2*nsg*SH; // shared memory size per query in (half) - const short TF = T/2; // shared memory size per query in (float) + const short T = D + nsg*SH; // shared memory size per query in (half) + const short TF = T; // shared memory size per query in (float) const short T4 = T/4; // shared memory size per query in (half4) - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup float * ss = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4 @@ -2840,7 +2840,7 @@ kernel void kernel_flash_attn_ext( // zero out shared memory SH for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TF + i] = 0.0f; + ss[j*TF + i] = 0.0h; } } @@ -2905,7 +2905,7 @@ kernel void kernel_flash_attn_ext( // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_float8x8 mqk = make_filled_simdgroup_matrix(0.h); + simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); // this is compile-time check, so it does not have runtime overhead if (is_same::value) { @@ -2977,7 +2977,7 @@ kernel void kernel_flash_attn_ext( const float m = M[j]; // scale and apply the logitcap / mask - float s = ss[j*TF + tiisg]*scale; + float s = ((float)(ss[j*TF + tiisg]))*scale; if (logit_softcap != 0.0f) { s = logit_softcap*precise::tanh(s); @@ -3013,7 +3013,7 @@ kernel void kernel_flash_attn_ext( // O = diag(ms)*O { - simdgroup_float8x8 mm; + simdgroup_half8x8 mm; simdgroup_load(mm, ss + C, TF, 0, false); for (short i = 0; i < D8; ++i) { @@ -3024,7 +3024,7 @@ kernel void kernel_flash_attn_ext( // O = O + (Q*K^T)*V { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_float8x8 ms; + simdgroup_half8x8 ms; simdgroup_load(ms, ss + 8*cc, TF, 0, false); if (is_same::value) { @@ -3137,8 +3137,8 @@ kernel void kernel_flash_attn_ext( // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 { simdgroup_half8x8 t; - simdgroup_float8x8 ms0; - simdgroup_float8x8 ms1; + simdgroup_half8x8 ms0; + simdgroup_half8x8 ms1; simdgroup_load(ms0, ss + C, TF, 0, false); simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); @@ -3219,6 +3219,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_ template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; // NOTE: can use half instead of float precision for some extra perf +// however, by default use F32 since the op should be mostly memory bandwidth bound // D - head size, Q - queries per threadgroup, C - cache items per threadgroup template kernel void kernel_flash_attn_ext_vec(