mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
update: support Qwen2-57B-A14B (#7835)
* update: convert-hf-to-gguf.py to support Qwen2-57B-A14B * fix: QWEN2MOE support for expert_feed_forward_length previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shared_exp are now properly calculated * update: convert-hf-to-gguf.py cleanup for Qwen2MoeForCausalLM * fix: QWEN2MOE support for expert_feed_forward_length previously, expert ff was taken from n_ff (intermediate size) but it is now properly taken from LLM_KV_EXPERT_FEED_FORWARD_LENGTH n_ff_exp and n_ff_shexp are now properly calculated
This commit is contained in:
parent
5b6da18750
commit
a94e6ff877
@ -1632,6 +1632,12 @@ class Qwen2MoeModel(Model):
|
|||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||||
self.gguf_writer.add_expert_count(n_experts)
|
self.gguf_writer.add_expert_count(n_experts)
|
||||||
|
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||||
|
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||||
|
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
|
||||||
|
if (shared_expert_intermediate_size := self.hparams.get('shared_expert_intermediate_size')) is not None:
|
||||||
|
self.gguf_writer.add_expert_shared_feed_forward_length(shared_expert_intermediate_size)
|
||||||
|
logger.info(f"gguf: expert shared feed forward length = {shared_expert_intermediate_size}")
|
||||||
|
|
||||||
_experts: list[dict[str, Tensor]] | None = None
|
_experts: list[dict[str, Tensor]] | None = None
|
||||||
|
|
||||||
|
@ -40,6 +40,7 @@ class Keys:
|
|||||||
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
|
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
|
||||||
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
||||||
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
|
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
|
||||||
|
EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
|
||||||
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
||||||
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
||||||
EXPERT_COUNT = "{arch}.expert_count"
|
EXPERT_COUNT = "{arch}.expert_count"
|
||||||
|
@ -394,6 +394,9 @@ class GGUFWriter:
|
|||||||
def add_expert_feed_forward_length(self, length: int) -> None:
|
def add_expert_feed_forward_length(self, length: int) -> None:
|
||||||
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
||||||
|
def add_expert_shared_feed_forward_length(self, length: int) -> None:
|
||||||
|
self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||||
|
|
||||||
def add_parallel_residual(self, use: bool) -> None:
|
def add_parallel_residual(self, use: bool) -> None:
|
||||||
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
||||||
|
|
||||||
|
21
llama.cpp
21
llama.cpp
@ -286,6 +286,7 @@ enum llm_kv {
|
|||||||
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
|
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
|
||||||
LLM_KV_FEED_FORWARD_LENGTH,
|
LLM_KV_FEED_FORWARD_LENGTH,
|
||||||
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
|
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
|
||||||
|
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
|
||||||
LLM_KV_USE_PARALLEL_RESIDUAL,
|
LLM_KV_USE_PARALLEL_RESIDUAL,
|
||||||
LLM_KV_TENSOR_DATA_LAYOUT,
|
LLM_KV_TENSOR_DATA_LAYOUT,
|
||||||
LLM_KV_EXPERT_COUNT,
|
LLM_KV_EXPERT_COUNT,
|
||||||
@ -371,6 +372,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||||||
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
|
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
|
||||||
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
|
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
|
||||||
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
|
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
|
||||||
|
{ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
|
||||||
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
|
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
|
||||||
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
|
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
|
||||||
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
||||||
@ -1970,6 +1972,7 @@ struct llama_hparams {
|
|||||||
uint32_t n_lora_q = 0;
|
uint32_t n_lora_q = 0;
|
||||||
uint32_t n_lora_kv = 0;
|
uint32_t n_lora_kv = 0;
|
||||||
uint32_t n_ff_exp = 0;
|
uint32_t n_ff_exp = 0;
|
||||||
|
uint32_t n_ff_shexp = 0;
|
||||||
uint32_t n_expert_shared = 0;
|
uint32_t n_expert_shared = 0;
|
||||||
float expert_weights_scale = 0.0;
|
float expert_weights_scale = 0.0;
|
||||||
|
|
||||||
@ -2018,6 +2021,7 @@ struct llama_hparams {
|
|||||||
if (this->n_lora_q != other.n_lora_q) return true;
|
if (this->n_lora_q != other.n_lora_q) return true;
|
||||||
if (this->n_lora_kv != other.n_lora_kv) return true;
|
if (this->n_lora_kv != other.n_lora_kv) return true;
|
||||||
if (this->n_ff_exp != other.n_ff_exp) return true;
|
if (this->n_ff_exp != other.n_ff_exp) return true;
|
||||||
|
if (this->n_ff_shexp != other.n_ff_shexp) return true;
|
||||||
if (this->n_expert_shared != other.n_expert_shared) return true;
|
if (this->n_expert_shared != other.n_expert_shared) return true;
|
||||||
|
|
||||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||||
@ -4455,6 +4459,9 @@ static void llm_load_hparams(
|
|||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_QWEN2MOE:
|
case LLM_ARCH_QWEN2MOE:
|
||||||
{
|
{
|
||||||
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 24: model.type = e_model::MODEL_A2_7B; break;
|
case 24: model.type = e_model::MODEL_A2_7B; break;
|
||||||
@ -5240,6 +5247,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
|||||||
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||||
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (model.arch == LLM_ARCH_QWEN2MOE) {
|
||||||
|
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||||
|
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns false if cancelled by progress_callback
|
// Returns false if cancelled by progress_callback
|
||||||
@ -6026,16 +6038,17 @@ static bool llm_load_tensors(
|
|||||||
GGML_ASSERT(hparams.n_expert_used > 0);
|
GGML_ASSERT(hparams.n_expert_used > 0);
|
||||||
|
|
||||||
// MoE branch
|
// MoE branch
|
||||||
auto n_ff_exp = n_ff / hparams.n_expert_used;
|
auto n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / hparams.n_expert_used;
|
||||||
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||||
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
|
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
|
||||||
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||||
|
|
||||||
// Shared expert branch
|
// Shared expert branch
|
||||||
|
auto n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff;
|
||||||
layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
|
layer.ffn_gate_inp_shexp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), {n_embd});
|
||||||
layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff});
|
layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_shexp});
|
||||||
layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff, n_embd});
|
layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {n_ff_shexp, n_embd});
|
||||||
layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff});
|
layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_shexp});
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_PHI2:
|
case LLM_ARCH_PHI2:
|
||||||
|
Loading…
Reference in New Issue
Block a user