Server: use llama_chat_apply_template (#5593)

* server: use llama_chat_apply_template

* server: remove trailing space

* server: fix format_chat

* server: fix help message

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* server: fix formatted_chat

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan Son Nguyen 2024-02-20 15:58:27 +01:00 committed by GitHub
parent 5207b3fbc5
commit 9c405c9f9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 45 additions and 49 deletions

View File

@ -15,13 +15,11 @@
using json = nlohmann::json; using json = nlohmann::json;
inline static json oaicompat_completion_params_parse( inline static json oaicompat_completion_params_parse(
const struct llama_model * model,
const json &body, /* openai api json semantics */ const json &body, /* openai api json semantics */
const std::string &chat_template) const std::string &chat_template)
{ {
json llama_params; json llama_params;
std::string formatted_prompt = chat_template == "chatml"
? format_chatml(body["messages"]) // OpenAI 'messages' to chatml (with <|im_start|>,...)
: format_llama2(body["messages"]); // OpenAI 'messages' to llama2 (with [INST],...)
llama_params["__oaicompat"] = true; llama_params["__oaicompat"] = true;
@ -34,7 +32,7 @@ inline static json oaicompat_completion_params_parse(
// https://platform.openai.com/docs/api-reference/chat/create // https://platform.openai.com/docs/api-reference/chat/create
llama_sampling_params default_sparams; llama_sampling_params default_sparams;
llama_params["model"] = json_value(body, "model", std::string("unknown")); llama_params["model"] = json_value(body, "model", std::string("unknown"));
llama_params["prompt"] = formatted_prompt; llama_params["prompt"] = format_chat(model, chat_template, body["messages"]);
llama_params["cache_prompt"] = json_value(body, "cache_prompt", false); llama_params["cache_prompt"] = json_value(body, "cache_prompt", false);
llama_params["temperature"] = json_value(body, "temperature", 0.0); llama_params["temperature"] = json_value(body, "temperature", 0.0);
llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k); llama_params["top_k"] = json_value(body, "top_k", default_sparams.top_k);

View File

@ -37,7 +37,7 @@ struct server_params
std::string hostname = "127.0.0.1"; std::string hostname = "127.0.0.1";
std::vector<std::string> api_keys; std::vector<std::string> api_keys;
std::string public_path = "examples/server/public"; std::string public_path = "examples/server/public";
std::string chat_template = "chatml"; std::string chat_template = "";
int32_t port = 8080; int32_t port = 8080;
int32_t read_timeout = 600; int32_t read_timeout = 600;
int32_t write_timeout = 600; int32_t write_timeout = 600;
@ -1937,8 +1937,9 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n"); printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`"); printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`"); printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
printf(" --chat-template FORMAT_NAME"); printf(" --chat-template JINJA_TEMPLATE\n");
printf(" set chat template, possible value is: llama2, chatml (default %s)", sparams.chat_template.c_str()); printf(" set custom jinja chat template (default: template taken from model's metadata)\n");
printf(" Note: only commonly used templates are accepted, since we don't have jinja parser\n");
printf("\n"); printf("\n");
} }
@ -2389,13 +2390,13 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
invalid_param = true; invalid_param = true;
break; break;
} }
std::string value(argv[i]); if (!verify_custom_template(argv[i])) {
if (value != "chatml" && value != "llama2") { fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
fprintf(stderr, "error: chat template can be \"llama2\" or \"chatml\", but got: %s\n", value.c_str()); fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.chat_template = value; sparams.chat_template = argv[i];
} }
else if (arg == "--override-kv") else if (arg == "--override-kv")
{ {
@ -2913,7 +2914,7 @@ int main(int argc, char **argv)
if (!validate_api_key(req, res)) { if (!validate_api_key(req, res)) {
return; return;
} }
json data = oaicompat_completion_params_parse(json::parse(req.body), sparams.chat_template); json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
const int task_id = llama.queue_tasks.get_new_id(); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);

View File

@ -167,50 +167,47 @@ static T json_value(const json &body, const std::string &key, const T &default_v
: default_value; : default_value;
} }
inline std::string format_llama2(std::vector<json> messages) // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
inline bool verify_custom_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}};
std::vector<char> buf(1);
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, buf.data(), buf.size());
return res >= 0;
}
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const struct llama_model * model, const std::string & tmpl, const std::vector<json> & messages)
{ {
std::ostringstream output; size_t alloc_size = 0;
bool is_inside_turn = false; // vector holding all allocated string to be passed to llama_chat_apply_template
std::vector<std::string> str(messages.size() * 2);
std::vector<llama_chat_message> chat(messages.size());
for (auto it = messages.begin(); it != messages.end(); ++it) { for (size_t i = 0; i < messages.size(); ++i) {
if (!is_inside_turn) { auto &curr_msg = messages[i];
output << "[INST] "; str[i*2 + 0] = json_value(curr_msg, "role", std::string(""));
} str[i*2 + 1] = json_value(curr_msg, "content", std::string(""));
std::string role = json_value(*it, "role", std::string("user")); alloc_size += str[i*2 + 1].length();
std::string content = json_value(*it, "content", std::string("")); chat[i].role = str[i*2 + 0].c_str();
if (role == "system") { chat[i].content = str[i*2 + 1].c_str();
output << "<<SYS>>\n" << content << "\n<<SYS>>\n\n";
is_inside_turn = true;
} else if (role == "user") {
output << content << " [/INST]";
is_inside_turn = true;
} else {
output << " " << content << " </s>";
is_inside_turn = false;
}
} }
LOG_VERBOSE("format_llama2", {{"text", output.str()}}); const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str();
std::vector<char> buf(alloc_size * 2);
return output.str(); // run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
// if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) {
buf.resize(res);
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), true, buf.data(), buf.size());
} }
inline std::string format_chatml(std::vector<json> messages) std::string formatted_chat(buf.data(), res);
{ LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
std::ostringstream chatml_msgs;
for (auto it = messages.begin(); it != messages.end(); ++it) { return formatted_chat;
chatml_msgs << "<|im_start|>"
<< json_value(*it, "role", std::string("user")) << '\n';
chatml_msgs << json_value(*it, "content", std::string(""))
<< "<|im_end|>\n";
}
chatml_msgs << "<|im_start|>assistant" << '\n';
LOG_VERBOSE("format_chatml", {{"text", chatml_msgs.str()}});
return chatml_msgs.str();
} }
// //

View File

@ -12602,7 +12602,7 @@ LLAMA_API int32_t llama_chat_apply_template(
// load template from model // load template from model
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
std::string template_key = "tokenizer.chat_template"; std::string template_key = "tokenizer.chat_template";
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), curr_tmpl.size()); int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
if (res < 0) { if (res < 0) {
// worst case: there is no information about template, we will use chatml by default // worst case: there is no information about template, we will use chatml by default
curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal curr_tmpl = "<|im_start|>"; // see llama_chat_apply_template_internal