mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 17:21:46 +00:00
metal : cont + avoid potential int overflow [no ci]
This commit is contained in:
parent
089404f3a1
commit
593bc1aef5
@ -2756,11 +2756,11 @@ template<
|
||||
short KV = 8, // key/value processed per each simdgroup
|
||||
short C = 32> // cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext(
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
@ -3156,7 +3156,7 @@ kernel void kernel_flash_attn_ext(
|
||||
const float S = ss[j*TS + 0];
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3248,11 +3248,11 @@ template<
|
||||
short Q = 1, // queries per threadgroup
|
||||
short C = 32> // cache items per threadgroup
|
||||
kernel void kernel_flash_attn_ext_vec(
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
device const char * q,
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
@ -3543,7 +3543,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const float S = ss[0];
|
||||
|
||||
for (short i = tiisg; i < D16; i += NW) {
|
||||
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
||||
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user