add eos_id_list to llama.cpp

This commit is contained in:
toyer 2024-06-24 12:27:02 +00:00
parent 4b65b648ce
commit 3a4d5790bf
13 changed files with 122 additions and 55 deletions

View File

@ -2417,14 +2417,21 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
}
}
const int n_eos = llama_n_eos(llama_get_model(lctx));
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(llama_get_model(lctx), eos_ptr);
if (params.ignore_eos) {
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
for (int32_t i = 0; i < n_eos; ++i) {
params.sparams.logit_bias[eos_ptr[i]] = -INFINITY;
}
}
if (params.warmup) {
LOG("warming up the model with an empty run\n");
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
std::vector<llama_token> tmp = { llama_token_bos(model) };
tmp.insert(tmp.end(), eos_tokens.begin(), eos_tokens.end());
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
llama_kv_cache_clear(lctx);
llama_synchronize(lctx);
@ -3357,8 +3364,17 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks);
const auto logit_bias_eos = sparams.logit_bias.find(llama_token_eos(llama_get_model(lctx)));
const bool ignore_eos = logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY;
const int n_eos = llama_n_eos(llama_get_model(lctx));
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(llama_get_model(lctx), eos_ptr);
bool ignore_eos = false;
for (auto eos: eos_tokens) {
const auto logit_bias_eos = sparams.logit_bias.find(eos);
if (logit_bias_eos != sparams.logit_bias.end() && logit_bias_eos->second == -INFINITY) {
ignore_eos = true;
}
}
fprintf(stream, "ignore_eos: %s # default: false\n", ignore_eos ? "true" : "false");
yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str());
@ -3371,7 +3387,7 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
fprintf(stream, "logit_bias:\n");
for (std::pair<llama_token, float> lb : sparams.logit_bias) {
if (ignore_eos && lb.first == logit_bias_eos->first) {
if (ignore_eos && std::count(eos_tokens.begin(), eos_tokens.end(), lb.first)) {
continue;
}
fprintf(stream, " %d: %f", lb.first, lb.second);

View File

@ -240,7 +240,11 @@ int64_t get_example_targets_batch(
ggml_set_f32(target_probs, 0.0f);
llama_token bos = llama_token_bos(llama_get_model(lctx));
llama_token eos = llama_token_eos(llama_get_model(lctx));
const int n_eos = llama_n_eos(llama_get_model(lctx));
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(llama_get_model(lctx), eos_ptr);
llama_token eos = eos_ptr[0];
// printf("%s: example_id=%d n_batch=%d n_train_samples=%zu\n", __func__, example_id, n_batch, n_train_samples);
for (int k=0; k<n_batch; ++k) {
// printf("%s: batch %d\n", __func__, k);

View File

@ -801,7 +801,7 @@ class MPTModel(Model):
self._set_vocab_sentencepiece()
self.gguf_writer.add_add_bos_token(False)
self.gguf_writer.add_pad_token_id(3)
self.gguf_writer.add_eos_token_id(1)
self.gguf_writer.add_eos_token_id_list([1])
self.gguf_writer.add_unk_token_id(0)
def set_gguf_parameters(self):
@ -2339,8 +2339,8 @@ class MambaModel(Model):
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0] if field else 1)
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0] if field else 0)
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID_LIST)
self.gguf_writer.add_eos_token_id_list([field.parts[-1].tolist()[0] if field else 0])
field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0] if field else 0)
@ -2875,9 +2875,10 @@ class ChatGLMModel(Model):
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_eos_token_id_list([151329, 151336, 151338])
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
special_vocab.chat_template = "ChatGLM4"
special_vocab.chat_template = "chatglm4"
special_vocab.merges = merges
# only add special tokens when they were not already loaded from config.json
# if len(special_vocab.special_token_ids) == 0:

View File

@ -331,7 +331,7 @@ class GGMLToGGUF:
gguf_writer.add_token_types(toktypes)
gguf_writer.add_unk_token_id(0)
gguf_writer.add_bos_token_id(1)
gguf_writer.add_eos_token_id(2)
gguf_writer.add_eos_token_id_list([2])
def add_tensors(self, gguf_writer):
tensor_map = self.name_map

