Fix new line issue with chat template, disable template when in-prefix/suffix is set (#8203)

* preserve new line llama_chat_format_single

* disable chat template if in-prefix/suffix is set

* remove redundant change
This commit is contained in:
Xuan Son Nguyen 2024-06-30 20:27:13 +02:00 committed by GitHub
parent 1c5eba6f8e
commit 9ef0780062
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 9 deletions

View File

@ -1014,16 +1014,19 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
} }
if (arg == "--in-prefix-bos") { if (arg == "--in-prefix-bos") {
params.input_prefix_bos = true; params.input_prefix_bos = true;
params.enable_chat_template = false;
return true; return true;
} }
if (arg == "--in-prefix") { if (arg == "--in-prefix") {
CHECK_ARG CHECK_ARG
params.input_prefix = argv[i]; params.input_prefix = argv[i];
params.enable_chat_template = false;
return true; return true;
} }
if (arg == "--in-suffix") { if (arg == "--in-suffix") {
CHECK_ARG CHECK_ARG
params.input_suffix = argv[i]; params.input_suffix = argv[i];
params.enable_chat_template = false;
return true; return true;
} }
if (arg == "--spm-infill") { if (arg == "--spm-infill") {
@ -1406,7 +1409,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"halt generation at PROMPT, return control in interactive mode\n" "halt generation at PROMPT, return control in interactive mode\n"
"can be specified more than once for multiple prompts" }); "can be specified more than once for multiple prompts" });
options.push_back({ "main", "-sp, --special", "special tokens output enabled (default: %s)", params.special ? "true" : "false" }); options.push_back({ "main", "-sp, --special", "special tokens output enabled (default: %s)", params.special ? "true" : "false" });
options.push_back({ "main", "-cnv, --conversation", "run in conversation mode (does not print special tokens and suffix/prefix) (default: %s)", params.conversation ? "true" : "false" }); options.push_back({ "main", "-cnv, --conversation", "run in conversation mode (does not print special tokens and suffix/prefix, use default chat template) (default: %s)", params.conversation ? "true" : "false" });
options.push_back({ "main infill", "-i, --interactive", "run in interactive mode (default: %s)", params.interactive ? "true" : "false" }); options.push_back({ "main infill", "-i, --interactive", "run in interactive mode (default: %s)", params.interactive ? "true" : "false" });
options.push_back({ "main infill", "-if, --interactive-first", "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" }); options.push_back({ "main infill", "-if, --interactive-first", "run in interactive mode and wait for input right away (default: %s)", params.interactive_first ? "true" : "false" });
options.push_back({ "main infill", "-mli, --multiline-input", "allows you to write or paste multiple lines without ending each in '\\'" }); options.push_back({ "main infill", "-mli, --multiline-input", "allows you to write or paste multiple lines without ending each in '\\'" });
@ -2668,12 +2671,19 @@ std::string llama_chat_format_single(const struct llama_model * model,
const std::vector<llama_chat_msg> & past_msg, const std::vector<llama_chat_msg> & past_msg,
const llama_chat_msg & new_msg, const llama_chat_msg & new_msg,
bool add_ass) { bool add_ass) {
std::ostringstream ss;
auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false); auto fmt_past_msg = llama_chat_apply_template(model, tmpl, past_msg, false);
std::vector<llama_chat_msg> chat_new(past_msg); std::vector<llama_chat_msg> chat_new(past_msg);
// if the past_msg ends with a newline, we must preserve it in the formatted version
if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') {
ss << "\n";
};
// format chat with new_msg
chat_new.push_back(new_msg); chat_new.push_back(new_msg);
auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass); auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass);
auto formatted = fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); // get the diff part
return formatted; ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size());
return ss.str();
} }
std::string llama_chat_format_example(const struct llama_model * model, std::string llama_chat_format_example(const struct llama_model * model,

View File

@ -200,6 +200,7 @@ struct gpt_params {
std::string public_path = ""; std::string public_path = "";
std::string chat_template = ""; std::string chat_template = "";
std::string system_prompt = ""; std::string system_prompt = "";
bool enable_chat_template = true;
std::vector<std::string> api_keys; std::vector<std::string> api_keys;

View File

@ -261,7 +261,7 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
{ {
auto prompt = params.conversation auto prompt = (params.conversation && params.enable_chat_template)
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode ? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
: params.prompt; : params.prompt;
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) { if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
@ -810,7 +810,9 @@ int main(int argc, char ** argv) {
is_antiprompt = true; is_antiprompt = true;
} }
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str()); if (params.enable_chat_template) {
chat_add_and_format(model, chat_msgs, "assistant", assistant_ss.str());
}
is_interacting = true; is_interacting = true;
printf("\n"); printf("\n");
} }
@ -872,12 +874,13 @@ int main(int argc, char ** argv) {
string_process_escapes(buffer); string_process_escapes(buffer);
} }
std::string user_inp = params.conversation bool format_chat = params.conversation && params.enable_chat_template;
std::string user_inp = format_chat
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer)) ? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
: std::move(buffer); : std::move(buffer);
// TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix) // TODO: one inconvenient of current chat template implementation is that we can't distinguish between user input and special tokens (prefix/postfix)
const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true);
const auto line_inp = ::llama_tokenize(ctx, user_inp, false, params.conversation); const auto line_inp = ::llama_tokenize(ctx, user_inp, false, format_chat);
const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str()); LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp).c_str());

View File

@ -142,9 +142,9 @@ int main(void) {
std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n"; std::cout << "fmt_single(" << tmpl << ")\n" << output << "\n-------------------------\n";
return output; return output;
}; };
assert(fmt_single("chatml") == "<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n"); assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n");
assert(fmt_single("llama2") == "[INST] How are you [/INST]"); assert(fmt_single("llama2") == "[INST] How are you [/INST]");
assert(fmt_single("gemma") == "<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n"); assert(fmt_single("gemma") == "\n<start_of_turn>user\nHow are you<end_of_turn>\n<start_of_turn>model\n");
assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n");
return 0; return 0;