remove ms array

This commit is contained in:
Georgi Gerganov 2024-11-07 13:35:33 +02:00
parent 984928109c
commit 61d05b57d9
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -3052,8 +3052,6 @@ kernel void kernel_flash_attn_ext(
// online softmax
{
float ms[Q];
for (short j = 0; j < Q; ++j) {
const float m = M[j];
@ -3072,18 +3070,17 @@ kernel void kernel_flash_attn_ext(
M[j] = simd_max(max(M[j], s));
ms[j] = exp(m - M[j]);
const float vs = exp(s - M[j]);
const float ms = exp(m - M[j]);
const float vs = exp(s - M[j]);
S[j] = S[j]*ms[j] + simd_sum(vs);
S[j] = S[j]*ms + simd_sum(vs);
// the P matrix from the paper (Q rows, C columns)
ss[j*TS + tiisg] = vs;
}
// create a QxQ diagonal matrix for rescaling the output
if (tiisg < Q) {
ss[tiisg*TS + 2*C + tiisg] = ms[tiisg];
if (tiisg == j) {
ss[j*TS + 2*C + j] = ms;
}
}
}