From 16791b8f0b4526aafbf5d0e5bbbd2e99c2253418 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 27 Jun 2024 18:14:19 +0200 Subject: [PATCH] Add chatml fallback for cpp `llama_chat_apply_template` (#8160) * add chatml fallback for cpp `llama_chat_apply_template` * remove redundant code --- common/common.cpp | 19 ++++++++++++++++++- common/common.h | 2 ++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/common/common.cpp b/common/common.cpp index 70349ad70..57d03a578 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2618,6 +2618,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, const std::vector & msgs, bool add_ass) { int alloc_size = 0; + bool fallback = false; // indicate if we must fallback to default chatml std::vector chat; for (auto & msg : msgs) { chat.push_back({msg.role.c_str(), msg.content.c_str()}); @@ -2630,10 +2631,26 @@ std::string llama_chat_apply_template(const struct llama_model * model, // run the first time to get the total output length int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + // error: chat template is not supported + if (res < 0) { + if (ptr_tmpl != nullptr) { + // if the custom "tmpl" is not supported, we throw an error + // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() + throw std::runtime_error("this custom template is not supported"); + } else { + // If the built-in template is not supported, we default to chatml + res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + fallback = true; + } + } + // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); - res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); + res = llama_chat_apply_template( + fallback ? nullptr : model, + fallback ? "chatml" : ptr_tmpl, + chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } std::string formatted_chat(buf.data(), res); diff --git a/common/common.h b/common/common.h index c541204f6..0486ba380 100644 --- a/common/common.h +++ b/common/common.h @@ -380,6 +380,8 @@ struct llama_chat_msg { bool llama_chat_verify_template(const std::string & tmpl); // CPP wrapper for llama_chat_apply_template +// If the built-in template is not supported, we default to chatml +// If the custom "tmpl" is not supported, we throw an error std::string llama_chat_apply_template(const struct llama_model * model, const std::string & tmpl, const std::vector & chat,