Add chatml fallback for cpp llama_chat_apply_template (#8160)

* add chatml fallback for cpp `llama_chat_apply_template`

* remove redundant code
This commit is contained in:
Xuan Son Nguyen 2024-06-27 18:14:19 +02:00 committed by GitHub
parent ab3679112d
commit 16791b8f0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 20 additions and 1 deletions

View File

@ -2618,6 +2618,7 @@ std::string llama_chat_apply_template(const struct llama_model * model,
const std::vector<llama_chat_msg> & msgs, const std::vector<llama_chat_msg> & msgs,
bool add_ass) { bool add_ass) {
int alloc_size = 0; int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat; std::vector<llama_chat_message> chat;
for (auto & msg : msgs) { for (auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()}); chat.push_back({msg.role.c_str(), msg.content.c_str()});
@ -2630,10 +2631,26 @@ std::string llama_chat_apply_template(const struct llama_model * model,
// run the first time to get the total output length // run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported
if (res < 0) {
if (ptr_tmpl != nullptr) {
// if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported");
} else {
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
}
}
// if it turns out that our buffer is too small, we resize it // if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) { if ((size_t) res > buf.size()) {
buf.resize(res); buf.resize(res);
res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); res = llama_chat_apply_template(
fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size());
} }
std::string formatted_chat(buf.data(), res); std::string formatted_chat(buf.data(), res);

View File

@ -380,6 +380,8 @@ struct llama_chat_msg {
bool llama_chat_verify_template(const std::string & tmpl); bool llama_chat_verify_template(const std::string & tmpl);
// CPP wrapper for llama_chat_apply_template // CPP wrapper for llama_chat_apply_template
// If the built-in template is not supported, we default to chatml
// If the custom "tmpl" is not supported, we throw an error
std::string llama_chat_apply_template(const struct llama_model * model, std::string llama_chat_apply_template(const struct llama_model * model,
const std::string & tmpl, const std::string & tmpl,
const std::vector<llama_chat_msg> & chat, const std::vector<llama_chat_msg> & chat,