add preprocess to chatglm3 and chatglm4

This commit is contained in:
toyer 2024-06-21 07:47:51 +00:00
parent e773174052
commit 4b65b648ce
3 changed files with 33 additions and 14 deletions

View File

@ -2792,6 +2792,9 @@ class ChatGLMModel(Model):
toktypes.append(toktype) toktypes.append(toktype)
self.gguf_writer.add_tokenizer_model("llama") self.gguf_writer.add_tokenizer_model("llama")
# glm3 needs prefix and suffix formatted as:
# prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>"
self.gguf_writer.add_tokenizer_pre("chatglm-spm")
self.gguf_writer.add_token_list(tokens) self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores) self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes) self.gguf_writer.add_token_types(toktypes)

View File

@ -4789,6 +4789,10 @@ static void llm_load_vocab(
return; return;
} else if (tokenizer_model == "llama") { } else if (tokenizer_model == "llama") {
vocab.type = LLAMA_VOCAB_TYPE_SPM; vocab.type = LLAMA_VOCAB_TYPE_SPM;
// chatglm3 needs to preprocess prefix and suffix
if (tokenizer_pre == "chatglm-spm") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM3;
}
// default special tokens // default special tokens
vocab.special_bos_id = 1; vocab.special_bos_id = 1;
@ -13923,6 +13927,14 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
output.push_back(vocab.special_bos_id); output.push_back(vocab.special_bos_id);
is_prev_special = true; is_prev_special = true;
} }
// add prefix to chatglm3
if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) {
output.push_back(64790);
output.push_back(64792);
output.push_back(64795);
output.push_back(30910);
output.push_back(13);
}
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@ -13957,6 +13969,10 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
GGML_ASSERT(vocab.special_eos_id != -1); GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id); output.push_back(vocab.special_eos_id);
} }
// add suffix to chatglm3
if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) {
output.push_back(64796);
}
} break; } break;
case LLAMA_VOCAB_TYPE_BPE: case LLAMA_VOCAB_TYPE_BPE:
{ {
@ -13965,7 +13981,13 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
if (add_special) { if (add_special) {
tokenizer.append_bos(output); tokenizer.append_bos(output);
} }
// add prefix to chatglm4
if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM4) {
output.push_back(151331);
output.push_back(151333);
output.push_back(151336);
output.push_back(198);
}
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@ -13983,6 +14005,10 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
tokenizer.append_eos(output); tokenizer.append_eos(output);
tokenizer.check_double_bos_eos(output); tokenizer.check_double_bos_eos(output);
} }
// add suffix to chatglm4
if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM4) {
output.push_back(151337);
}
} break; } break;
case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_WPM:
{ {
@ -18599,18 +18625,7 @@ int32_t llama_tokenize(
int32_t n_tokens_max, int32_t n_tokens_max,
bool add_special, bool add_special,
bool parse_special) { bool parse_special) {
auto arch_name = llama_model_arch_name(model->arch); auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_special, parse_special);
auto prompt = std::move(std::string(text, text_len));
auto vocab_type = model->vocab.type;
if (strcmp(arch_name, "chatglm") == 0) {
// chatglm3
if (LLAMA_VOCAB_TYPE_SPM == vocab_type) {
prompt = "[gMASK]sop<|user|>\n" + prompt + "<|assistant|>";
} else if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4
prompt = "[gMASK]<sop><|user|>\n" + prompt + "<|assistant|>";
}
}
auto res = llama_tokenize_internal(model->vocab, prompt, add_special, parse_special);
if (n_tokens_max < (int) res.size()) { if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
return -((int) res.size()); return -((int) res.size());

View File

@ -87,7 +87,8 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_DBRX = 13, LLAMA_VOCAB_PRE_TYPE_DBRX = 13,
LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, LLAMA_VOCAB_PRE_TYPE_SMAUG = 14,
LLAMA_VOCAB_PRE_TYPE_PORO = 15, LLAMA_VOCAB_PRE_TYPE_PORO = 15,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 16, LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16,
LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17,
}; };
// note: these values should be synchronized with ggml_rope // note: these values should be synchronized with ggml_rope