mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
cuda : fix -INF block check
This commit is contained in:
parent
5b263dd83a
commit
e04ff39181
@ -6658,7 +6658,7 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
smax = warp_reduce_max(smax);
|
smax = warp_reduce_max(smax);
|
||||||
|
|
||||||
// skip -INF blocks
|
// skip -INF blocks
|
||||||
if (__hisinf(smax.x) == -1 || __hisinf(smax.y) == -1) {
|
if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6676,8 +6676,10 @@ static __global__ void flash_attn_ext_f16(
|
|||||||
|
|
||||||
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
|
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// restore zeros
|
// restore zeros
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
|
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user