From 40e717263e72e3af1f83b2a92499615fd391b0a5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 2 Nov 2024 20:41:42 +0200 Subject: [PATCH] metal : minor fixup in FA kernel ggml-ci --- ggml/src/ggml-metal.metal | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index defde6246..6fc114e29 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - float4 mq[D4]; + float4 mq[D4/NW]; for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; - mq[i] = (float4) sq4[i]; + mq[i/NW] = (float4) sq4[i]; } // pointer to the mask @@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16( mk[2] = (float4) pk4[i + 2*(nb11/8)]; mk[3] = (float4) pk4[i + 3*(nb11/8)]; - mqk += (float4) (mq[i] * mk); + mqk += (float4) (mq[i/NW] * mk); } // reduce the results from the threads in the simdgroup