mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-08 09:41:45 +00:00
metal : cont + avoid potential int overflow [no ci]
This commit is contained in:
parent
089404f3a1
commit
593bc1aef5
@ -2760,7 +2760,7 @@ kernel void kernel_flash_attn_ext(
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
@ -3156,7 +3156,7 @@ kernel void kernel_flash_attn_ext(
|
||||
const float S = ss[j*TS + 0];
|
||||
|
||||
for (short i = tiisg; i < D4; i += NW) {
|
||||
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||
dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -3252,7 +3252,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
device const char * k,
|
||||
device const char * v,
|
||||
device const char * mask,
|
||||
device float * dst,
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||
threadgroup half * shared [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
@ -3543,7 +3543,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const float S = ss[0];
|
||||
|
||||
for (short i = tiisg; i < D16; i += NW) {
|
||||
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
||||
dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user