mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
CUDA: fix FA out-of-bounds reads (#7479)
This commit is contained in:
parent
1e374365d1
commit
cd93a28cb1
@ -83,7 +83,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
const int i = i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
|
||||||
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -79,7 +79,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
|
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
|
||||||
float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x];
|
float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
|
||||||
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
|
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
|
||||||
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
|
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
|
||||||
}
|
}
|
||||||
|
@ -94,7 +94,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
const int i = i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
const float2 tmp = Q_f2[j*(nb01/sizeof(float2)) + i];
|
const float2 tmp = ncols <= 2 || ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
|
||||||
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
Q_h2[j][i0/WARP_SIZE] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -212,7 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||||
if (ic0 + j_VKQ >= ne01) {
|
if (ncols > 2 && ic0 + j_VKQ >= ne01) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -227,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
|
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
@ -91,7 +91,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
|
||||||
const int i = i0 + threadIdx.x;
|
const int i = i0 + threadIdx.x;
|
||||||
|
|
||||||
Q_h2[j][i0/WARP_SIZE] = Q_f2[j*(nb01/sizeof(float2)) + i];
|
Q_h2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
|
||||||
Q_h2[j][i0/WARP_SIZE].x *= scale;
|
Q_h2[j][i0/WARP_SIZE].x *= scale;
|
||||||
Q_h2[j][i0/WARP_SIZE].y *= scale;
|
Q_h2[j][i0/WARP_SIZE].y *= scale;
|
||||||
}
|
}
|
||||||
@ -200,7 +200,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
for (int j_VKQ = 0; j_VKQ < ncols; ++j_VKQ) {
|
||||||
if (ic0 + j_VKQ >= ne01) {
|
if (ncols > 2 && ic0 + j_VKQ >= ne01) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +215,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||||||
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (parallel_blocks != 1 && tid < ncols && ic0 + tid < ne01) {
|
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||||
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user