mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
metal : separate scale and mask from QKT in FA kernel (#9189)
* metal : separate scale and mask from QKT in FA kernel * metal : ne01 check no longer necessary * metal : keep data in local memory
This commit is contained in:
parent
fc18425b6a
commit
06658ad7c3
@ -2261,24 +2261,6 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
|
|
||||||
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
|
||||||
|
|
||||||
const short tx = tiisg%4;
|
|
||||||
const short ty = tiisg/4;
|
|
||||||
|
|
||||||
// mqk = mqk*scale
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
|
||||||
|
|
||||||
if (logit_softcap != 0.0f) {
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mask != q) {
|
|
||||||
// mqk = mqk + mask*slope
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2290,10 +2272,19 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
float ms[Q];
|
float ms[Q];
|
||||||
|
|
||||||
for (short j = 0; j < Q; ++j) {
|
for (short j = 0; j < Q; ++j) {
|
||||||
const short p = tiisg;
|
|
||||||
|
|
||||||
const float m = M[j];
|
const float m = M[j];
|
||||||
const float s = ss[j*TF + p];
|
|
||||||
|
// scale and apply the logitcap / mask
|
||||||
|
float s = ss[j*TF + tiisg]*scale;
|
||||||
|
|
||||||
|
if (logit_softcap != 0.0f) {
|
||||||
|
s = logit_softcap*precise::tanh(s);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (mask != q) {
|
||||||
|
// mqk = mqk + mask*slope
|
||||||
|
s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
|
||||||
|
}
|
||||||
|
|
||||||
smax = simd_max(max(smax, s));
|
smax = simd_max(max(smax, s));
|
||||||
M[j] = simd_max(max(M[j], s));
|
M[j] = simd_max(max(M[j], s));
|
||||||
@ -2304,7 +2295,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
S[j] = S[j]*ms[j] + simd_sum(vs);
|
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*TF + p] = vs;
|
ss[j*TF + tiisg] = vs;
|
||||||
}
|
}
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
|
Loading…
Reference in New Issue
Block a user