From 0f6f1c789cf3517e0955611caf4b01ff47bd3fb4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 16:22:29 +0200 Subject: [PATCH] metal : more precise Q*K in FA vec kernel --- ggml/src/ggml-metal.metal | 35 ++++++++++++++++++++++------------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 413661c8a..e8b71a9f8 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2942,6 +2942,7 @@ kernel void kernel_flash_attn_ext( half smax = -INFINITY; // load the mask in shared memory + #pragma unroll(Q) for (short j = 0; j < Q; ++j) { device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31); @@ -2968,7 +2969,7 @@ kernel void kernel_flash_attn_ext( // we can read directly from global memory device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); -#pragma unroll + #pragma unroll(D8) for (short i = 0; i < D8; ++i) { k8x8_t mk; simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 @@ -2989,7 +2990,7 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); -#pragma unroll + #pragma unroll(4) for (short k = 0; k < 4; ++k) { k8x8_t mk; @@ -3067,7 +3068,7 @@ kernel void kernel_flash_attn_ext( s8x8_t mm; simdgroup_load(mm, ss + 2*C, TS, 0, false); -#pragma unroll + #pragma unroll(D8) for (short i = 0; i < D8; ++i) { simdgroup_multiply(lo[i], mm, lo[i]); } @@ -3082,7 +3083,8 @@ kernel void kernel_flash_attn_ext( if (is_same::value) { // we can read directly from global memory device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); -#pragma unroll + + #pragma unroll(D8) for (short i = 0; i < D8; ++i) { v8x8_t mv; simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 @@ -3103,7 +3105,7 @@ kernel void kernel_flash_attn_ext( simdgroup_barrier(mem_flags::mem_threadgroup); -#pragma unroll + #pragma unroll(4) for (short k = 0; k < 4; ++k) { v8x8_t mv; @@ -3196,6 +3198,7 @@ kernel void kernel_flash_attn_ext( simdgroup_load(ms0, ss + 2*C, TS, 0, false); simdgroup_load(ms1, ss + 2*C + sg*SH, TS, 0, false); + #pragma unroll(D8) for (short i = 0; i < D8; ++i) { o8x8_t t; @@ -3413,6 +3416,7 @@ kernel void kernel_flash_attn_ext_vec( // load the queries from shared memory into local memory q4x4_t mq[D16/NL]; + #pragma unroll(D16/NL) for (short ii = 0; ii < D16; ii += NL) { mq[ii/NL] = sq4x4[ii + tx]; } @@ -3454,17 +3458,23 @@ kernel void kernel_flash_attn_ext_vec( device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); -#pragma unroll + #pragma unroll(D16/NL) for (short ii = 0; ii < D16; ii += NL) { const short i = ii + tx; k4x4_t mk; deq_k(pk + i/nl_k, i%nl_k, mk); - mqka[0] += dot(mq[ii/NL][0], mk[0]); - mqka[1] += dot(mq[ii/NL][1], mk[1]); - mqka[2] += dot(mq[ii/NL][2], mk[2]); - mqka[3] += dot(mq[ii/NL][3], mk[3]); + // note: this is less precise than the version below + //mqka[0] += dot(mq[ii/NL][0], mk[0]); + //mqka[1] += dot(mq[ii/NL][1], mk[1]); + //mqka[2] += dot(mq[ii/NL][2], mk[2]); + //mqka[3] += dot(mq[ii/NL][3], mk[3]); + + mqka[0] += dot((float4) mq[ii/NL][0], (float4) mk[0]); + mqka[1] += dot((float4) mq[ii/NL][1], (float4) mk[1]); + mqka[2] += dot((float4) mq[ii/NL][2], (float4) mk[2]); + mqka[3] += dot((float4) mq[ii/NL][3], (float4) mk[3]); } qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3]; @@ -3513,7 +3523,7 @@ kernel void kernel_flash_attn_ext_vec( ss[tiisg] = vs; // O = diag(ms)*O -#pragma unroll + #pragma unroll(D16/NL) for (short ii = 0; ii < D16; ii += NL) { lo[ii/NL] *= ms; } @@ -3523,13 +3533,12 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { -#pragma unroll for (short cc = 0; cc < C/4; ++cc) { device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); const s4x4_t ms(ss[4*cc + ty]); -#pragma unroll + #pragma unroll(D16/NL) for (short ii = 0; ii < D16; ii += NL) { const short i = ii + tx;