mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
cuda : switch to 1 warp for bs > 16
This commit is contained in:
parent
b958151e3f
commit
a7b471569b
@ -10933,7 +10933,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||
|
||||
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
||||
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
|
||||
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 2;
|
||||
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1;
|
||||
|
||||
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
||||
dim3 block_dim(32, nwarps, 1);
|
||||
|
Loading…
Reference in New Issue
Block a user