mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
cuda : fix -INF block check
This commit is contained in:
parent
5b263dd83a
commit
e04ff39181
@ -6658,7 +6658,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
smax = warp_reduce_max(smax);
|
||||
|
||||
// skip -INF blocks
|
||||
if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) {
|
||||
if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -6676,8 +6676,10 @@ static __global__ void flash_attn_ext_f16(
|
||||
|
||||
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
|
||||
}
|
||||
}
|
||||
|
||||
// restore zeros
|
||||
for (int j = 0; j < Q16; ++j) {
|
||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user