From 9cfe4d7202da427e5e7f65000021ca33f283b26b Mon Sep 17 00:00:00 2001 From: ochafik Date: Thu, 26 Sep 2024 18:06:03 +0100 Subject: [PATCH] `tool-call`: refactor llama_chat_template class + use in validate_model_chat_template --- common/chat-template.cpp | 58 +++++++++++++++++++++++++------------- common/chat-template.h | 26 ++++------------- examples/server/server.cpp | 20 +++++++++++-- 3 files changed, 61 insertions(+), 43 deletions(-) diff --git a/common/chat-template.cpp b/common/chat-template.cpp index 3f84a1fb5..ed37513be 100644 --- a/common/chat-template.cpp +++ b/common/chat-template.cpp @@ -1,5 +1,4 @@ #include "chat-template.h" -#include "minja.hpp" #include "llama.h" using json = nlohmann::ordered_json; @@ -31,14 +30,39 @@ static std::string llama_model_meta_val_str(const struct llama_model * model, co return ""; } +llama_chat_template::llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token) + : _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) { + + _supports_tools = chat_template.find("tools") != std::string::npos; + _requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos; + _supports_system_role = chat_template.find("System role not supported") == std::string::npos; + + if (chat_template.find("") != std::string::npos) { + _tool_call_style = Hermes2Pro; + } else if (chat_template.find(">>>all") != std::string::npos) { + _tool_call_style = FunctionaryV3Llama3; + } else if (chat_template.find("<|start_header_id|>") != std::string::npos) { + if (chat_template.find("") != std::string::npos) { + _tool_call_style = Llama31; + } + } + _template_root = minja::Parser::parse(_chat_template, { + /* .trim_blocks = */ true, + /* .lstrip_blocks = */ true, + /* .keep_trailing_newline = */ false, + }); +} + llama_chat_template llama_chat_template::from_model( const struct llama_model * model, - const std::string & chat_template_override) + const char * chat_template_override) { // TODO: handle "chatml"? - auto chat_template = chat_template_override.empty() - ? llama_model_meta_val_str(model, "tokenizer.chat_template") - : chat_template_override; + std::string chat_template = chat_template_override + ? chat_template_override + : llama_model_meta_val_str(model, "tokenizer.chat_template"); auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true); auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true); return llama_chat_template(chat_template, bos_token, eos_token); @@ -69,9 +93,9 @@ std::string llama_chat_template::apply( throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); } std::string role = message.at("role"); - std::string content = message.at("content"); - if (!_supports_system_role) { + if (!message["content"].is_null() && !_supports_system_role) { + std::string content = message.at("content"); if (role == "system") { if (!pending_system.empty()) pending_system += "\n"; pending_system += content; @@ -89,8 +113,11 @@ std::string llama_chat_template::apply( } if (_requires_object_arguments && message.contains("tool_calls")) { for (auto & tool_call : message.at("tool_calls")) { - std::string arguments = tool_call.at("arguments"); - tool_call["arguments"] = json::parse(arguments); + if (tool_call["type"] == "function") { + auto & function = tool_call.at("function"); + std::string arguments = function.at("arguments"); + function["arguments"] = json::parse(arguments); + } } } } @@ -99,20 +126,11 @@ std::string llama_chat_template::apply( auto context = minja::Context::make(json({ {"messages", actual_messages}, + {"tools", tools}, {"add_generation_prompt", add_generation_prompt}, {"bos_token", _bos_token}, {"eos_token", _eos_token}, })); - if (!tools.is_null() && !tools.empty()) { - auto tools_val = minja::Value(tools); - context->set("tools", tools_val); - } - - auto tmpl_root = minja::Parser::parse(_chat_template, { - /* .trim_blocks = */ true, - /* .lstrip_blocks = */ true, - /* .keep_trailing_newline = */ false, - }); - return tmpl_root->render(context); + return _template_root->render(context); } diff --git a/common/chat-template.h b/common/chat-template.h index 4bab3ff08..e4dc7667f 100644 --- a/common/chat-template.h +++ b/common/chat-template.h @@ -1,11 +1,13 @@ #pragma once +#include "minja.hpp" #include #include #include using json = nlohmann::ordered_json; + enum llama_tool_call_style { Unknown, Llama31, @@ -27,30 +29,14 @@ class llama_chat_template { std::string _chat_template; std::string _bos_token; std::string _eos_token; + std::unique_ptr _template_root; + public: - llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token) - : _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) { - - _supports_tools = chat_template.find("tools") != std::string::npos; - _requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos; - _supports_system_role = chat_template.find("System role not supported") == std::string::npos; - - if (chat_template.find("") != std::string::npos) { - _tool_call_style = Hermes2Pro; - } else if (chat_template.find(">>>all") != std::string::npos) { - _tool_call_style = FunctionaryV3Llama3; - } else if (chat_template.find("<|start_header_id|>") != std::string::npos) { - if (chat_template.find("") != std::string::npos) { - _tool_call_style = Llama31; - } - } - } + llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token); static llama_chat_template from_model( const struct llama_model * model, - const std::string & chat_template_override); + const char * chat_template_override = nullptr); llama_tool_call_style tool_call_style() const { return _tool_call_style; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 341d1cb45..65c0eab0d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -662,9 +662,23 @@ struct server_context { bool validate_model_chat_template(bool use_jinja) const { llama_chat_message chat[] = {{"user", "test"}}; - const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); + if (use_jinja) { + auto chat_template = llama_chat_template::from_model(model); + try { + chat_template.apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + return true; + } catch (const std::exception & e) { + SRV_ERR("failed to apply template: %s\n", e.what()); + return false; + } + } else { + const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); - return res > 0; + return res > 0; + } } void init() { @@ -2860,7 +2874,7 @@ int main(int argc, char ** argv) { return; } - auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template); + auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template.empty() ? nullptr : params.chat_template.c_str()); json data; try {