llama : add MiniCPM support (#5346)

* support minicpm arch.

* fix tab/space typo.

* convert minicpm model via convert-hf-gguf.py

* try to make tokenizer work

* fix bug for quantize minicpm

* fix for flake8 lint

* remove convert-minicpm.py

* fix for editorconfig

* correct minicpm model type (size)

* constants expanded for minicpm

* Minor change of the constant names for minicpm
This commit is contained in:
runfuture 2024-02-07 14:15:56 +08:00 committed by GitHub
parent f3e2b4fa3f
commit 316c7faf77
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 259 additions and 1 deletions

View File

@ -22,6 +22,8 @@ if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py'))
import gguf
from convert import HfVocab
# check for any of the given keys in the dictionary and return the value of the first key found
def get_key_opts(d, keys):
@ -205,6 +207,8 @@ class Model:
return OrionModel
if model_architecture == "InternLM2ForCausalLM":
return InternLM2Model
if model_architecture == "MiniCPMForCausalLM":
return MiniCPMModel
return Model
def _is_model_safetensors(self) -> bool:
@ -258,6 +262,8 @@ class Model:
return gguf.MODEL_ARCH.ORION
if arch == "InternLM2ForCausalLM":
return gguf.MODEL_ARCH.INTERNLM2
if arch == "MiniCPMForCausalLM":
return gguf.MODEL_ARCH.MINICPM
raise NotImplementedError(f'Architecture "{arch}" not supported!')
@ -402,6 +408,31 @@ class Model:
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_hf(self):
path = self.dir_model
added_tokens_path = self.dir_model
vocab = HfVocab(
path, added_tokens_path if added_tokens_path.exists() else None
)
tokens = []
scores = []
toktypes = []
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
assert len(tokens) == vocab.vocab_size
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
class GPTNeoXModel(Model):
def set_gguf_parameters(self):
@ -1041,6 +1072,24 @@ class MixtralModel(Model):
self._set_vocab_sentencepiece()
class MiniCPMModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
self.gguf_writer.add_name("MiniCPM")
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
def set_vocab(self):
self._set_vocab_hf()
class QwenModel(Model):
@staticmethod
def token_bytes_to_string(b):

View File

@ -104,6 +104,7 @@ class MODEL_ARCH(IntEnum):
CODESHELL = auto()
ORION = auto()
INTERNLM2 = auto()
MINICPM = auto()
class MODEL_TENSOR(IntEnum):
@ -156,6 +157,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.CODESHELL: "codeshell",
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
MODEL_ARCH.MINICPM: "minicpm",
}
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@ -464,6 +466,25 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.MINICPM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
}

188
llama.cpp
View File

@ -205,6 +205,7 @@ enum llm_arch {
LLM_ARCH_CODESHELL,
LLM_ARCH_ORION,
LLM_ARCH_INTERNLM2,
LLM_ARCH_MINICPM,
LLM_ARCH_UNKNOWN,
};
@ -228,6 +229,7 @@ static std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_CODESHELL, "codeshell" },
{ LLM_ARCH_ORION, "orion" },
{ LLM_ARCH_INTERNLM2, "internlm2" },
{ LLM_ARCH_MINICPM, "minicpm" },
};
enum llm_kv {
@ -690,6 +692,29 @@ static std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NAMES =
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_MINICPM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@ -1390,6 +1415,7 @@ enum e_model {
MODEL_UNKNOWN,
MODEL_0_5B,
MODEL_1B,
MODEL_2B,
MODEL_3B,
MODEL_4B,
MODEL_7B,
@ -2748,6 +2774,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
static const char * llama_model_type_name(e_model type) {
switch (type) {
case MODEL_1B: return "1B";
case MODEL_2B: return "2B";
case MODEL_3B: return "3B";
case MODEL_7B: return "7B";
case MODEL_8B: return "8B";
@ -2887,6 +2914,13 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_MINICPM:
{
switch (hparams.n_layer) {
case 40: model.type = e_model::MODEL_2B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_FALCON:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
@ -3524,14 +3558,17 @@ static bool llm_load_tensors(
switch (model.arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_REFACT:
case LLM_ARCH_MINICPM:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
if (model.arch != LLM_ARCH_MINICPM){
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
}
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
@ -6781,6 +6818,153 @@ struct llm_build_context {
return gf;
}
// ref: https://arxiv.org/abs/2203.03466
// https://github.com/ggerganov/llama.cpp/issues/5276#issuecomment-1925774738
// based on the original build_llama() function
struct ggml_cgraph * build_minicpm() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
const int64_t n_embd = hparams.n_embd;
//TODO: if the model varies, these parameters need to be read from the model
const int64_t n_embd_base = 256;
const float scale_embd = 12.0f;
const float scale_depth = 1.4f;
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
cb(inpL, "inp_embd", -1);
// scale the input embeddings
inpL = ggml_scale(ctx0, inpL, scale_embd);
cb(inpL, "inp_scaled", -1);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = ggml_view_1d(ctx0, lctx.inp_pos, n_tokens, 0);
cb(inp_pos, "inp_pos", -1);
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
cb(KQ_mask, "KQ_mask", -1);
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
// norm
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
if (model.layers[il].bq) {
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
cb(Qcur, "Qcur", il);
}
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
if (model.layers[il].bk) {
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
cb(Kcur, "Kcur", il);
}
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
if (model.layers[il].bv) {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
}
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
model.layers[il].wo, model.layers[il].bo,
Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
cb(cur, "kqv_out", il);
}
// scale_res - scale the hidden states for residual connection
const float scale_res = scale_depth/sqrtf(float(n_layer));
cur = ggml_scale(ctx0, cur, scale_res);
cb(cur, "hidden_scaled", -1);
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
{
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL,
model.layers[il].ffn_gate, NULL,
model.layers[il].ffn_down, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
cb(cur, "ffn_out", il);
}
// scale the hidden states for residual connection
cur = ggml_scale(ctx0, cur, scale_res);
cb(cur, "hidden_scaled_ffn", -1);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = llm_build_norm(ctx0, cur, hparams,
model.output_norm, NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
// lm_head scaling
const float scale_lmhead = float(n_embd_base)/float(n_embd);
cur = ggml_scale(ctx0, cur, scale_lmhead);
cb(cur, "lmhead_scaling", -1);
// lm_head
cur = ggml_mul_mat(ctx0, model.tok_embd, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
};
static struct ggml_cgraph * llama_build_graph(
@ -6943,6 +7127,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_internlm2();
} break;
case LLM_ARCH_MINICPM:
{
result = llm.build_minicpm();
} break;
default:
GGML_ASSERT(false);
}