From 963552903f51043ee947a8deeaaa7ec00bc3f1a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 12 Jun 2024 17:41:51 +0200 Subject: [PATCH] CUDA: fix broken oob check for FA vec f32 kernel (#7904) --- ggml-cuda/fattn-vec-f32.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml-cuda/fattn-vec-f32.cuh b/ggml-cuda/fattn-vec-f32.cuh index ddf0c8374..11a5e355f 100644 --- a/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml-cuda/fattn-vec-f32.cuh @@ -149,7 +149,7 @@ static __global__ void flash_attn_vec_ext_f32( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j ? Q_f2_j[i] : make_float2(0.0f, 0.0f); + Q_f2[j][i0/WARP_SIZE] = ncols <= 2 || ic0 + j < ne01 ? Q_f2_j[i] : make_float2(0.0f, 0.0f); Q_f2[j][i0/WARP_SIZE].x *= scale; Q_f2[j][i0/WARP_SIZE].y *= scale; }