metal : cont + avoid potential int overflow [no ci]

This commit is contained in:
Georgi Gerganov 2024-11-09 16:39:36 +02:00
parent 089404f3a1
commit 593bc1aef5
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -2756,11 +2756,11 @@ template<
short KV = 8, // key/value processed per each simdgroup
short C = 32> // cache items per threadgroup
kernel void kernel_flash_attn_ext(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device char * dst,
constant ggml_metal_kargs_flash_attn_ext & args,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@ -3156,7 +3156,7 @@ kernel void kernel_flash_attn_ext(
const float S = ss[j*TS + 0];
for (short i = tiisg; i < D4; i += NW) {
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
}
}
}
@ -3248,11 +3248,11 @@ template<
short Q = 1, // queries per threadgroup
short C = 32> // cache items per threadgroup
kernel void kernel_flash_attn_ext_vec(
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device float * dst,
device const char * q,
device const char * k,
device const char * v,
device const char * mask,
device char * dst,
constant ggml_metal_kargs_flash_attn_ext & args,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
@ -3543,7 +3543,7 @@ kernel void kernel_flash_attn_ext_vec(
const float S = ss[0];
for (short i = tiisg; i < D16; i += NW) {
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
}
}
}