From 4b65b648ce4a06b566cffdfb0ee09b55b79a15bb Mon Sep 17 00:00:00 2001 From: toyer <2042519524@qq.com> Date: Fri, 21 Jun 2024 07:47:51 +0000 Subject: [PATCH] add preprocess to chatglm3 and chatglm4 --- convert-hf-to-gguf.py | 3 +++ llama.cpp | 41 ++++++++++++++++++++++++++++------------- llama.h | 3 ++- 3 files changed, 33 insertions(+), 14 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index c5cb8bbec..3305b8ceb 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2792,6 +2792,9 @@ class ChatGLMModel(Model): toktypes.append(toktype) self.gguf_writer.add_tokenizer_model("llama") + # glm3 needs prefix and suffix formatted as: + # prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>" + self.gguf_writer.add_tokenizer_pre("chatglm-spm") self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_scores(scores) self.gguf_writer.add_token_types(toktypes) diff --git a/llama.cpp b/llama.cpp index a2df298a8..a2ac68379 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4789,6 +4789,10 @@ static void llm_load_vocab( return; } else if (tokenizer_model == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM; + // chatglm3 needs to preprocess prefix and suffix + if (tokenizer_pre == "chatglm-spm") { + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM3; + } // default special tokens vocab.special_bos_id = 1; @@ -13923,6 +13927,14 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(vocab.special_bos_id); is_prev_special = true; } + // add prefix to chatglm3 + if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { + output.push_back(64790); + output.push_back(64792); + output.push_back(64795); + output.push_back(30910); + output.push_back(13); + } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { @@ -13957,6 +13969,10 @@ static std::vector llama_tokenize_internal(const llama_vocab & GGML_ASSERT(vocab.special_eos_id != -1); output.push_back(vocab.special_eos_id); } + // add suffix to chatglm3 + if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) { + output.push_back(64796); + } } break; case LLAMA_VOCAB_TYPE_BPE: { @@ -13965,7 +13981,13 @@ static std::vector llama_tokenize_internal(const llama_vocab & if (add_special) { tokenizer.append_bos(output); } - + // add prefix to chatglm4 + if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM4) { + output.push_back(151331); + output.push_back(151333); + output.push_back(151336); + output.push_back(198); + } for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -13983,6 +14005,10 @@ static std::vector llama_tokenize_internal(const llama_vocab & tokenizer.append_eos(output); tokenizer.check_double_bos_eos(output); } + // add suffix to chatglm4 + if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM4) { + output.push_back(151337); + } } break; case LLAMA_VOCAB_TYPE_WPM: { @@ -18599,18 +18625,7 @@ int32_t llama_tokenize( int32_t n_tokens_max, bool add_special, bool parse_special) { - auto arch_name = llama_model_arch_name(model->arch); - auto prompt = std::move(std::string(text, text_len)); - auto vocab_type = model->vocab.type; - if (strcmp(arch_name, "chatglm") == 0) { - // chatglm3 - if (LLAMA_VOCAB_TYPE_SPM == vocab_type) { - prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>"; - } else if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4 - prompt = "[gMASK]<|user|>\n" + prompt + "<|assistant|>"; - } - } - auto res = llama_tokenize_internal(model->vocab, prompt, add_special, parse_special); + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special); if (n_tokens_max < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); return -((int) res.size()); diff --git a/llama.h b/llama.h index b1ff05bd7..a85b568b9 100644 --- a/llama.h +++ b/llama.h @@ -87,7 +87,8 @@ extern "C" { LLAMA_VOCAB_PRE_TYPE_DBRX = 13, LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, }; // note: these values should be synchronized with ggml_rope