mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 10:54:36 +00:00
CUDA: fix Gemma 2 numerical issues for FA (#9166)
This commit is contained in:
parent
e11bd856d5
commit
f91fc5639b
@ -8877,7 +8877,7 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
|
||||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user