mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
Add support for DeepseekV2ForCausalLM (#7519)
* common : increase max number of experts to 160 * common : add tensors ATTN_Q_A, ATTN_Q_A_NORM, ATTN_Q_B, ATTN_KV_A_MQA, ATTN_KV_A_NORM, ATTN_KV_B needed by DeepSeek-V2 MLA (multi-head latent attention) architecture * common : add model header parameters: leading_dense_block_count, expert_feed_forward_length, expert_shared_count, expert_weights_scale, attention.q_lora_rank, attention.kv_lora_rank, rope.scaling.yarn_log_multiplier * convert-hf : add model conversion support for DeepseekV2ForCausalLM * llama : add model types for DeepSeek-V2 and DeepSeek-V2-Lite models * llama : add two new llm_build_moe_ffn() arguments: scale_w (whether to scale weights of selected MoE experts) and w_scale (numerical value of the scaling factor) * llama : add inference support for LLM_ARCH_DEEPSEEK2 --------- Co-authored-by: Stanisław Szymczyk <sszymczy@gmail.com>
This commit is contained in:
parent
edc29433fa
commit
ee3dff6b8e
@ -2620,6 +2620,85 @@ class ArcticModel(Model):
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@Model.register("DeepseekV2ForCausalLM")
|
||||
class DeepseekV2Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.DEEPSEEK2
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
hparams = self.hparams
|
||||
|
||||
self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
|
||||
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
|
||||
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
|
||||
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_value_length(hparams["v_head_dim"])
|
||||
self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
|
||||
self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
|
||||
self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
|
||||
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
|
||||
if self.hparams["rope_scaling"].get("type") == "yarn":
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["rope_scaling"]["original_max_position_embeddings"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(0.1 * hparams["rope_scaling"]["mscale_all_dim"])
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# process the experts separately
|
||||
if name.find("mlp.experts") != -1:
|
||||
n_experts = self.hparams["n_routed_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def write_tensors(self):
|
||||
super().write_tensors()
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
@ -33,17 +33,21 @@ class Keys:
|
||||
FILE_TYPE = "general.file_type"
|
||||
|
||||
class LLM:
|
||||
VOCAB_SIZE = "{arch}.vocab_size"
|
||||
CONTEXT_LENGTH = "{arch}.context_length"
|
||||
EMBEDDING_LENGTH = "{arch}.embedding_length"
|
||||
BLOCK_COUNT = "{arch}.block_count"
|
||||
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
||||
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
||||
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
||||
EXPERT_COUNT = "{arch}.expert_count"
|
||||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
LOGIT_SCALE = "{arch}.logit_scale"
|
||||
VOCAB_SIZE = "{arch}.vocab_size"
|
||||
CONTEXT_LENGTH = "{arch}.context_length"
|
||||
EMBEDDING_LENGTH = "{arch}.embedding_length"
|
||||
BLOCK_COUNT = "{arch}.block_count"
|
||||
LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count"
|
||||
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
||||
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
|
||||
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
||||
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
||||
EXPERT_COUNT = "{arch}.expert_count"
|
||||
EXPERT_USED_COUNT = "{arch}.expert_used_count"
|
||||
EXPERT_SHARED_COUNT = "{arch}.expert_shared_count"
|
||||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
LOGIT_SCALE = "{arch}.logit_scale"
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
@ -55,6 +59,8 @@ class Keys:
|
||||
LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon"
|
||||
LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon"
|
||||
CAUSAL = "{arch}.attention.causal"
|
||||
Q_LORA_RANK = "{arch}.attention.q_lora_rank"
|
||||
KV_LORA_RANK = "{arch}.attention.kv_lora_rank"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
@ -64,6 +70,7 @@ class Keys:
|
||||
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
|
||||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
|
||||
|
||||
class SSM:
|
||||
CONV_KERNEL = "{arch}.ssm.conv_kernel"
|
||||
@ -140,6 +147,7 @@ class MODEL_ARCH(IntEnum):
|
||||
DBRX = auto()
|
||||
OLMO = auto()
|
||||
ARCTIC = auto()
|
||||
DEEPSEEK2 = auto()
|
||||
|
||||
|
||||
class MODEL_TENSOR(IntEnum):
|
||||
@ -185,6 +193,12 @@ class MODEL_TENSOR(IntEnum):
|
||||
SSM_A = auto()
|
||||
SSM_D = auto()
|
||||
SSM_OUT = auto()
|
||||
ATTN_Q_A = auto()
|
||||
ATTN_Q_B = auto()
|
||||
ATTN_KV_A_MQA = auto()
|
||||
ATTN_KV_B = auto()
|
||||
ATTN_Q_A_NORM = auto()
|
||||
ATTN_KV_A_NORM = auto()
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
@ -221,6 +235,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.DBRX: "dbrx",
|
||||
MODEL_ARCH.OLMO: "olmo",
|
||||
MODEL_ARCH.ARCTIC: "arctic",
|
||||
MODEL_ARCH.DEEPSEEK2: "deepseek2",
|
||||
}
|
||||
|
||||
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
@ -266,6 +281,12 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
||||
MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a",
|
||||
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
|
||||
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
|
||||
MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a",
|
||||
MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b",
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: "blk.{bid}.attn_kv_a_mqa",
|
||||
MODEL_TENSOR.ATTN_KV_B: "blk.{bid}.attn_kv_b",
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: "blk.{bid}.attn_q_a_norm",
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: "blk.{bid}.attn_kv_a_norm",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
@ -757,6 +778,33 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK2: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_A,
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||
MODEL_TENSOR.ATTN_KV_B,
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
@ -790,6 +838,10 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
MODEL_ARCH.DEEPSEEK2: [
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
],
|
||||
}
|
||||
|
||||
#
|
||||
|
@ -374,9 +374,15 @@ class GGUFWriter:
|
||||
def add_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_leading_dense_block_count(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
|
||||
|
||||
def add_feed_forward_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_expert_feed_forward_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_parallel_residual(self, use: bool) -> None:
|
||||
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
||||
|
||||
@ -407,6 +413,12 @@ class GGUFWriter:
|
||||
def add_expert_used_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_USED_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_expert_shared_count(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_SHARED_COUNT.format(arch=self.arch), count)
|
||||
|
||||
def add_expert_weights_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.EXPERT_WEIGHTS_SCALE.format(arch=self.arch), value)
|
||||
|
||||
def add_layer_norm_eps(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.LAYERNORM_EPS.format(arch=self.arch), value)
|
||||
|
||||
@ -416,6 +428,12 @@ class GGUFWriter:
|
||||
def add_causal_attention(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Attention.CAUSAL.format(arch=self.arch), value)
|
||||
|
||||
def add_q_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.Q_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_kv_lora_rank(self, length: int) -> None:
|
||||
self.add_uint32(Keys.Attention.KV_LORA_RANK.format(arch=self.arch), length)
|
||||
|
||||
def add_pooling_type(self, value: PoolingType) -> None:
|
||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||
|
||||
@ -440,6 +458,9 @@ class GGUFWriter:
|
||||
def add_rope_scaling_finetuned(self, value: bool) -> None:
|
||||
self.add_bool(Keys.Rope.SCALING_FINETUNED.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
|
||||
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
|
||||
|
||||
def add_ssm_conv_kernel(self, value: int) -> None:
|
||||
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
|
||||
|
||||
|
@ -256,6 +256,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek2
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
@ -285,6 +286,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.gate_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.gate_proj", # deepseek2
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
@ -320,6 +322,7 @@ class TensorNameMap:
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: (
|
||||
"model.layers.{bid}.mlp.shared_expert.down_proj", # qwen2moe
|
||||
"model.layers.{bid}.mlp.shared_experts.down_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
@ -383,6 +386,30 @@ class TensorNameMap:
|
||||
"model.layers.{bid}.out_proj",
|
||||
"backbone.layers.{bid}.mixer.out_proj",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A: (
|
||||
"model.layers.{bid}.self_attn.q_a_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_B: (
|
||||
"model.layers.{bid}.self_attn.q_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA: (
|
||||
"model.layers.{bid}.self_attn.kv_a_proj_with_mqa", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_B: (
|
||||
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM: (
|
||||
"model.layers.{bid}.self_attn.kv_a_layernorm", # deepseek2
|
||||
),
|
||||
}
|
||||
|
||||
# architecture-specific block mappings
|
||||
@ -415,7 +442,7 @@ class TensorNameMap:
|
||||
if tensor not in MODEL_TENSORS[arch]:
|
||||
continue
|
||||
# TODO: make this configurable
|
||||
n_experts = 128
|
||||
n_experts = 160
|
||||
for xid in range(n_experts):
|
||||
tensor_name = TENSOR_NAMES[tensor].format(bid = bid, xid = xid)
|
||||
self.mapping[tensor_name] = (tensor, tensor_name)
|
||||
|
422
llama.cpp
422
llama.cpp
@ -103,7 +103,7 @@
|
||||
#endif
|
||||
|
||||
#define LLAMA_MAX_NODES 8192
|
||||
#define LLAMA_MAX_EXPERTS 128
|
||||
#define LLAMA_MAX_EXPERTS 160
|
||||
|
||||
//
|
||||
// logging
|
||||
@ -222,6 +222,7 @@ enum llm_arch {
|
||||
LLM_ARCH_DBRX,
|
||||
LLM_ARCH_OLMO,
|
||||
LLM_ARCH_ARCTIC,
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
@ -259,6 +260,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_DBRX, "dbrx" },
|
||||
{ LLM_ARCH_OLMO, "olmo" },
|
||||
{ LLM_ARCH_ARCTIC, "arctic" },
|
||||
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
@ -279,11 +281,15 @@ enum llm_kv {
|
||||
LLM_KV_CONTEXT_LENGTH,
|
||||
LLM_KV_EMBEDDING_LENGTH,
|
||||
LLM_KV_BLOCK_COUNT,
|
||||
LLM_KV_LEADING_DENSE_BLOCK_COUNT,
|
||||
LLM_KV_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_USE_PARALLEL_RESIDUAL,
|
||||
LLM_KV_TENSOR_DATA_LAYOUT,
|
||||
LLM_KV_EXPERT_COUNT,
|
||||
LLM_KV_EXPERT_USED_COUNT,
|
||||
LLM_KV_EXPERT_SHARED_COUNT,
|
||||
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||
LLM_KV_POOLING_TYPE,
|
||||
LLM_KV_LOGIT_SCALE,
|
||||
|
||||
@ -296,6 +302,8 @@ enum llm_kv {
|
||||
LLM_KV_ATTENTION_LAYERNORM_EPS,
|
||||
LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,
|
||||
LLM_KV_ATTENTION_CAUSAL,
|
||||
LLM_KV_ATTENTION_Q_LORA_RANK,
|
||||
LLM_KV_ATTENTION_KV_LORA_RANK,
|
||||
|
||||
LLM_KV_ROPE_DIMENSION_COUNT,
|
||||
LLM_KV_ROPE_FREQ_BASE,
|
||||
@ -305,6 +313,7 @@ enum llm_kv {
|
||||
LLM_KV_ROPE_SCALING_ATTN_FACTOR,
|
||||
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
||||
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
|
||||
|
||||
LLM_KV_SPLIT_NO,
|
||||
LLM_KV_SPLIT_COUNT,
|
||||
@ -353,17 +362,21 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
|
||||
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
|
||||
|
||||
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
|
||||
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
|
||||
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
|
||||
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
|
||||
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
|
||||
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
|
||||
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
|
||||
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
||||
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
||||
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
|
||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
|
||||
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
|
||||
{ LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" },
|
||||
{ LLM_KV_BLOCK_COUNT, "%s.block_count" },
|
||||
{ LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" },
|
||||
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
|
||||
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
|
||||
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
|
||||
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
|
||||
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
||||
{ LLM_KV_EXPERT_USED_COUNT, "%s.expert_used_count" },
|
||||
{ LLM_KV_EXPERT_SHARED_COUNT, "%s.expert_shared_count" },
|
||||
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
||||
{ LLM_KV_POOLING_TYPE , "%s.pooling_type" },
|
||||
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
@ -374,6 +387,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_EPS, "%s.attention.layer_norm_epsilon" },
|
||||
{ LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, "%s.attention.layer_norm_rms_epsilon" },
|
||||
{ LLM_KV_ATTENTION_CAUSAL, "%s.attention.causal" },
|
||||
{ LLM_KV_ATTENTION_Q_LORA_RANK, "%s.attention.q_lora_rank" },
|
||||
{ LLM_KV_ATTENTION_KV_LORA_RANK, "%s.attention.kv_lora_rank" },
|
||||
|
||||
{ LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" },
|
||||
{ LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" },
|
||||
@ -383,6 +398,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||
{ LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" },
|
||||
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
||||
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
||||
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
|
||||
|
||||
{ LLM_KV_SPLIT_NO, "split.no" },
|
||||
{ LLM_KV_SPLIT_COUNT, "split.count" },
|
||||
@ -474,6 +490,12 @@ enum llm_tensor {
|
||||
LLM_TENSOR_SSM_A,
|
||||
LLM_TENSOR_SSM_D,
|
||||
LLM_TENSOR_SSM_OUT,
|
||||
LLM_TENSOR_ATTN_Q_A,
|
||||
LLM_TENSOR_ATTN_Q_B,
|
||||
LLM_TENSOR_ATTN_KV_A_MQA,
|
||||
LLM_TENSOR_ATTN_KV_B,
|
||||
LLM_TENSOR_ATTN_Q_A_NORM,
|
||||
LLM_TENSOR_ATTN_KV_A_NORM,
|
||||
};
|
||||
|
||||
static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES = {
|
||||
@ -1057,6 +1079,35 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_DEEPSEEK2,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
|
||||
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
|
||||
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
|
||||
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
|
||||
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
@ -1741,6 +1792,7 @@ enum e_model {
|
||||
MODEL_13B,
|
||||
MODEL_14B,
|
||||
MODEL_15B,
|
||||
MODEL_16B,
|
||||
MODEL_20B,
|
||||
MODEL_30B,
|
||||
MODEL_34B,
|
||||
@ -1748,6 +1800,7 @@ enum e_model {
|
||||
MODEL_40B,
|
||||
MODEL_65B,
|
||||
MODEL_70B,
|
||||
MODEL_236B,
|
||||
MODEL_314B,
|
||||
MODEL_SMALL,
|
||||
MODEL_MEDIUM,
|
||||
@ -1783,6 +1836,13 @@ struct llama_hparams {
|
||||
uint32_t n_expert_used = 0;
|
||||
uint32_t n_vocab_type = 0; // for BERT-style token types
|
||||
|
||||
uint32_t n_layer_dense_lead = 0;
|
||||
uint32_t n_lora_q = 0;
|
||||
uint32_t n_lora_kv = 0;
|
||||
uint32_t n_ff_exp = 0;
|
||||
uint32_t n_expert_shared = 0;
|
||||
float expert_weights_scale = 0.0;
|
||||
|
||||
float f_norm_eps;
|
||||
float f_norm_rms_eps;
|
||||
|
||||
@ -1790,6 +1850,7 @@ struct llama_hparams {
|
||||
float rope_freq_base_train;
|
||||
float rope_freq_scale_train;
|
||||
uint32_t n_yarn_orig_ctx;
|
||||
float rope_yarn_log_mul;
|
||||
|
||||
// for State Space Models
|
||||
uint32_t ssm_d_conv = 0;
|
||||
@ -1823,6 +1884,12 @@ struct llama_hparams {
|
||||
if (this->n_expert != other.n_expert) return true;
|
||||
if (this->n_expert_used != other.n_expert_used) return true;
|
||||
|
||||
if (this->n_layer_dense_lead != other.n_layer_dense_lead) return true;
|
||||
if (this->n_lora_q != other.n_lora_q) return true;
|
||||
if (this->n_lora_kv != other.n_lora_kv) return true;
|
||||
if (this->n_ff_exp != other.n_ff_exp) return true;
|
||||
if (this->n_expert_shared != other.n_expert_shared) return true;
|
||||
|
||||
if (this->rope_finetuned != other.rope_finetuned) return true;
|
||||
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
|
||||
|
||||
@ -1838,6 +1905,8 @@ struct llama_hparams {
|
||||
if (!is_float_close(this->rope_attn_factor, other.rope_attn_factor, EPSILON)) return true;
|
||||
if (!is_float_close(this->rope_freq_base_train, other.rope_freq_base_train, EPSILON)) return true;
|
||||
if (!is_float_close(this->rope_freq_scale_train, other.rope_freq_scale_train, EPSILON)) return true;
|
||||
if (!is_float_close(this->expert_weights_scale, other.expert_weights_scale, EPSILON)) return true;
|
||||
if (!is_float_close(this->rope_yarn_log_mul, other.rope_yarn_log_mul, EPSILON)) return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
@ -1913,6 +1982,8 @@ struct llama_layer {
|
||||
struct ggml_tensor * attn_k_norm_b;
|
||||
struct ggml_tensor * attn_out_norm;
|
||||
struct ggml_tensor * attn_out_norm_b;
|
||||
struct ggml_tensor * attn_q_a_norm;
|
||||
struct ggml_tensor * attn_kv_a_norm;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * wq;
|
||||
@ -1920,6 +1991,10 @@ struct llama_layer {
|
||||
struct ggml_tensor * wv;
|
||||
struct ggml_tensor * wo;
|
||||
struct ggml_tensor * wqkv;
|
||||
struct ggml_tensor * wq_a;
|
||||
struct ggml_tensor * wq_b;
|
||||
struct ggml_tensor * wkv_a_mqa;
|
||||
struct ggml_tensor * wkv_b;
|
||||
|
||||
// attention bias
|
||||
struct ggml_tensor * bq;
|
||||
@ -3832,6 +3907,7 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_13B: return "13B";
|
||||
case MODEL_14B: return "14B";
|
||||
case MODEL_15B: return "15B";
|
||||
case MODEL_16B: return "16B";
|
||||
case MODEL_20B: return "20B";
|
||||
case MODEL_30B: return "30B";
|
||||
case MODEL_34B: return "34B";
|
||||
@ -3839,6 +3915,7 @@ static const char * llama_model_type_name(e_model type) {
|
||||
case MODEL_40B: return "40B";
|
||||
case MODEL_65B: return "65B";
|
||||
case MODEL_70B: return "70B";
|
||||
case MODEL_236B: return "236B";
|
||||
case MODEL_314B: return "314B";
|
||||
case MODEL_SMALL: return "0.1B";
|
||||
case MODEL_MEDIUM: return "0.4B";
|
||||
@ -4384,6 +4461,26 @@ static void llm_load_hparams(
|
||||
model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
{
|
||||
bool is_lite = (hparams.n_layer == 27);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead);
|
||||
if (!is_lite) {
|
||||
ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
|
||||
}
|
||||
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared);
|
||||
ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 27: model.type = e_model::MODEL_16B; break;
|
||||
case 60: model.type = e_model::MODEL_236B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: (void)0;
|
||||
}
|
||||
|
||||
@ -4895,6 +4992,16 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
|
||||
if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
|
||||
if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
|
||||
if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); }
|
||||
|
||||
if (model.arch == LLM_ARCH_DEEPSEEK2) {
|
||||
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
|
||||
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
|
||||
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
LLAMA_LOG_INFO("%s: n_expert_shared = %d\n", __func__, hparams.n_expert_shared);
|
||||
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||
LLAMA_LOG_INFO("%s: rope_yarn_log_mul = %.4f\n", __func__, hparams.rope_yarn_log_mul);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns false if cancelled by progress_callback
|
||||
@ -5051,8 +5158,6 @@ static bool llm_load_tensors(
|
||||
throw std::runtime_error("model has expert layers but no expert layers are used");
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_embd_gqa == n_embd_k_gqa);
|
||||
|
||||
ggml_context * ctx_input = ctx_map.at(model.buft_input.buft);
|
||||
ggml_context * ctx_output = ctx_map.at(model.buft_output.buft);
|
||||
ggml_context * ctx_output_split = ctx_map.at(model.buft_output.buft_matrix);
|
||||
@ -6213,6 +6318,70 @@ static bool llm_load_tensors(
|
||||
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert});
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
{
|
||||
bool is_lite = (hparams.n_layer == 27);
|
||||
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
const uint32_t q_lora_rank = hparams.n_lora_q;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
const uint32_t n_ff_exp = hparams.n_ff_exp;
|
||||
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
if (!is_lite) {
|
||||
layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank});
|
||||
}
|
||||
layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank});
|
||||
|
||||
if (!is_lite) {
|
||||
layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
|
||||
layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.n_head * hparams.n_embd_head_k});
|
||||
} else {
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_k_gqa});
|
||||
}
|
||||
layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope});
|
||||
layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, hparams.n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {hparams.n_head * hparams.n_embd_head_v, n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
if ((uint32_t) i < hparams.n_layer_dense_lead) {
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
} else {
|
||||
layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert});
|
||||
|
||||
GGML_ASSERT(hparams.n_expert > 0);
|
||||
GGML_ASSERT(hparams.n_expert_used > 0);
|
||||
|
||||
// MoE branch
|
||||
layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||
layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert});
|
||||
layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert});
|
||||
|
||||
// Shared expert branch
|
||||
layer.ffn_gate_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * hparams.n_expert_shared});
|
||||
layer.ffn_down_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * hparams.n_expert_shared, n_embd});
|
||||
layer.ffn_up_shexp = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * hparams.n_expert_shared});
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
@ -6667,6 +6836,8 @@ static struct ggml_tensor * llm_build_moe_ffn(
|
||||
int64_t n_expert_used,
|
||||
llm_ffn_op_type type_op,
|
||||
bool norm_w,
|
||||
bool scale_w,
|
||||
float w_scale,
|
||||
const llm_build_cb & cb,
|
||||
int il) {
|
||||
int64_t n_embd = cur->ne[0];
|
||||
@ -6698,6 +6869,10 @@ static struct ggml_tensor * llm_build_moe_ffn(
|
||||
|
||||
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
|
||||
}
|
||||
if (scale_w) {
|
||||
weights = ggml_scale(ctx, weights, w_scale);
|
||||
cb(weights, "ffn_moe_weights_scaled", il);
|
||||
}
|
||||
|
||||
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
|
||||
ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
||||
@ -7328,6 +7503,7 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_down_exps,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
@ -7809,6 +7985,7 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_down_exps,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_GELU, true,
|
||||
false, 0.0,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@ -7952,6 +8129,7 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_down_exps,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@ -9090,6 +9268,7 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_down_exps,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, false,
|
||||
false, 0.0,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@ -10977,6 +11156,7 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_down_exps,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
cb, il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
|
||||
@ -11008,6 +11188,215 @@ struct llm_build_context {
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_deepseek2() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
// mutable variable, needed during the last layer of the computation to skip unused tokens
|
||||
int32_t n_tokens = this->n_tokens;
|
||||
|
||||
bool is_lite = (hparams.n_layer == 27);
|
||||
|
||||
// We have to pre-scale kq_scale and attn_factor to make the YaRN RoPE work correctly.
|
||||
// See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation.
|
||||
const float mscale = attn_factor * (1.0f + hparams.rope_yarn_log_mul * logf(1.0f / freq_scale));
|
||||
const float kq_scale = 1.0f*mscale*mscale/sqrtf(float(hparams.n_embd_head_k));
|
||||
const float attn_factor_scaled = 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale));
|
||||
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
// {n_embd, n_tokens}
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self_attention
|
||||
{
|
||||
struct ggml_tensor * q = NULL;
|
||||
if (!is_lite) {
|
||||
// {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
|
||||
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
|
||||
cb(q, "q", il);
|
||||
|
||||
q = llm_build_norm(ctx0, q, hparams,
|
||||
model.layers[il].attn_q_a_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(q, "q", il);
|
||||
|
||||
// {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
|
||||
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
|
||||
cb(q, "q", il);
|
||||
} else {
|
||||
q = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
|
||||
cb(q, "q", il);
|
||||
}
|
||||
|
||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, 0);
|
||||
cb(q_nope, "q_nope", il);
|
||||
// and {n_head * n_embd_head_qk_rope, n_tokens}
|
||||
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_element_size(q) * hparams.n_embd_head_k, ggml_element_size(q) * hparams.n_embd_head_k * n_head, ggml_element_size(q) * n_embd_head_qk_nope);
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
|
||||
struct ggml_tensor * compressed_kv_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
||||
cb(compressed_kv_pe, "compressed_kv_pe", il);
|
||||
|
||||
// split into {kv_lora_rank, n_tokens}
|
||||
struct ggml_tensor * compressed_kv = ggml_view_2d(ctx0, compressed_kv_pe, kv_lora_rank, n_tokens, compressed_kv_pe->nb[1], 0);
|
||||
cb(compressed_kv, "compressed_kv", il);
|
||||
// and {n_embd_head_qk_rope, n_tokens}
|
||||
struct ggml_tensor * k_pe = ggml_view_2d(ctx0, compressed_kv_pe, n_embd_head_qk_rope, n_tokens, compressed_kv_pe->nb[1], ggml_element_size(compressed_kv_pe)*kv_lora_rank);
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
compressed_kv = llm_build_norm(ctx0, compressed_kv, hparams,
|
||||
model.layers[il].attn_kv_a_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(compressed_kv, "compressed_kv", il);
|
||||
|
||||
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
||||
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, compressed_kv);
|
||||
cb(kv, "kv", il);
|
||||
|
||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), 0);
|
||||
cb(k_nope, "k_nope", il);
|
||||
|
||||
// and {n_head * n_embd_head_v, n_tokens}
|
||||
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens, ggml_element_size(kv) * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v), ggml_element_size(kv) * n_embd_head_qk_nope);
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
v_states = ggml_cont(ctx0, v_states);
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens, ggml_element_size(kv) * hparams.n_embd_head_v * n_head, 0);
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
q_pe = ggml_rope_ext(
|
||||
ctx0, q_pe, inp_pos, nullptr,
|
||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
||||
);
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
// shared RoPE key
|
||||
k_pe = ggml_rope_ext(
|
||||
ctx0, ggml_view_3d(ctx0, k_pe, n_embd_head_qk_rope, 1, n_tokens, k_pe->nb[0], k_pe->nb[1], 0), inp_pos, nullptr,
|
||||
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor_scaled, beta_fast, beta_slow
|
||||
);
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||
cb(q_states, "q_states", il);
|
||||
|
||||
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
||||
cb(k_states, "k_states", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
n_tokens = n_outputs;
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
if ((uint32_t) il < hparams.n_layer_dense_lead) {
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up, NULL,
|
||||
model.layers[il].ffn_gate, NULL,
|
||||
model.layers[il].ffn_down, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
ggml_tensor * moe_out =
|
||||
llm_build_moe_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, false,
|
||||
true, hparams.expert_weights_scale,
|
||||
cb, il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// FFN shared expert
|
||||
{
|
||||
ggml_tensor * ffn_shexp = llm_build_ffn(ctx0, cur,
|
||||
model.layers[il].ffn_up_shexp, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
|
||||
@ -11226,6 +11615,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_arctic();
|
||||
} break;
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
{
|
||||
result = llm.build_deepseek2();
|
||||
} break;
|
||||
default:
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
@ -16239,6 +16632,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_COMMAND_R:
|
||||
case LLM_ARCH_OLMO:
|
||||
case LLM_ARCH_ARCTIC:
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
Loading…
Reference in New Issue
Block a user