mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 11:40:17 +00:00
CUDA: fix FA out-of-bounds writes (#7465)
This commit is contained in:
parent
b18532a4ef
commit
38c03478a3
@ -238,6 +238,10 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (ic0 + j_VKQ >= ne01) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
||||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||||
|
|
||||||
|
@ -237,6 +237,10 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||||
|
|
||||||
|
if (ic0 + j_VKQ >= ne01) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
||||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||||
|
|
||||||
|
@ -212,6 +212,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||||
|
if (ic0 + j_VKQ >= ne01) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
||||||
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
||||||
|
|
||||||
@ -223,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && tid < ncols) {
|
if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
|
||||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
@ -200,6 +200,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||||
|
if (ic0 + j_VKQ >= ne01) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
kqsum[j_VKQ] = kqsum_shared[j_VKQ][threadIdx.x];
|
||||||
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
|
||||||
|
|
||||||
@ -211,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && tid < ncols) {
|
if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
|
||||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user