metal : use F32 prec for K*Q in vec FA

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-09-22 21:56:31 +03:00
parent c35e586ea5
commit 5d888c48a3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -2631,11 +2631,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
const short iv3 = iq3 / rv3; const short iv3 = iq3 / rv3;
// load the queries from shared memory into local memory // load the queries from shared memory into local memory
half4 mq[D4]; float4 mq[D4];
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg; short i = ii + tiisg;
mq[i] = sq4[i]; mq[i] = (float4) sq4[i];
} }
// pointer to the mask // pointer to the mask
@ -2661,11 +2661,11 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short ii = 0; ii < D4; ii += NW) { for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg; const short i = ii + tiisg;
half4x4 mk; float4x4 mk;
mk[0] = pk4[i + 0*(nb11/8)]; mk[0] = (float4) pk4[i + 0*(nb11/8)];
mk[1] = pk4[i + 1*(nb11/8)]; mk[1] = (float4) pk4[i + 1*(nb11/8)];
mk[2] = pk4[i + 2*(nb11/8)]; mk[2] = (float4) pk4[i + 2*(nb11/8)];
mk[3] = pk4[i + 3*(nb11/8)]; mk[3] = (float4) pk4[i + 3*(nb11/8)];
mqk += (float4) (mq[i] * mk); mqk += (float4) (mq[i] * mk);
} }