mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-09 10:11:44 +00:00
remove ms array
This commit is contained in:
parent
984928109c
commit
61d05b57d9
@ -3052,8 +3052,6 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
{
|
{
|
||||||
float ms[Q];
|
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
const float m = M[j];
|
const float m = M[j];
|
||||||
|
|
||||||
@ -3072,18 +3070,17 @@ kernel void kernel_flash_attn_ext(
|
|||||||
|
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
ms[j] = exp(m - M[j]);
|
const float ms = exp(m - M[j]);
|
||||||
const float vs = exp(s - M[j]);
|
const float vs = exp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
S[j] = S[j]*ms + simd_sum(vs);
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*TS + tiisg] = vs;
|
ss[j*TS + tiisg] = vs;
|
||||||
}
|
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
if (tiisg == j) {
|
||||||
if (tiisg < Q) {
|
ss[j*TS + 2*C + j] = ms;
|
||||||
ss[tiisg*TS + 2*C + tiisg] = ms[tiisg];
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user