diff --git a/ggml-cuda.cu b/ggml-cuda.cu index dbd482239..e51ddc08f 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6658,7 +6658,7 @@ static __global__ void flash_attn_ext_f16( smax = warp_reduce_max(smax); // skip -INF blocks - if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { + if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { continue; } @@ -6676,8 +6676,10 @@ static __global__ void flash_attn_ext_f16( nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); } + } - // restore zeros + // restore zeros + for (int j = 0; j < Q16; ++j) { nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); }