Only use FIM middle token if it exists (#7648)

* Only use FIM middle if it exists

* Only use FIM middle if it exists
This commit is contained in:
Sigbjørn Skjæret 2024-06-18 14:19:45 +02:00 committed by GitHub
parent 84f6de17f6
commit 91c188d6c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 3 deletions

View File

@ -223,7 +223,11 @@ int main(int argc, char ** argv) {
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model)); inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
embd_inp = inp_pfx; embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp.push_back(llama_token_middle(model));
const llama_token middle_token = llama_token_middle(model);
if (middle_token >= 0) {
embd_inp.push_back(middle_token);
}
LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix)); LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix));
LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix)); LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix));
@ -528,7 +532,12 @@ int main(int argc, char ** argv) {
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model)); inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));
embd_inp = inp_pfx; embd_inp = inp_pfx;
embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end());
embd_inp.push_back(llama_token_middle(model));
const llama_token middle_token = llama_token_middle(model);
if (middle_token >= 0) {
embd_inp.push_back(middle_token);
}
embd.clear(); embd.clear();
n_remain = params.n_predict; n_remain = params.n_predict;
n_past = 0; n_past = 0;

View File

@ -2038,7 +2038,12 @@ struct server_context {
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(model));
const llama_token middle_token = llama_token_middle(model);
if (middle_token >= 0) {
prefix_tokens.push_back(middle_token);
}
prompt_tokens = prefix_tokens; prompt_tokens = prefix_tokens;
} else { } else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt