mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 03:31:46 +00:00
metal : fix bug in soft_max kernels (out-of-bounds access) (#3194)
This commit is contained in:
parent
e3d87a6c36
commit
c6f1491da0
@ -118,7 +118,7 @@ kernel void kernel_soft_max(
|
|||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float lmax = psrc0[tpitg[0]];
|
float lmax = tpitg[0] < ne00 ? psrc0[tpitg[0]] : -INFINITY;
|
||||||
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
|
||||||
lmax = MAX(lmax, psrc0[i00]);
|
lmax = MAX(lmax, psrc0[i00]);
|
||||||
}
|
}
|
||||||
@ -158,7 +158,7 @@ kernel void kernel_soft_max_4(
|
|||||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float4 lmax4 = psrc4[tpitg[0]];
|
float4 lmax4 = tpitg[0] < ne00/4 ? psrc4[tpitg[0]] : -INFINITY;
|
||||||
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]);
|
lmax4 = fmax(lmax4, psrc4[i00]);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user