diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 8e540e943..ebce329a9 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3052,8 +3052,6 @@ kernel void kernel_flash_attn_ext( // online softmax { - float ms[Q]; - for (short j = 0; j < Q; ++j) { const float m = M[j]; @@ -3072,18 +3070,17 @@ kernel void kernel_flash_attn_ext( M[j] = simd_max(max(M[j], s)); - ms[j] = exp(m - M[j]); - const float vs = exp(s - M[j]); + const float ms = exp(m - M[j]); + const float vs = exp(s - M[j]); - S[j] = S[j]*ms[j] + simd_sum(vs); + S[j] = S[j]*ms + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) ss[j*TS + tiisg] = vs; - } - // create a QxQ diagonal matrix for rescaling the output - if (tiisg < Q) { - ss[tiisg*TS + 2*C + tiisg] = ms[tiisg]; + if (tiisg == j) { + ss[j*TS + 2*C + j] = ms; + } } }