From 91c188d6c296bd3384f2a02a83b71187aa3d18b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sigbj=C3=B8rn=20Skj=C3=A6ret?= Date: Tue, 18 Jun 2024 14:19:45 +0200 Subject: [PATCH] Only use FIM middle token if it exists (#7648) * Only use FIM middle if it exists * Only use FIM middle if it exists --- examples/infill/infill.cpp | 13 +++++++++++-- examples/server/server.cpp | 7 ++++++- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 0e4ec79c6..3e82e4a81 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -223,7 +223,11 @@ int main(int argc, char ** argv) { inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model)); embd_inp = inp_pfx; embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); - embd_inp.push_back(llama_token_middle(model)); + + const llama_token middle_token = llama_token_middle(model); + if (middle_token >= 0) { + embd_inp.push_back(middle_token); + } LOG("prefix: \"%s\"\n", log_tostr(params.input_prefix)); LOG("suffix: \"%s\"\n", log_tostr(params.input_suffix)); @@ -528,7 +532,12 @@ int main(int argc, char ** argv) { inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model)); embd_inp = inp_pfx; embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); - embd_inp.push_back(llama_token_middle(model)); + + const llama_token middle_token = llama_token_middle(model); + if (middle_token >= 0) { + embd_inp.push_back(middle_token); + } + embd.clear(); n_remain = params.n_predict; n_past = 0; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 919078f2b..ec59307b2 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2038,7 +2038,12 @@ struct server_context { prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model)); prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); - prefix_tokens.push_back(llama_token_middle(model)); + + const llama_token middle_token = llama_token_middle(model); + if (middle_token >= 0) { + prefix_tokens.push_back(middle_token); + } + prompt_tokens = prefix_tokens; } else { prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt