mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
llama: Add attention and final logit soft-capping, update scaling factor to Gemma2 (#8197)
* Add attention and final logit softcapping. * fix * Add custom add_ functions * Disable flash attention for Gemma2 * Update src/llama.cpp Co-authored-by: slaren <slarengh@gmail.com> * Add default value for attention and final logit softcap value * Add custom kq scaling from Gemma2Attention * Remove custom pre attention scaling and use computed value instead. --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
72272b83a3
commit
1c5eba6f8e
@ -2363,6 +2363,12 @@ class Gemma2Model(Model):
|
|||||||
self.gguf_writer.add_key_length(hparams["head_dim"])
|
self.gguf_writer.add_key_length(hparams["head_dim"])
|
||||||
self.gguf_writer.add_value_length(hparams["head_dim"])
|
self.gguf_writer.add_value_length(hparams["head_dim"])
|
||||||
self.gguf_writer.add_file_type(self.ftype)
|
self.gguf_writer.add_file_type(self.ftype)
|
||||||
|
self.gguf_writer.add_attn_logit_softcapping(
|
||||||
|
self.hparams["attn_logit_softcapping"]
|
||||||
|
)
|
||||||
|
self.gguf_writer.add_final_logit_softcapping(
|
||||||
|
self.hparams["final_logit_softcapping"]
|
||||||
|
)
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
del bid # unusem
|
del bid # unusem
|
||||||
|
@ -50,6 +50,8 @@ class Keys:
|
|||||||
POOLING_TYPE = "{arch}.pooling_type"
|
POOLING_TYPE = "{arch}.pooling_type"
|
||||||
LOGIT_SCALE = "{arch}.logit_scale"
|
LOGIT_SCALE = "{arch}.logit_scale"
|
||||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||||
|
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
|
||||||
|
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
|
||||||
|
|
||||||
class Attention:
|
class Attention:
|
||||||
HEAD_COUNT = "{arch}.attention.head_count"
|
HEAD_COUNT = "{arch}.attention.head_count"
|
||||||
|
@ -516,6 +516,12 @@ class GGUFWriter:
|
|||||||
def add_logit_scale(self, value: float) -> None:
|
def add_logit_scale(self, value: float) -> None:
|
||||||
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
|
self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_attn_logit_softcapping(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_final_logit_softcapping(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_expert_count(self, count: int) -> None:
|
def add_expert_count(self, count: int) -> None:
|
||||||
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
|
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)
|
||||||
|
|
||||||
|
@ -302,6 +302,8 @@ enum llm_kv {
|
|||||||
LLM_KV_POOLING_TYPE,
|
LLM_KV_POOLING_TYPE,
|
||||||
LLM_KV_LOGIT_SCALE,
|
LLM_KV_LOGIT_SCALE,
|
||||||
LLM_KV_DECODER_START_TOKEN_ID,
|
LLM_KV_DECODER_START_TOKEN_ID,
|
||||||
|
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
|
||||||
|
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
|
||||||
|
|
||||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||||
@ -392,6 +394,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||||||
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
|
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
|
||||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||||
|
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
|
||||||
|
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
|
||||||
|
|
||||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||||
@ -2099,6 +2103,9 @@ struct llama_hparams {
|
|||||||
float f_norm_eps;
|
float f_norm_eps;
|
||||||
float f_norm_rms_eps;
|
float f_norm_rms_eps;
|
||||||
|
|
||||||
|
float f_attn_logit_softcapping = 50.0f;
|
||||||
|
float f_final_logit_softcapping = 30.0f;
|
||||||
|
|
||||||
float rope_attn_factor = 1.0f;
|
float rope_attn_factor = 1.0f;
|
||||||
float rope_freq_base_train;
|
float rope_freq_base_train;
|
||||||
float rope_freq_scale_train;
|
float rope_freq_scale_train;
|
||||||
@ -2117,6 +2124,7 @@ struct llama_hparams {
|
|||||||
|
|
||||||
bool causal_attn = true;
|
bool causal_attn = true;
|
||||||
bool use_alibi = false;
|
bool use_alibi = false;
|
||||||
|
bool attn_soft_cap = false;
|
||||||
|
|
||||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||||
@ -4702,6 +4710,9 @@ static void llm_load_hparams(
|
|||||||
case LLM_ARCH_GEMMA2:
|
case LLM_ARCH_GEMMA2:
|
||||||
{
|
{
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
|
||||||
|
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
|
||||||
|
hparams.attn_soft_cap = true;
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 42: model.type = e_model::MODEL_9B; break;
|
case 42: model.type = e_model::MODEL_9B; break;
|
||||||
@ -7579,6 +7590,12 @@ static struct ggml_tensor * llm_build_kqv(
|
|||||||
kq = ggml_scale(ctx, kq, 30);
|
kq = ggml_scale(ctx, kq, 30);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (hparams.attn_soft_cap) {
|
||||||
|
kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping);
|
||||||
|
kq = ggml_tanh(ctx, kq);
|
||||||
|
kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping);
|
||||||
|
}
|
||||||
|
|
||||||
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||||
cb(kq, "kq_soft_max_ext", il);
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
|
|
||||||
@ -11039,7 +11056,7 @@ struct llm_build_context {
|
|||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
|
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head)));
|
||||||
cb(Qcur, "Qcur_scaled", il);
|
cb(Qcur, "Qcur_scaled", il);
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
@ -11106,6 +11123,12 @@ struct llm_build_context {
|
|||||||
|
|
||||||
// lm_head
|
// lm_head
|
||||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||||
|
|
||||||
|
// final logit soft-capping
|
||||||
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
||||||
|
cur = ggml_tanh(ctx0, cur);
|
||||||
|
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
|
||||||
|
|
||||||
cb(cur, "result_output", -1);
|
cb(cur, "result_output", -1);
|
||||||
|
|
||||||
ggml_build_forward_expand(gf, cur);
|
ggml_build_forward_expand(gf, cur);
|
||||||
@ -17379,6 +17402,12 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
params.flash_attn = false;
|
params.flash_attn = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (params.flash_attn && model->hparams.attn_soft_cap) {
|
||||||
|
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
|
||||||
|
params.flash_attn = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
||||||
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
||||||
params.flash_attn = false;
|
params.flash_attn = false;
|
||||||
|
Loading…
Reference in New Issue
Block a user