diff --git a/src/llama.cpp b/src/llama.cpp index 39d80d33e..d2ee2d6ad 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17402,6 +17402,11 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } + if (params.flash_attn && model->arch == LLM_ARCH_GEMMA2) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with Gemma2 - forcing off\n", __func__); + params.flash_attn = false; + } + if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); params.flash_attn = false;