diff --git a/Makefile b/Makefile index eb1da90f1..097598455 100644 --- a/Makefile +++ b/Makefile @@ -876,6 +876,11 @@ endif # GGML_HIPBLAS ifdef GGML_METAL MK_CPPFLAGS += -DGGML_USE_METAL + +ifdef GGML_METAL_FORCE_FATTN_PREC_F16 + MK_CPPFLAGS += -DGGML_METAL_FORCE_FATTN_PREC_F16 +endif # GGML_METAL_FORCE_FATTN_PREC_F16 + MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit OBJ_GGML += ggml/src/ggml-metal.o ifdef GGML_METAL_NDEBUG diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index cfa6e3f70..2b902a3d8 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE "ggml: enable Vulkan validation" option(GGML_VULKAN_RUN_TESTS "ggml: run Vulkan tests" OFF) option(GGML_KOMPUTE "ggml: use Kompute" OFF) option(GGML_METAL "ggml: use Metal" ${GGML_METAL_DEFAULT}) +option(GGML_METAL_FORCE_FATTN_PREC_F16 "ggml: force F16 accumulators for FA kernels" OFF) option(GGML_METAL_NDEBUG "ggml: disable Metal debugging" OFF) option(GGML_METAL_SHADER_DEBUG "ggml: compile Metal with -fno-fast-math" OFF) option(GGML_METAL_EMBED_LIBRARY "ggml: embed Metal library" ${GGML_METAL}) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 34b81bd7f..f0030bf5e 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -58,6 +58,10 @@ if (GGML_METAL) add_compile_definitions(GGML_METAL_NDEBUG) endif() + if (GGML_METAL_FORCE_FATTN_PREC_F16) + add_compile_definitions(GGML_METAL_FORCE_FATTN_PREC_F16) + endif() + # copy ggml-common.h and ggml-metal.metal to bin directory configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 4e503e6ac..df77eb96e 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -12,9 +12,6 @@ #define MIN(a, b) ((a) < (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b)) -// TODO: for now, always use F32 for flash attention to avoid compiling 2 sets of kernels -#define GGML_METAL_FORCE_FATTN_PREC_F32 - // max memory buffers that can be mapped to the device #define GGML_METAL_MAX_BUFFERS 64 @@ -499,9 +496,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de // dictionary of preprocessor macros NSMutableDictionary * prep = [NSMutableDictionary dictionary]; - // add GGML_METAL_FORCE_FATTN_PREC_F32 -#if defined(GGML_METAL_FORCE_FATTN_PREC_F32) - [prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F32"]; +#if defined(GGML_METAL_FORCE_FATTN_PREC_F16) + [prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F16"]; #endif MTLCompileOptions * options = [MTLCompileOptions new]; @@ -554,6 +550,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de } } +#if defined(GGML_METAL_FORCE_FATTN_PREC_F16) + GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16 = yes\n", __func__); +#else + GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16 = no\n", __func__); +#endif GGML_LOG_INFO("%s: simdgroup reduction = %s\n", __func__, ctx_dev->has_simdgroup_reduction ? "true" : "false"); GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm ? "true" : "false"); GGML_LOG_INFO("%s: bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false"); @@ -3224,10 +3225,12 @@ static void ggml_metal_encode_node( GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); -#ifdef GGML_METAL_FORCE_FATTN_PREC_F32 +#ifdef GGML_METAL_FORCE_FATTN_PREC_F16 const enum ggml_prec prec = GGML_PREC_DEFAULT; #else - const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst); + // TODO: support both precisions + const enum ggml_prec prec = GGML_PREC_F32; + //const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst); #endif const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f338e317f..41354dc68 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2755,8 +2755,16 @@ kernel void kernel_leaky_relu_f32( } // ref: https://arxiv.org/pdf/2307.08691.pdf -// D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup -template +template< + typename block_q, + short nl, + void (*dequantize_func)(device const block_q *, short, thread half4x4 &), + typename s_t, // attention accumulation types + typename s8x8_t, + short D, // head size + short Q = 8, // queries per threadgroup + short KV = 8, // key/value processed per each simdgroup + short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext( device const char * q, device const char * k, @@ -2805,13 +2813,15 @@ kernel void kernel_flash_attn_ext( const short NW = N_SIMDWIDTH; const short SH = (C + Q); // shared memory per simdgroup in (half) - const short T = D + nsg*SH; // shared memory size per query in (half) - const short TF = T; // shared memory size per query in (float) - const short T4 = T/4; // shared memory size per query in (half4) + const short SF = sizeof(s_t)/sizeof(half); - threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data - threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 - threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + const short T = D + SF*nsg*SH; // shared memory size per query in (half) + const short TS = T/SF; // shared memory size per query in (s_t) + const short T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup s_t * ss = (threadgroup s_t *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix threadgroup half * skv = (threadgroup half *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4 @@ -2840,7 +2850,7 @@ kernel void kernel_flash_attn_ext( // zero out shared memory SH for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TF + i] = 0.0h; + ss[j*TS + i] = 0.0f; } } @@ -2905,7 +2915,7 @@ kernel void kernel_flash_attn_ext( // Q*K^T { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 mqk = make_filled_simdgroup_matrix(0.h); + s8x8_t mqk = make_filled_simdgroup_matrix(0.0f); // this is compile-time check, so it does not have runtime overhead if (is_same::value) { @@ -2962,7 +2972,7 @@ kernel void kernel_flash_attn_ext( } } - simdgroup_store(mqk, ss + 8*cc, TF, 0, false); + simdgroup_store(mqk, ss + 8*cc, TS, 0, false); } } @@ -2977,7 +2987,7 @@ kernel void kernel_flash_attn_ext( const float m = M[j]; // scale and apply the logitcap / mask - float s = ((float)(ss[j*TF + tiisg]))*scale; + float s = ((float)(ss[j*TS + tiisg]))*scale; if (logit_softcap != 0.0f) { s = logit_softcap*precise::tanh(s); @@ -2997,12 +3007,12 @@ kernel void kernel_flash_attn_ext( S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*TF + tiisg] = vs; + ss[j*TS + tiisg] = vs; } // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { - ss[tiisg*TF + C + tiisg] = ms[tiisg]; + ss[tiisg*TS + C + tiisg] = ms[tiisg]; } } @@ -3013,8 +3023,8 @@ kernel void kernel_flash_attn_ext( // O = diag(ms)*O { - simdgroup_half8x8 mm; - simdgroup_load(mm, ss + C, TF, 0, false); + s8x8_t mm; + simdgroup_load(mm, ss + C, TS, 0, false); for (short i = 0; i < D8; ++i) { simdgroup_multiply(lo[i], mm, lo[i]); @@ -3024,8 +3034,8 @@ kernel void kernel_flash_attn_ext( // O = O + (Q*K^T)*V { for (short cc = 0; cc < C/8; ++cc) { - simdgroup_half8x8 ms; - simdgroup_load(ms, ss + 8*cc, TF, 0, false); + s8x8_t ms; + simdgroup_load(ms, ss + 8*cc, TS, 0, false); if (is_same::value) { // we can read directly from global memory @@ -3087,8 +3097,8 @@ kernel void kernel_flash_attn_ext( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (short j = 0; j < Q; ++j) { if (tiisg == 0) { - ss[j*TF + 0] = S[j]; - ss[j*TF + 1] = M[j]; + ss[j*TS + 0] = S[j]; + ss[j*TS + 1] = M[j]; } } } @@ -3112,11 +3122,11 @@ kernel void kernel_flash_attn_ext( // the first simdgroup accumulates the results from the other simdgroups if (sgitg == 0) { for (short j = 0; j < Q; ++j) { - const float S0 = ss[j*TF + 0]; - const float S1 = ss[j*TF + sg*SH + 0]; + const float S0 = ss[j*TS + 0]; + const float S1 = ss[j*TS + sg*SH + 0]; - const float M0 = ss[j*TF + 1]; - const float M1 = ss[j*TF + sg*SH + 1]; + const float M0 = ss[j*TS + 1]; + const float M1 = ss[j*TS + sg*SH + 1]; M = max(M0, M1); @@ -3126,22 +3136,23 @@ kernel void kernel_flash_attn_ext( S = S0*ms0 + S1*ms1; if (tiisg == 0) { - ss[j*TF + 0] = S; - ss[j*TF + 1] = M; + ss[j*TS + 0] = S; + ss[j*TS + 1] = M; - ss[j*TF + C + j ] = ms0; - ss[j*TF + C + j + sg*SH] = ms1; + ss[j*TS + C + j ] = ms0; + ss[j*TS + C + j + sg*SH] = ms1; } } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 { simdgroup_half8x8 t; - simdgroup_half8x8 ms0; - simdgroup_half8x8 ms1; - simdgroup_load(ms0, ss + C, TF, 0, false); - simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); + s8x8_t ms0; + s8x8_t ms1; + + simdgroup_load(ms0, ss + C, TS, 0, false); + simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false); for (short i = 0; i < D8; ++i) { simdgroup_load (t, sq + i*8, T, 0, false); @@ -3165,7 +3176,7 @@ kernel void kernel_flash_attn_ext( // final rescale with 1/S and store to global memory if (sgitg == 0) { for (short j = 0; j < Q && iq1 + j < ne01; ++j) { - const float S = ss[j*TF + 0]; + const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; @@ -3174,49 +3185,57 @@ kernel void kernel_flash_attn_ext( } } -typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; +#if defined(GGML_METAL_FORCE_FATTN_PREC_F16) +#define S_T half +#define S8x8_T simdgroup_half8x8 +#else +#define S_T float +#define S8x8_T simdgroup_float8x8 +#endif -template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +typedef decltype(kernel_flash_attn_ext) flash_attn_ext_t; -template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; -template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; + +template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext; +template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext; // NOTE: can use half instead of float precision for some extra perf // however, by default use F32 since the op should be mostly memory bandwidth bound