cuda : fix -INF block check

This commit is contained in:
Georgi Gerganov 2024-02-03 16:57:46 +02:00
parent 5b263dd83a
commit e04ff39181
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -6658,7 +6658,7 @@ static __global__ void flash_attn_ext_f16(
smax = warp_reduce_max(smax); smax = warp_reduce_max(smax);
// skip -INF blocks // skip -INF blocks
if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) { if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) {
continue; continue;
} }
@ -6676,8 +6676,10 @@ static __global__ void flash_attn_ext_f16(
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
} }
}
// restore zeros // restore zeros
for (int j = 0; j < Q16; ++j) {
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
} }