mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 10:41:47 +00:00
metal : use F32 prec for K*Q in vec FA
ggml-ci
This commit is contained in:
parent
c35e586ea5
commit
5d888c48a3
@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||
const short iv3 = iq3 / rv3;
|
||||
|
||||
// load the queries from shared memory into local memory
|
||||
half4 mq[D4];
|
||||
float4 mq[D4];
|
||||
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
mq[i] = sq4[i];
|
||||
mq[i] = (float4) sq4[i];
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
const short i = ii + tiisg;
|
||||
|
||||
half4x4 mk;
|
||||
mk[0] = pk4[i + 0*(nb11/8)];
|
||||
mk[1] = pk4[i + 1*(nb11/8)];
|
||||
mk[2] = pk4[i + 2*(nb11/8)];
|
||||
mk[3] = pk4[i + 3*(nb11/8)];
|
||||
float4x4 mk;
|
||||
mk[0] = (float4) pk4[i + 0*(nb11/8)];
|
||||
mk[1] = (float4) pk4[i + 1*(nb11/8)];
|
||||
mk[2] = (float4) pk4[i + 2*(nb11/8)];
|
||||
mk[3] = (float4) pk4[i + 3*(nb11/8)];
|
||||
|
||||
mqk += (float4) (mq[i] * mk);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user