cuda : fix -INF block check

This commit is contained in:
Georgi Gerganov 2024-02-03 16:57:46 +02:00
parent 5b263dd83a
commit e04ff39181
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6658,7 +6658,7 @@ static __global__ void flash_attn_ext_f16(
smax = warp_reduce_max(smax);
// skip -INF blocks
if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) {
if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) {
continue;
}
@ -6676,8 +6676,10 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
}
}
// restore zeros
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
}