diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index defde6246..57eb34f13 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2776,11 +2776,11 @@ kernel void kernel_flash_attn_ext_vec_f16( const short iv3 = iq3 / rv3; // load the queries from shared memory into local memory - float4 mq[D4]; + float4 mq[D4/NW]; for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; - mq[i] = (float4) sq4[i]; + mq[ii/NW] = (float4) sq4[i]; } // pointer to the mask @@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16( mk[2] = (float4) pk4[i + 2*(nb11/8)]; mk[3] = (float4) pk4[i + 3*(nb11/8)]; - mqk += (float4) (mq[i] * mk); + mqk += (float4) (mq[ii/NW] * mk); } // reduce the results from the threads in the simdgroup @@ -2857,8 +2857,7 @@ kernel void kernel_flash_attn_ext_vec_f16( // O = diag(ms)*O #pragma unroll for (short ii = 0; ii < D4; ii += NW) { - const short i = ii + tiisg; - lo[i/NW] *= ms; + lo[ii/NW] *= ms; } } @@ -2872,10 +2871,10 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; - lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; - lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; - lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; } } }