CUDA: fix broken oob check for FA vec f32 kernel (#7904)

This commit is contained in:
Johannes Gäßler 2024-06-12 17:41:51 +02:00 committed by GitHub
parent a9cae48003
commit 963552903f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -149,7 +149,7 @@ static __global__ void flash_attn_vec_ext_f32(
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f);
Q_f2[j][i0/WARP_SIZE].x *= scale;
Q_f2[j][i0/WARP_SIZE].y *= scale;
}