mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
llama : support MiniCPM3 (#9322)
Co-authored-by: 范睿凯 <fanruikai@modelbest.cn>
This commit is contained in:
parent
441b72b91f
commit
95ca85168b
@ -1841,6 +1841,60 @@ class MiniCPMModel(Model):
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
|
||||
@Model.register("MiniCPM3ForCausalLM")
|
||||
class MiniCPM3Model(Model):
|
||||
model_arch = gguf.MODEL_ARCH.MINICPM3
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
hparams = self.hparams
|
||||
|
||||
rope_dims = hparams["qk_rope_head_dim"]
|
||||
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
|
||||
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(self.block_count)
|
||||
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
|
||||
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
|
||||
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
|
||||
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
|
||||
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
|
||||
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
|
||||
|
||||
rope_scaling = self.find_hparam(['rope_scaling'], True)
|
||||
if rope_scaling is None:
|
||||
return
|
||||
|
||||
long_factors = rope_scaling.get('long_factor', None)
|
||||
short_factors = rope_scaling.get('short_factor', None)
|
||||
|
||||
if long_factors is None or short_factors is None:
|
||||
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')
|
||||
|
||||
if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
|
||||
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')
|
||||
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
|
||||
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_llama_hf()
|
||||
|
||||
def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
|
||||
if n_kv_head is not None and n_head != n_kv_head:
|
||||
n_head //= n_kv_head
|
||||
|
||||
return (
|
||||
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
|
||||
.swapaxes(1, 2)
|
||||
.reshape(weights.shape)
|
||||
)
|
||||
|
||||
|
||||
@Model.register("QWenLMHeadModel")
|
||||
class QwenModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.QWEN
|
||||
|
@ -210,6 +210,7 @@ class MODEL_ARCH(IntEnum):
|
||||
ORION = auto()
|
||||
INTERNLM2 = auto()
|
||||
MINICPM = auto()
|
||||
MINICPM3 = auto()
|
||||
GEMMA = auto()
|
||||
GEMMA2 = auto()
|
||||
STARCODER2 = auto()
|
||||
@ -364,6 +365,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
MODEL_ARCH.ORION: "orion",
|
||||
MODEL_ARCH.INTERNLM2: "internlm2",
|
||||
MODEL_ARCH.MINICPM: "minicpm",
|
||||
MODEL_ARCH.MINICPM3: "minicpm3",
|
||||
MODEL_ARCH.GEMMA: "gemma",
|
||||
MODEL_ARCH.GEMMA2: "gemma2",
|
||||
MODEL_ARCH.STARCODER2: "starcoder2",
|
||||
@ -867,6 +869,23 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
MODEL_ARCH.MINICPM3: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q_A,
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||
MODEL_TENSOR.ATTN_KV_B,
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.GEMMA: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
298
src/llama.cpp
298
src/llama.cpp
@ -193,6 +193,7 @@ enum llm_arch {
|
||||
LLM_ARCH_ORION,
|
||||
LLM_ARCH_INTERNLM2,
|
||||
LLM_ARCH_MINICPM,
|
||||
LLM_ARCH_MINICPM3,
|
||||
LLM_ARCH_GEMMA,
|
||||
LLM_ARCH_GEMMA2,
|
||||
LLM_ARCH_STARCODER2,
|
||||
@ -241,6 +242,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
||||
{ LLM_ARCH_ORION, "orion" },
|
||||
{ LLM_ARCH_INTERNLM2, "internlm2" },
|
||||
{ LLM_ARCH_MINICPM, "minicpm" },
|
||||
{ LLM_ARCH_MINICPM3, "minicpm3" },
|
||||
{ LLM_ARCH_GEMMA, "gemma" },
|
||||
{ LLM_ARCH_GEMMA2, "gemma2" },
|
||||
{ LLM_ARCH_STARCODER2, "starcoder2" },
|
||||
@ -1034,6 +1036,29 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
|
||||
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MINICPM3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FACTORS_LONG, "rope_factors_long" },
|
||||
{ LLM_TENSOR_ROPE_FACTORS_SHORT, "rope_factors_short" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q_A_NORM, "blk.%d.attn_q_a_norm" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_NORM, "blk.%d.attn_kv_a_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_A, "blk.%d.attn_q_a" },
|
||||
{ LLM_TENSOR_ATTN_Q_B, "blk.%d.attn_q_b" },
|
||||
{ LLM_TENSOR_ATTN_KV_A_MQA, "blk.%d.attn_kv_a_mqa" },
|
||||
{ LLM_TENSOR_ATTN_KV_B, "blk.%d.attn_kv_b" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GEMMA,
|
||||
{
|
||||
@ -5390,6 +5415,17 @@ static void llm_load_hparams(
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MINICPM3:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q);
|
||||
ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 62: model.type = e_model::MODEL_4B; break;
|
||||
default: model.type = e_model::MODEL_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GROK:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
@ -6897,6 +6933,54 @@ static bool llm_load_tensors(
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MINICPM3:
|
||||
{
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
|
||||
const int64_t q_lora_rank = hparams.n_lora_q;
|
||||
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
||||
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
|
||||
|
||||
// output
|
||||
{
|
||||
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
|
||||
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (model.output == NULL) {
|
||||
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
ggml_context * ctx_layer = ctx_for_layer(i);
|
||||
ggml_context * ctx_split = ctx_for_layer_split(i);
|
||||
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
|
||||
layer.attn_q_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank});
|
||||
|
||||
layer.attn_kv_a_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank});
|
||||
|
||||
layer.wq_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank});
|
||||
layer.wq_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k});
|
||||
|
||||
layer.wkv_a_mqa = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + (n_embd_head_qk_rope)});
|
||||
layer.wkv_b = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_KV_B, "weight", i), {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)});
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head * ( n_embd_head_v), n_embd});
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd});
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
|
||||
layer.rope_long = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight"), { n_embd_head_qk_rope/2 }, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GROK:
|
||||
{
|
||||
if (n_expert == 0) {
|
||||
@ -12843,6 +12927,215 @@ struct llm_build_context {
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_minicpm3() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
//TODO: if the model varies, these parameters need to be read from the model
|
||||
const int64_t n_embd_base = 256;
|
||||
const float scale_embd = 12.0f;
|
||||
const float scale_depth = 1.4f;
|
||||
const float kq_scale = 1.0f / sqrtf(float(hparams.n_embd_head_k));
|
||||
|
||||
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
struct ggml_tensor * cur;
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
|
||||
|
||||
// scale the input embeddings
|
||||
inpL = ggml_scale(ctx0, inpL, scale_embd);
|
||||
cb(inpL, "inp_scaled", -1);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
struct ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
struct ggml_tensor * rope_factors = build_rope_factors(il);
|
||||
// norm
|
||||
cur = llm_build_norm(ctx0, inpL, hparams,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self_attention
|
||||
{
|
||||
struct ggml_tensor * q = NULL;
|
||||
// {n_embd, q_lora_rank} * {n_embd, n_tokens} -> {q_lora_rank, n_tokens}
|
||||
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
|
||||
cb(q, "q", il);
|
||||
|
||||
q = llm_build_norm(ctx0, q, hparams,
|
||||
model.layers[il].attn_q_a_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(q, "q", il);
|
||||
|
||||
// {q_lora_rank, n_head * hparams.n_embd_head_k} * {q_lora_rank, n_tokens} -> {n_head * hparams.n_embd_head_k, n_tokens}
|
||||
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
|
||||
cb(q, "q", il);
|
||||
|
||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||
struct ggml_tensor * q_nope = ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens,
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k),
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
||||
0);
|
||||
cb(q_nope, "q_nope", il);
|
||||
|
||||
// and {n_head * n_embd_head_qk_rope, n_tokens}
|
||||
struct ggml_tensor * q_pe = ggml_view_3d(ctx0, q, n_embd_head_qk_rope, n_head, n_tokens,
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k),
|
||||
ggml_row_size(q->type, hparams.n_embd_head_k * n_head),
|
||||
ggml_row_size(q->type, n_embd_head_qk_nope));
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
// {n_embd, kv_lora_rank + n_embd_head_qk_rope} * {n_embd, n_tokens} -> {kv_lora_rank + n_embd_head_qk_rope, n_tokens}
|
||||
struct ggml_tensor * kv_pe_compresseed = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
||||
cb(kv_pe_compresseed, "kv_pe_compresseed", il);
|
||||
|
||||
// split into {kv_lora_rank, n_tokens}
|
||||
struct ggml_tensor * kv_compressed = ggml_view_2d(ctx0, kv_pe_compresseed, kv_lora_rank, n_tokens,
|
||||
kv_pe_compresseed->nb[1],
|
||||
0);
|
||||
cb(kv_compressed, "kv_compressed", il);
|
||||
|
||||
// and {n_embd_head_qk_rope, n_tokens}
|
||||
struct ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_pe_compresseed, n_embd_head_qk_rope, 1, n_tokens,
|
||||
kv_pe_compresseed->nb[1],
|
||||
kv_pe_compresseed->nb[1],
|
||||
ggml_row_size(kv_pe_compresseed->type, kv_lora_rank));
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
kv_compressed = ggml_cont(ctx0, kv_compressed); // TODO: the CUDA backend does not support non-contiguous norm
|
||||
kv_compressed = llm_build_norm(ctx0, kv_compressed, hparams,
|
||||
model.layers[il].attn_kv_a_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(kv_compressed, "kv_compressed", il);
|
||||
|
||||
// {kv_lora_rank, n_head * (n_embd_head_qk_nope + n_embd_head_v)} * {kv_lora_rank, n_tokens} -> {n_head * (n_embd_head_qk_nope + n_embd_head_v), n_tokens}
|
||||
struct ggml_tensor * kv = ggml_mul_mat(ctx0, model.layers[il].wkv_b, kv_compressed);
|
||||
cb(kv, "kv", il);
|
||||
|
||||
// split into {n_head * n_embd_head_qk_nope, n_tokens}
|
||||
struct ggml_tensor * k_nope = ggml_view_3d(ctx0, kv, n_embd_head_qk_nope, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, n_embd_head_qk_nope + hparams.n_embd_head_v),
|
||||
ggml_row_size(kv->type, n_head * (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
0);
|
||||
cb(k_nope, "k_nope", il);
|
||||
|
||||
// and {n_head * n_embd_head_v, n_tokens}
|
||||
struct ggml_tensor * v_states = ggml_view_3d(ctx0, kv, hparams.n_embd_head_v, n_head, n_tokens,
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)),
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope + hparams.n_embd_head_v)*n_head),
|
||||
ggml_row_size(kv->type, (n_embd_head_qk_nope)));
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
v_states = ggml_cont(ctx0, v_states);
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
v_states = ggml_view_2d(ctx0, v_states, hparams.n_embd_head_v * n_head, n_tokens,
|
||||
ggml_row_size(kv->type, hparams.n_embd_head_v * n_head),
|
||||
0);
|
||||
cb(v_states, "v_states", il);
|
||||
|
||||
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||
q_pe = ggml_rope_ext(
|
||||
ctx0, q_pe, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
// shared RoPE key
|
||||
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
|
||||
k_pe = ggml_rope_ext(
|
||||
ctx0, k_pe, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
struct ggml_tensor * q_states = ggml_concat(ctx0, q_nope, q_pe, 0);
|
||||
cb(q_states, "q_states", il);
|
||||
|
||||
struct ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
|
||||
cb(k_states, "k_states", il);
|
||||
|
||||
cur = llm_build_kv(ctx0, lctx, kv_self, gf,
|
||||
model.layers[il].wo, NULL,
|
||||
k_states, v_states, q_states, KQ_mask, n_tokens, kv_head, n_kv, kq_scale, cb, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1) {
|
||||
// skip computing output for unused tokens
|
||||
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
// scale_res - scale the hidden states for residual connection
|
||||
const float scale_res = scale_depth/sqrtf(float(n_layer));
|
||||
cur = ggml_scale(ctx0, cur, scale_res);
|
||||
cb(cur, "hidden_scaled", il);
|
||||
|
||||
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = llm_build_ffn(ctx0, lctx, cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
|
||||
// scale the hidden states for residual connection
|
||||
cur = ggml_scale(ctx0, cur, scale_res);
|
||||
cb(cur, "hidden_scaled_ffn", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cur = lctx.cvec.apply_to(ctx0, cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = llm_build_norm(ctx0, cur, hparams,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, cb, -1);
|
||||
cb(cur, "result_norm", -1);
|
||||
|
||||
// lm_head scaling
|
||||
const float scale_lmhead = float(n_embd_base)/float(n_embd);
|
||||
cur = ggml_scale(ctx0, cur, scale_lmhead);
|
||||
cb(cur, "lmhead_scaling", -1);
|
||||
|
||||
// lm_head
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
|
||||
return gf;
|
||||
}
|
||||
|
||||
struct ggml_cgraph * build_gemma() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
|
||||
|
||||
@ -15383,6 +15676,10 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
{
|
||||
result = llm.build_minicpm();
|
||||
} break;
|
||||
case LLM_ARCH_MINICPM3:
|
||||
{
|
||||
result = llm.build_minicpm3();
|
||||
} break;
|
||||
case LLM_ARCH_GEMMA:
|
||||
{
|
||||
result = llm.build_gemma();
|
||||
@ -18609,6 +18906,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
|
||||
case LLM_ARCH_CODESHELL:
|
||||
case LLM_ARCH_NEMOTRON:
|
||||
case LLM_ARCH_EXAONE:
|
||||
case LLM_ARCH_MINICPM3:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
// all model arches should be listed explicitly here
|
||||
|
Loading…
Reference in New Issue
Block a user