diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 7e1517414..1f233ba7f 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3450,7 +3450,7 @@ kernel void kernel_flash_attn_ext_vec( { // each simdgroup processes 1 query and 4 keys for (short cc = 0; cc < C/4; ++cc) { - qk_t mqk = 0.0; + qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); @@ -3461,13 +3461,14 @@ kernel void kernel_flash_attn_ext_vec( k4x4_t mk; deq_k(pk + i/nl_k, i%nl_k, mk); - mqk += - dot(mq[ii/NL][0], mk[0]) + - dot(mq[ii/NL][1], mk[1]) + - dot(mq[ii/NL][2], mk[2]) + - dot(mq[ii/NL][3], mk[3]); + mqka[0] += dot(mq[ii/NL][0], mk[0]); + mqka[1] += dot(mq[ii/NL][1], mk[1]); + mqka[2] += dot(mq[ii/NL][2], mk[2]); + mqka[3] += dot(mq[ii/NL][3], mk[3]); } + qk_t mqk = mqka[0] + mqka[1] + mqka[2] + mqka[3]; + // simdgroup reduce // [ 0 .. 7] -> [ 0] // [ 8 .. 15] -> [ 8]