mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 11:40:17 +00:00
metal : fix fa kernel
This commit is contained in:
parent
7a3df798fc
commit
aa931d0375
@ -2144,19 +2144,26 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
const short tx = tiisg%4;
|
const short tx = tiisg%4;
|
||||||
const short ty = tiisg/4;
|
const short ty = tiisg/4;
|
||||||
|
|
||||||
|
if (logit_softcap == 0.0f) {
|
||||||
|
if (mask != q) {
|
||||||
|
// mqk = mqk*scale + mask*slope
|
||||||
|
ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
||||||
|
ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
||||||
|
} else {
|
||||||
// mqk = mqk*scale
|
// mqk = mqk*scale
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
||||||
|
|
||||||
if (logit_softcap != 0.0f) {
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
|
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
if (mask != q) {
|
if (mask != q) {
|
||||||
// mqk = mqk + mask*slope
|
// mqk = ls*tanh(mqk*scale) + mask*slope
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 0]) + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 1]) + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
|
||||||
|
} else {
|
||||||
|
// mqk = ls*tanh(mqk*scale)
|
||||||
|
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 0]);
|
||||||
|
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(scale*ss[8*cc + ty*TF + 2*tx + 1]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user