llama : support glm3 and glm4 (#8031)

* add chatglm3-6b model support huggingface model:
 https://hf-mirror.com/THUDM/chatglm3-6b

Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>

* remove .rotary_pos_emb.inv_freq and unuse code for chatglm3 model

Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>

* fix lint error

Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>

* optimize convert-hf-to-gguf.py for chatglm model

Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>

* support glm-4-9b-chat

Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>

* fix eos tokens to glm4

* remove unused log

* add preprocess to chatglm3 and chatglm4

* add eos_id_list to llama.cpp

* fix code style

* fix code style

* fix conflicts

* fix conflicts

* Revert "add eos_id_list to llama.cpp"

This reverts commit 3a4d5790bf.

* set <|endoftext|> as eos and <|user|> as eot

* fix chat template bug

* add comment to glm prefix and suffix

* fix conflicts and add rope_ratio & ChatGLMForConditionalGeneration

* fix chat template bug

* fix codestyle

* fix conflicts

* modified the general name of glm model

* fix conflicts

* remove prefix and suffix

* use normal glm4 chattempalte & use LLM_FFN_SWIGLU in phi3

* fix: resolve Flake8 errors in `convert-hf-to-gguf.py`

- Fix E302 by adding two blank lines before top-level function definitions
- Replace print statements to fix NP100
- Fix E303 by ensuring only one blank line between lines of code

* fix rope ratio to solve incorrect answers

* fix by comments

---------

Signed-off-by: XingXing Qiao <qiaoxx@dingdao.com>
Co-authored-by: XingXing Qiao <qiaoxx@dingdao.com>
Co-authored-by: Umpire2018 <138990495+Umpire2018@users.noreply.github.com>
This commit is contained in:
toyer 2024-07-07 20:52:10 +08:00 committed by GitHub
parent b5040086d4
commit 905942abdb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 455 additions and 25 deletions

View File

@ -487,6 +487,9 @@ class Model:
if chkhsh == "7967bfa498ade6b757b064f31e964dddbb80f8f9a4d68d4ba7998fcf281c531a":
# ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-code
res = "jina-v2-code"
if chkhsh == "b6e8e1518dc4305be2fe39c313ed643381c4da5db34a98f6a04c093f8afbe99b":
# ref: https://huggingface.co/THUDM/glm-4-9b-chat
res = "chatglm-bpe"
if chkhsh == "7fc505bd3104ca1083b150b17d088b59534ede9bde81f0dd2090967d7fe52cee":
# ref: https://huggingface.co/LumiOpen/Viking-7B
res = "viking"
@ -3176,6 +3179,190 @@ class JaisModel(Model):
self.gguf_writer.add_max_alibi_bias(self.max_alibi_bias)
@Model.register("ChatGLMModel", "ChatGLMForConditionalGeneration")
class ChatGLMModel(Model):
model_arch = gguf.MODEL_ARCH.CHATGLM
def set_vocab_chatglm3(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
toktypes: list[int] = []
scores: list[float] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams.get("padded_vocab_size", len(tokenizer.get_vocab()))
assert max(tokenizer.get_vocab().values()) < vocab_size
role_special_tokens = ["<|system|>", "<|user|>", "<|assistant|>", "<|observation|>"]
special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "sop", "eop"] + role_special_tokens
for token_id in range(vocab_size):
piece = tokenizer._convert_id_to_token(token_id)
if token_id == 0:
piece = "<unk>"
elif token_id == 1:
piece = "<bos>"
elif token_id == 2:
piece = "<eos>"
text = piece.encode("utf-8")
score = 0.0
# Referencing the tokenizer Python implementation(https://huggingface.co/THUDM/chatglm3-6b/blob/main/tokenization_chatglm.py),
# it is only valid if it is less than tokenizer.tokenizer.sp_model.vocab_size()
if len(piece) != 0 and token_id < tokenizer.tokenizer.sp_model.vocab_size():
score = tokenizer.tokenizer.sp_model.get_score(token_id)
if len(piece) == 0:
text = f"[PAD{token_id}]".encode("utf-8")
if token_id >= tokenizer.tokenizer.sp_model.vocab_size():
if piece in special_tokens:
# show special tokens in prompt
toktype = SentencePieceTokenTypes.USER_DEFINED
else:
toktype = SentencePieceTokenTypes.UNKNOWN
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
continue
toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.tokenizer.sp_model.is_unknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.tokenizer.sp_model.is_control(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.tokenizer.sp_model.is_unused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.tokenizer.sp_model.is_byte(token_id):
toktype = SentencePieceTokenTypes.BYTE
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
self.gguf_writer.add_tokenizer_model("llama")
# glm3 needs prefix and suffix formatted as:
# prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>"
self.gguf_writer.add_tokenizer_pre("chatglm-spm")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
@staticmethod
def token_bytes_to_string(b):
from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
byte_encoder = bytes_to_unicode()
return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
@staticmethod
def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
parts = [bytes([b]) for b in token]
while True:
min_idx = None
min_rank = None
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
rank = mergeable_ranks.get(pair[0] + pair[1])
if rank is not None and (min_rank is None or rank < min_rank):
min_idx = i
min_rank = rank
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
break
assert min_idx is not None
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
return parts
def set_vocab(self):
if "THUDM/chatglm3-6b" in self.hparams.get("_name_or_path", ""):
self.set_vocab_chatglm3()
return
dir_model = self.dir_model
hparams = self.hparams
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
vocab_size = hparams["padded_vocab_size"]
assert max(tokenizer.get_vocab().values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
merges = []
vocab = {}
mergeable_ranks = tokenizer.mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[ChatGLMModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = ChatGLMModel.bpe(mergeable_ranks, token, max_rank=rank)
assert len(merged) >= 2 and len(merged) <= 7
merges.append(' '.join(map(ChatGLMModel.token_bytes_to_string, merged)))
# for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
added_vocab = tokenizer.get_added_vocab()
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.merges = merges
# only add special tokens when they were not already loaded from config.json
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
# this one is usually not in config.json anyway
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
def set_gguf_parameters(self):
self.gguf_writer.add_name(self.hparams.get("_name_or_path").split("/")[1]) # THUDM/glm4-9b-chat or THUDM/chatglm3-6b
n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
n_head_kv = self.hparams.get("multi_query_group_num", n_head)
self.gguf_writer.add_context_length(self.hparams.get("seq_length", n_embed))
self.gguf_writer.add_embedding_length(n_embed)
self.gguf_writer.add_feed_forward_length(self.hparams.get("ffn_hidden_size", 4 * n_embed))
self.gguf_writer.add_block_count(self.hparams["num_layers"])
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layernorm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_rope_dimension_count(64)
self.gguf_writer.add_add_bos_token(False)
rope_freq = 10000
if "rope_ratio" in self.hparams:
rope_freq = rope_freq * self.hparams["rope_ratio"]
self.gguf_writer.add_rope_freq_base(rope_freq)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
if name.endswith(".rotary_pos_emb.inv_freq"):
return []
name = name.removeprefix("transformer.")
return [(self.map_tensor_name(name), data_torch)]
###### CONVERSION LOGIC ######

View File

@ -120,7 +120,6 @@ class Keys:
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
EOT_ID = "tokenizer.ggml.eot_token_id"
#
# recommended mapping of model tensor names for storage in gguf
#
@ -163,6 +162,7 @@ class MODEL_ARCH(IntEnum):
OPENELM = auto()
ARCTIC = auto()
DEEPSEEK2 = auto()
CHATGLM = auto()
BITNET = auto()
T5 = auto()
JAIS = auto()
@ -289,6 +289,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.OPENELM: "openelm",
MODEL_ARCH.ARCTIC: "arctic",
MODEL_ARCH.DEEPSEEK2: "deepseek2",
MODEL_ARCH.CHATGLM: "chatglm",
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.JAIS: "jais",
@ -924,6 +925,18 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
MODEL_ARCH.CHATGLM : [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.BITNET: [
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
@ -1020,6 +1033,9 @@ MODEL_TENSOR_SKIP: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_ROT_EMBD,
],
MODEL_ARCH.CHATGLM: [
MODEL_TENSOR.ROPE_FREQS,
],
}
#

View File

@ -24,6 +24,7 @@ class TensorNameMap:
"backbone.embedding", # mamba
"backbone.embeddings", # mamba-hf
"transformer.in_out_embed", # Grok
"embedding.word_embeddings", # chatglm
"transformer.token_embeddings", # openelm
"shared", # t5
),
@ -55,6 +56,7 @@ class TensorNameMap:
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
"output_layer", # chatglm
),
# Output norm
@ -71,12 +73,14 @@ class TensorNameMap:
"model.norm_f", # mamba-qbert
"backbone.norm_f", # mamba
"transformer.rms_norm", # Grok
"encoder.final_layernorm", # chatglm
"transformer.norm", # openelm
),
# Rope frequencies
MODEL_TENSOR.ROPE_FREQS: (
"rope.freqs", # llama-pth
"rotary_pos_emb.inv_freq", # chatglm
),
}
@ -101,6 +105,7 @@ class TensorNameMap:
"backbone.layers.{bid}.norm", # mamba
"transformer.decoder_layer.{bid}.rms_norm", # Grok
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
"encoder.layers.{bid}.input_layernorm", # chatglm
"transformer.layers.{bid}.attn_norm", # openelm
),
@ -124,6 +129,7 @@ class TensorNameMap:
"transformer.h.{bid}.mixer.Wqkv", # phi2
"encoder.layers.{bid}.attn.Wqkv", # nomic-bert
"model.layers.{bid}.self_attn.qkv_proj", # phi3
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
"transformer.layers.{bid}.attn.qkv_proj", # openelm
),
@ -135,7 +141,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.q_proj", # gpt-j
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
"model.layers.{bid}.attention.wq", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.query" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
),
# Attention key
@ -147,7 +153,7 @@ class TensorNameMap:
"transformer.h.{bid}.attn.k", # refact
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
),
# Attention value
@ -182,6 +188,7 @@ class TensorNameMap:
"encoder.layers.{bid}.attn.out_proj", # nomic-bert
"transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
"encoder.layers.{bid}.self_attention.dense", # chatglm
"transformer.layers.{bid}.attn.out_proj", # openelm
),
@ -218,6 +225,7 @@ class TensorNameMap:
"h.{bid}.ln_2", # gpt2
"model.layers.{bid}.ffn_norm", # internlm2
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
"transformer.layers.{bid}.ffn_norm", # openelm
),
@ -268,6 +276,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.c_fc", # starcoder2
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
),
MODEL_TENSOR.FFN_UP_EXP: (
@ -337,6 +346,7 @@ class TensorNameMap:
"transformer.layers.{bid}.ffn.proj_2", # openelm
"model.layers.{bid}.residual_mlp.w2", # arctic
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
),
MODEL_TENSOR.FFN_DOWN_EXP: (

View File

@ -88,8 +88,10 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_VIKING = 16,
LLAMA_VOCAB_PRE_TYPE_JAIS = 17,
LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
LLAMA_VOCAB_PRE_TYPE_VIKING = 18,
LLAMA_VOCAB_PRE_TYPE_JAIS = 19,
};
// note: these values should be synchronized with ggml_rope