View File

@ -95,7 +95,6 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
std::string result;
const llama_model * mdl = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(mdl);
llama_kv_cache_clear(ctx);
llama_set_causal_attn(ctx, true);
@ -123,7 +122,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
if (token == eos_token) {
if (llama_token_is_eog(mdl, token)) {
break;
}

View File

@ -184,8 +184,13 @@ int main(int argc, char ** argv) {
return 1;
}
// add eos if not present
if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) {
inp.push_back(llama_token_eos(model));
const int n_eos = llama_n_eos(model);
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(model, eos_ptr);
if (!eos_tokens.empty() && (inp.empty() || std::count(eos_tokens.begin(), eos_tokens.end(), inp.back()))) {
inp.insert(inp.end(), eos_tokens.begin(), eos_tokens.end());
}
chunk.tokens = inp;
}

View File

@ -1021,7 +1021,13 @@ struct server_context {
slot.sparams.logit_bias.clear();
if (json_value(data, "ignore_eos", false)) {
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
const int n_eos = llama_n_eos(model);
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(model, eos_ptr);
for (int32_t i = 0; i < n_eos; ++i) {
slot.sparams.logit_bias[eos_ptr[i]] = -INFINITY;
}
}
const auto & logit_bias = data.find("logit_bias");
@ -1308,9 +1314,17 @@ struct server_context {
}
json get_formated_generation(const server_slot & slot) const {
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
const int n_eos = llama_n_eos(model);
std::vector<int32_t> eos_tokens(n_eos, 0);
int32_t* eos_ptr = eos_tokens.data();
llama_token_eos(model, eos_ptr);
bool ignore_eos = false;
for (auto eos: eos_tokens) {
const auto logit_bias_eos = slot.sparams.logit_bias.find(eos);
if (logit_bias_eos != slot.sparams.logit_bias.end() && eos < 0.0f && std::isinf(logit_bias_eos->second)) {
ignore_eos = true;
}
}
std::vector<std::string> samplers_sequence;
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
for (const auto & sampler_type : slot.sparams.samplers_sequence) {

View File

@ -88,12 +88,21 @@ int main(int argc, char ** argv) {
fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
return 1;
}
const int n_eos_tgt = llama_n_eos(model_tgt);
std::vector<int32_t> eos_tokens_tgt(n_eos_tgt, 0);
int32_t* eos_ptr_tgt = eos_tokens_tgt.data();
llama_token_eos(model_tgt, eos_ptr_tgt);
const int n_eos_dft = llama_n_eos(model_dft);
std::vector<int32_t> eos_tokens_dft(n_eos_dft, 0);
int32_t* eos_ptr_dft = eos_tokens_dft.data();
llama_token_eos(model_dft, eos_ptr_dft);
if (
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
eos_tokens_tgt != eos_tokens_dft
) {
fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
return 1;

View File

@ -88,7 +88,7 @@ class Keys:
SCORES = "tokenizer.ggml.scores"
MERGES = "tokenizer.ggml.merges"
BOS_ID = "tokenizer.ggml.bos_token_id"
EOS_ID = "tokenizer.ggml.eos_token_id"
EOS_ID = "tokenizer.ggml.eos_token_id" # recommand eos_id_list
UNK_ID = "tokenizer.ggml.unknown_token_id"
SEP_ID = "tokenizer.ggml.seperator_token_id"
PAD_ID = "tokenizer.ggml.padding_token_id"
@ -107,6 +107,8 @@ class Keys:
SUFFIX_ID = "tokenizer.ggml.suffix_token_id"
MIDDLE_ID = "tokenizer.ggml.middle_token_id"
EOT_ID = "tokenizer.ggml.eot_token_id"
EOS_ID_LIST = "tokenizer.ggml.eos_token_id_list"
#
@ -1091,7 +1093,7 @@ KEY_TOKENIZER_TOKEN_TYPE = Keys.Tokenizer.TOKEN_TYPE
KEY_TOKENIZER_SCORES = Keys.Tokenizer.SCORES
KEY_TOKENIZER_MERGES = Keys.Tokenizer.MERGES
KEY_TOKENIZER_BOS_ID = Keys.Tokenizer.BOS_ID
KEY_TOKENIZER_EOS_ID = Keys.Tokenizer.EOS_ID
KEY_TOKENIZER_EOS_ID_LIST= Keys.Tokenizer.EOS_ID_LIST
KEY_TOKENIZER_UNK_ID = Keys.Tokenizer.UNK_ID
KEY_TOKENIZER_SEP_ID = Keys.Tokenizer.SEP_ID
KEY_TOKENIZER_PAD_ID = Keys.Tokenizer.PAD_ID

View File

@ -510,9 +510,9 @@ class GGUFWriter:
def add_bos_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.BOS_ID, id)
def add_eos_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.EOS_ID, id)
def add_eos_token_id_list(self, id: Sequence[str] | Sequence[bytes] | Sequence[bytearray]) -> None:
self.add_array(Keys.Tokenizer.EOS_ID_LIST, id)
def add_unk_token_id(self, id: int) -> None:
self.add_uint32(Keys.Tokenizer.UNK_ID, id)

View File

@ -337,7 +337,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_SCORES,
LLM_KV_TOKENIZER_MERGES,
LLM_KV_TOKENIZER_BOS_ID,
LLM_KV_TOKENIZER_EOS_ID,
LLM_KV_TOKENIZER_EOS_ID, //compatibility with previous versions
LLM_KV_TOKENIZER_UNK_ID,
LLM_KV_TOKENIZER_SEP_ID,
LLM_KV_TOKENIZER_PAD_ID,
@ -352,6 +352,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_SUFFIX_ID,
LLM_KV_TOKENIZER_MIDDLE_ID,
LLM_KV_TOKENIZER_EOT_ID,
LLM_KV_TOKENIZER_EOS_ID_LIST
};
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
@ -438,6 +439,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" },
{ LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" },
{ LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" },
{ LLM_KV_TOKENIZER_EOS_ID_LIST, "tokenizer.ggml.eos_token_id_list" },
};
struct LLM_KV {
@ -2328,6 +2330,7 @@ struct llama_vocab {
id special_pad_id = -1;
id special_cls_id = -1;
id special_mask_id = -1;
std::set<id> special_eos_id_list;
id linefeed_id = 13;
id special_prefix_id = -1;
@ -5084,6 +5087,24 @@ static void llm_load_vocab(
}
}
const int eos_idx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_EOS_ID_LIST).c_str());
if (eos_idx == -1) {
vocab.special_eos_id_list.clear();
vocab.special_eos_id_list.insert(vocab.special_eos_id);
} else {
const uint32_t n_eos = gguf_get_arr_n(ctx, eos_idx);
const int* eos_tokens = (const int*)gguf_get_arr_data(ctx, eos_idx);
if (n_eos > 0) {
vocab.special_eos_id_list.clear();
} else {
vocab.special_eos_id_list.clear();
vocab.special_eos_id_list.insert(vocab.special_eos_id);
}
for (uint32_t i = 0; i < n_eos; ++i) {
vocab.special_eos_id_list.insert(eos_tokens[i]);
}
}
// Handle add_bos_token and add_eos_token
{
bool temp = true;
@ -5273,7 +5294,11 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
// special tokens
if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); }
if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); }
if (!vocab.special_eos_id_list.empty()) {
for (auto it = vocab.special_eos_id_list.begin(); it != vocab.special_eos_id_list.end(); ++it) {
LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, *it, vocab.id_to_token[*it].text.c_str() );
}
}
if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); }
if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); }
if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); }
@ -13482,8 +13507,8 @@ struct llm_tokenizer_bpe {
bool append_eos(std::vector<llama_vocab::id> & output) const {
if (vocab.tokenizer_add_eos) {
GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id);
GGML_ASSERT(!vocab.special_eos_id_list.empty());
output.insert(output.end(), vocab.special_eos_id_list.begin(), vocab.special_eos_id_list.end());
return true;
}
return false;
@ -13496,7 +13521,7 @@ struct llm_tokenizer_bpe {
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
"Are you sure this is what you want?\n", __FUNCTION__);
}
if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
if (vocab.tokenizer_add_eos && output.size() >= 2 && vocab.special_eos_id_list.find(*(output.end()-2)) != vocab.special_eos_id_list.end()) {
LLAMA_LOG_WARN(
"%s: Added a EOS token to the prompt as specified by the model but the prompt "
"also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
@ -13966,8 +13991,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
}
if (add_special && vocab.tokenizer_add_eos) {
GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id);
GGML_ASSERT(!vocab.special_eos_id_list.empty());
output.insert(output.end(), vocab.special_eos_id_list.begin(), vocab.special_eos_id_list.end());
}
// add suffix to chatglm3
if (vocab.type_pre == LLAMA_VOCAB_PRE_TYPE_CHATGLM3) {
@ -16966,6 +16991,10 @@ int32_t llama_n_vocab(const struct llama_model * model) {
return model->hparams.n_vocab;
}
int32_t llama_n_eos(const struct llama_model * model) {
return model->vocab.special_eos_id_list.size();
}
int32_t llama_n_ctx_train(const struct llama_model * model) {
return model->hparams.n_ctx_train;
}
@ -18550,21 +18579,8 @@ llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_to
}
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
auto arch_name = llama_model_arch_name(model->arch);
auto vocab_type = model->vocab.type;
if (strcmp(arch_name, "chatglm") == 0) {
if (LLAMA_VOCAB_TYPE_BPE == vocab_type) { // glm4
return token != -1 && (
token == llama_token_eos(model) ||
token == llama_token_eot(model) ||
token == 151329 ||
token == 151336 ||
token == 151338
);
}
}
return token != -1 && (
token == llama_token_eos(model) ||
model->vocab.special_eos_id_list.count(token) ||
token == llama_token_eot(model)
);
}
@ -18577,8 +18593,11 @@ llama_token llama_token_bos(const struct llama_model * model) {
return model->vocab.special_bos_id;
}
llama_token llama_token_eos(const struct llama_model * model) {
return model->vocab.special_eos_id;
void llama_token_eos(const struct llama_model * model, llama_token* token_list) {
int ind = 0;
for (auto it = model->vocab.special_eos_id_list.begin(); it != model->vocab.special_eos_id_list.end(); ++it) {
token_list[ind++] = *it;
}
}
llama_token llama_token_cls(const struct llama_model * model) {
@ -18952,10 +18971,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|start_header_id|>assistant<|end_header_id|>\n\n";
}
} else if (tmpl == "chatglm3" ||
(tmpl.find("add_generation_prompt") != std::string::npos &&
tmpl.find("for message in messages") != std::string::npos &&
tmpl.find("loop.first") != std::string::npos)) {
} else if (tmpl == "chatglm3" || tmpl.find("[gMASK]sop") != std::string::npos) {
// chatglm3-6b
ss << "[gMASK]" << "sop";
for (auto message : chat) {
@ -18965,7 +18981,7 @@ static int32_t llama_chat_apply_template_internal(
if (add_ass) {
ss << "<|assistant|>";
}
} else if (tmpl == "ChatGLM4") {
} else if (tmpl == "chatglm4" || tmpl.find("[gMASK]<sop>") != std::string::npos) {
ss << "[gMASK]" << "<sop>";
for (auto message : chat) {
std::string role(message->role);

View File

@ -448,6 +448,7 @@ extern "C" {
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_eos (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
@ -851,7 +852,7 @@ extern "C" {
// Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
LLAMA_API void llama_token_eos(const struct llama_model * model, llama_token* token_list); // end-of-sentence
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line

View File

@ -60,7 +60,7 @@ int main(void) {
// ChatGLM3
"{% for message in messages %}{% if loop.first %}[gMASK]sop<|{{ message['role'] }}|>\n {{ message['content'] }}{% else %}<|{{ message['role'] }}|>\n {{ message['content'] }}{% endif %}{% endfor %}{% if add_generation_prompt %}<|assistant|>{% endif %}",
// ChatGLM4
"ChatGLM4",
"chatglm4",
};
std::vector<std::string> expected_output = {
// teknium/OpenHermes-2.5-Mistral-7B