mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
metal : another fix for the fa kernel
This commit is contained in:
parent
7a3df798fc
commit
a95225cdfd
@ -2144,6 +2144,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
const short tx = tiisg%4;
|
const short tx = tiisg%4;
|
||||||
const short ty = tiisg/4;
|
const short ty = tiisg/4;
|
||||||
|
|
||||||
|
if (iq1 + ty < ne01) {
|
||||||
// mqk = mqk*scale
|
// mqk = mqk*scale
|
||||||
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
ss[8*cc + ty*TF + 2*tx + 0] *= scale;
|
||||||
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
||||||
@ -2160,6 +2161,7 @@ kernel void kernel_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// used to detect blocks full of -INF
|
// used to detect blocks full of -INF
|
||||||
float smax = -INFINITY;
|
float smax = -INFINITY;
|
||||||
|
Loading…
Reference in New Issue
Block a user