View File

@ -229,6 +229,7 @@ enum llm_arch {
LLM_ARCH_OPENELM,
LLM_ARCH_ARCTIC,
LLM_ARCH_DEEPSEEK2,
LLM_ARCH_CHATGLM,
LLM_ARCH_BITNET,
LLM_ARCH_T5,
LLM_ARCH_JAIS,
@ -272,6 +273,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_OPENELM, "openelm" },
{ LLM_ARCH_ARCTIC, "arctic" },
{ LLM_ARCH_DEEPSEEK2, "deepseek2" },
{ LLM_ARCH_CHATGLM, "chatglm" },
{ LLM_ARCH_BITNET, "bitnet" },
{ LLM_ARCH_T5, "t5" },
{ LLM_ARCH_JAIS, "jais" },
@ -1205,6 +1207,21 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
},
},
{
LLM_ARCH_CHATGLM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
},
},
{
LLM_ARCH_BITNET,
{
@ -2087,9 +2104,11 @@ enum e_model {
MODEL_2_8B,
MODEL_3B,
MODEL_4B,
MODEL_6B,
MODEL_6_9B,
MODEL_7B,
MODEL_8B,
MODEL_9B,
MODEL_11B,
MODEL_12B,
MODEL_13B,
@ -2115,7 +2134,6 @@ enum e_model {
MODEL_16x12B,
MODEL_10B_128x3_66B,
MODEL_57B_A14B,
MODEL_9B,
MODEL_27B,
};
@ -4490,9 +4508,11 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_2_8B: return "2.8B";
case MODEL_3B: return "3B";
case MODEL_4B: return "4B";
case MODEL_6B: return "6B";
case MODEL_6_9B: return "6.9B";
case MODEL_7B: return "7B";
case MODEL_8B: return "8B";
case MODEL_9B: return "9B";
case MODEL_11B: return "11B";
case MODEL_12B: return "12B";
case MODEL_13B: return "13B";
@ -4518,7 +4538,6 @@ static const char * llama_model_type_name(e_model type) {
case MODEL_16x12B: return "16x12B";
case MODEL_10B_128x3_66B: return "10B+128x3.66B";
case MODEL_57B_A14B: return "57B.A14B";
case MODEL_9B: return "9B";
case MODEL_27B: return "27B";
default: return "?B";
}
@ -5124,6 +5143,15 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_CHATGLM:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) {
case 28: model.type = e_model::MODEL_6B; break;
case 40: model.type = e_model::MODEL_9B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_BITNET:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -5256,9 +5284,7 @@ static void llm_load_vocab(
if (merges_keyidx == -1) {
throw std::runtime_error("cannot find tokenizer merges in model file\n");
}
const int n_merges = gguf_get_arr_n(ctx, merges_keyidx);
for (int i = 0; i < n_merges; i++) {
const std::string word = gguf_get_arr_str(ctx, merges_keyidx, i);
GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0);
@ -5401,6 +5427,10 @@ static void llm_load_vocab(
tokenizer_pre == "poro-chat") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
vocab.tokenizer_clean_spaces = false;
} else if (
tokenizer_pre == "chatglm-bpe") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4;
vocab.special_bos_id = -1;
} else if (
tokenizer_pre == "viking") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING;
@ -5525,7 +5555,6 @@ static void llm_load_vocab(
vocab.special_eot_id = 107;
}
}
try {
vocab.linefeed_id = llama_byte_to_token(vocab, '\n');
} catch (const std::exception & e) {
@ -7433,6 +7462,36 @@ static bool llm_load_tensors(
layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff});
}
} break;
case LLM_ARCH_CHATGLM:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});
// output
{
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab});
}
for (int i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i);
ggml_context * ctx_split = ctx_for_layer_split(i);
auto & layer = model.layers[i];
layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + (hparams.n_embd_head_k << 2)});
layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + (hparams.n_embd_head_k << 2)});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2});
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
@ -7657,6 +7716,7 @@ enum llm_ffn_op_type {
LLM_FFN_GELU,
LLM_FFN_RELU,
LLM_FFN_RELU_SQR,
LLM_FFN_SWIGLU,
};
enum llm_ffn_gate_type {
@ -7861,6 +7921,19 @@ static struct ggml_tensor * llm_build_ffn(
cur = ggml_sqr(ctx, cur);
cb(cur, "ffn_sqr(relu)", il);
} break;
case LLM_FFN_SWIGLU:
{
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
int64_t split_point = cur->ne[0] / 2;
struct ggml_tensor * x0 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], 0));
struct ggml_tensor * x1 = ggml_cont(ctx, ggml_view_2d(ctx, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
x0 = ggml_silu(ctx, x0);
cb(cur, "ffn_silu", il);
cur = ggml_mul(ctx, x0, x1);
cb(cur, "ffn_mul", il);
} break;
}
if (type_gate == LLM_FFN_PAR) {
@ -10709,19 +10782,12 @@ struct llm_build_context {
// special-case: the up and gate tensors are merged into a single tensor
// TOOD: support into llm_build_ffn
{
struct ggml_tensor* up = ggml_mul_mat(ctx0, model.layers[il].ffn_up, cur);
cb(up, "ffn_up", il);
auto g = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), 0));
auto y = ggml_cont(ctx0, ggml_view_2d(ctx0, up, up->ne[0] / 2, up->ne[1], ggml_row_size(up->type, up->ne[0]), up->nb[1] / 2));
y = ggml_mul(ctx0, y, ggml_silu(ctx0, g));
cb(y, "ffn_gate", il);
auto down = ggml_mul_mat(ctx0, model.layers[il].ffn_down, y);
cb(down, "ffn_down", il);
cur = down;
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il);
}
@ -13413,6 +13479,120 @@ struct llm_build_context {
return gf;
}
struct ggml_cgraph * build_chatglm() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
struct ggml_tensor * cur;
struct ggml_tensor * inpL;
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
// inp_pos - contains the positions
struct ggml_tensor * inp_pos = build_inp_pos();
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * inpSA = inpL;
cur = llm_build_norm(ctx0, inpL, hparams,
model.layers[il].attn_norm,
NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "attn_norm", il);
// self-attention
{
struct ggml_tensor * Qcur = nullptr;
struct ggml_tensor * Kcur = nullptr;
struct ggml_tensor * Vcur = nullptr;
cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
cb(cur, "wqkv", il);
cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
cb(cur, "bqkv", il);
Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
cb(Qcur, "Qcur", il);
cb(Kcur, "Kcur", il);
cb(Vcur, "Vcur", il);
//printf("freq_base: %f freq_scale: %f ext_factor: %f attn_factor: %f\n", freq_base, freq_scale, ext_factor, attn_factor);
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur_rope", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur_rope", il);
cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
model.layers[il].wo, NULL,
Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
}
if (il == n_layer - 1) {
// skip computing output for unused tokens
struct ggml_tensor * inp_out_ids = build_inp_out_ids();
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
// Add the input
struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
cb(ffn_inp, "ffn_inp", il);
// FF
{
cur = llm_build_norm(ctx0, ffn_inp, hparams,
model.layers[il].ffn_norm,
NULL,
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = llm_build_ffn(ctx0, cur,
model.layers[il].ffn_up, NULL, NULL,
NULL, NULL, NULL,
model.layers[il].ffn_down, NULL, NULL,
NULL,
LLM_FFN_SWIGLU, LLM_FFN_SEQ, cb, il);
cb(cur, "ffn_out", il);
}
inpL = ggml_add(ctx0, cur, ffn_inp);
cb(inpL, "l_out", il);
}
cur = llm_build_norm(ctx0, inpL, hparams,
model.output_norm,
NULL,
LLM_NORM_RMS, cb, -1);
cb(cur, "result_norm", -1);
cur = ggml_mul_mat(ctx0, model.output, cur);
cb(cur, "result_output", -1);
ggml_build_forward_expand(gf, cur);
return gf;
}
};
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@ -13644,6 +13824,10 @@ static struct ggml_cgraph * llama_build_graph(
{
result = llm.build_deepseek2();
} break;
case LLM_ARCH_CHATGLM:
{
result = llm.build_chatglm();
} break;
case LLM_ARCH_BITNET:
{
result = llm.build_bitnet();
@ -15259,6 +15443,11 @@ struct llm_tokenizer_bpe {
" ?[^(\\s|.,!?…。,、।۔،)]+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_CHATGLM4:
regex_exprs = {
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_VIKING:
regex_exprs = {
"\\p{N}",
@ -16160,7 +16349,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
if (add_special) {
tokenizer.append_bos(output);
}
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@ -19151,6 +19339,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_OLMO:
case LLM_ARCH_ARCTIC:
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_CHATGLM:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2
@ -20883,7 +21072,6 @@ int32_t llama_tokenize(
bool add_special,
bool parse_special) {
auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
return -((int) res.size());
@ -21302,6 +21490,25 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
} else if (tmpl == "chatglm3" || tmpl_contains("[gMASK]sop")) {
// chatglm3-6b
ss << "[gMASK]" << "sop";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n " << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == "chaglm4" || tmpl_contains("[gMASK]<sop>")) {
ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);
ss << "<|" << role << "|>" << "\n" << message->content;
}
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == "minicpm" || tmpl_contains(u8"<用户>")) {
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
for (auto message : chat) {

View File

@ -58,6 +58,10 @@ int main(void) {
"{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
//Phi-3-vision
"{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{- '<|assistant|>\n' -}}{% endif %}",
// ChatGLM3
"{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
// ChatGLM4
u8"[gMASK]<sop>{% for item in messages %}{% if item['tools'] is defined %}<|system|>\n你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n# 可用工具{% set tools = item['tools'] %}{% for tool in tools %}{% if tool['type'] == 'function' %}\n\n## {{ tool['function']['name'] }}\n\n{{ tool['function'] | tojson(indent=4) }}\n......{% endif %}{% endfor %}{% endif %}{% if item['content'] %}<|{{ item['role'] }}|>{{ item['metadata'] }}\n{{ item['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
u8"{% for message in messages %}{% if message['role'] == 'user' %}{{'<用户>' + message['content'].strip() + '<AI>'}}{% else %}{{message['content'].strip()}}{% endif %}{% endfor %}",
// DeepSeek-V2
@ -98,6 +102,10 @@ int main(void) {
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
//Phi-3-vision
"<|system|>\nYou are a helpful assistant<|end|>\n<|user|>\nHello<|end|>\n<|assistant|>\nHi there<|end|>\n<|user|>\nWho are you<|end|>\n<|assistant|>\n I am an assistant <|end|>\n<|user|>\nAnother question<|end|>\n<|assistant|>\n",
// ChatGLM3
"[gMASK]sop<|system|>\n You are a helpful assistant<|user|>\n Hello<|assistant|>\n Hi there<|user|>\n Who are you<|assistant|>\n I am an assistant <|user|>\n Another question<|assistant|>",
// ChatGLM4
"[gMASK]<sop><|system|>\nYou are a helpful assistant<|user|>\nHello<|assistant|>\nHi there<|user|>\nWho are you<|assistant|>\n I am an assistant <|user|>\nAnother question<|assistant|>",
// MiniCPM-3B-OpenHermes-2.5-v2-GGUF
u8"You are a helpful assistant<用户>Hello<AI>Hi there<用户>Who are you<AI>I am an assistant<用户>Another question<AI>",
// DeepSeek-V2