diff --git a/Makefile b/Makefile index 25f5db074..749925a57 100644 --- a/Makefile +++ b/Makefile @@ -934,6 +934,7 @@ OBJ_LLAMA = \ OBJ_COMMON = \ common/common.o \ + common/chat-template.o \ common/arg.o \ common/log.o \ common/console.o \ @@ -1170,6 +1171,8 @@ $(LIB_LLAMA_S): \ common/common.o: \ common/common.cpp \ common/common.h \ + common/chat-template.cpp \ + common/chat-template.h \ common/console.h \ common/sampling.h \ common/json.hpp \ @@ -1465,6 +1468,7 @@ llama-server: \ examples/server/prompt-formats.js.hpp \ examples/server/json-schema-to-grammar.mjs.hpp \ examples/server/loading.html.hpp \ + common/chat-template.h \ common/json.hpp \ common/stb_image.h \ $(OBJ_ALL) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index c132e8333..3fb2865ca 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -54,6 +54,8 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp + chat-template.cpp + chat-template.h common.cpp common.h console.cpp diff --git a/common/chat-template.cpp b/common/chat-template.cpp new file mode 100644 index 000000000..3f84a1fb5 --- /dev/null +++ b/common/chat-template.cpp @@ -0,0 +1,118 @@ +#include "chat-template.h" +#include "minja.hpp" +#include "llama.h" + +using json = nlohmann::ordered_json; + +static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { + std::string piece; + piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' + const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); + } + else { + piece.resize(n_chars); + } + + return piece; +} + +static std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) { + int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); + if (tlen > 0) { + std::vector curr_tmpl_buf(tlen + 1, 0); + if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { + return std::string(curr_tmpl_buf.data(), tlen); + } + } + return ""; +} + +llama_chat_template llama_chat_template::from_model( + const struct llama_model * model, + const std::string & chat_template_override) +{ + // TODO: handle "chatml"? + auto chat_template = chat_template_override.empty() + ? llama_model_meta_val_str(model, "tokenizer.chat_template") + : chat_template_override; + auto bos_token = _llama_token_to_piece(model, llama_token_bos(model), true); + auto eos_token = _llama_token_to_piece(model, llama_token_eos(model), true); + return llama_chat_template(chat_template, bos_token, eos_token); +} + +std::string llama_chat_template::apply( + const json & messages, + const json & tools, + bool add_generation_prompt) const +{ + auto actual_messages = messages; + + // First, "fix" messages so they have a chance to be rendered correctly by the template + + if (_requires_object_arguments || !_supports_system_role) { + std::string pending_system; + auto flush_sys = [&]() { + if (!pending_system.empty()) { + actual_messages.push_back({ + {"role", "user"}, + {"content", pending_system}, + }); + pending_system.clear(); + } + }; + for (auto & message : actual_messages) { + if (!message.contains("role") || !message.contains("content")) { + throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); + } + std::string role = message.at("role"); + std::string content = message.at("content"); + + if (!_supports_system_role) { + if (role == "system") { + if (!pending_system.empty()) pending_system += "\n"; + pending_system += content; + continue; + } else { + if (role == "user") { + if (!pending_system.empty()) { + message["content"] = pending_system + (content.empty() ? "" : "\n" + content); + pending_system.clear(); + } + } else { + flush_sys(); + } + } + } + if (_requires_object_arguments && message.contains("tool_calls")) { + for (auto & tool_call : message.at("tool_calls")) { + std::string arguments = tool_call.at("arguments"); + tool_call["arguments"] = json::parse(arguments); + } + } + } + flush_sys(); + } + + auto context = minja::Context::make(json({ + {"messages", actual_messages}, + {"add_generation_prompt", add_generation_prompt}, + {"bos_token", _bos_token}, + {"eos_token", _eos_token}, + })); + + if (!tools.is_null() && !tools.empty()) { + auto tools_val = minja::Value(tools); + context->set("tools", tools_val); + } + + auto tmpl_root = minja::Parser::parse(_chat_template, { + /* .trim_blocks = */ true, + /* .lstrip_blocks = */ true, + /* .keep_trailing_newline = */ false, + }); + return tmpl_root->render(context); +} diff --git a/common/chat-template.h b/common/chat-template.h new file mode 100644 index 000000000..4bab3ff08 --- /dev/null +++ b/common/chat-template.h @@ -0,0 +1,64 @@ +#pragma once + +#include +#include +#include + +using json = nlohmann::ordered_json; + +enum llama_tool_call_style { + Unknown, + Llama31, + FunctionaryV3Llama3, + FunctionaryV3Llama31, + Hermes2Pro, +}; + +class llama_chat_template { + public: + + private: + llama_tool_call_style _tool_call_style = Unknown; + bool _supports_tools = true; + // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. + // Most other templates (and OpenAI's API) expect the arguments object to be stringified. + bool _requires_object_arguments = false; + bool _supports_system_role = true; + std::string _chat_template; + std::string _bos_token; + std::string _eos_token; + public: + llama_chat_template(const std::string & chat_template, const std::string & bos_token, const std::string & eos_token) + : _chat_template(chat_template), _bos_token(bos_token), _eos_token(eos_token) { + + _supports_tools = chat_template.find("tools") != std::string::npos; + _requires_object_arguments = chat_template.find("tool_call.arguments | items") != std::string::npos; + _supports_system_role = chat_template.find("System role not supported") == std::string::npos; + + if (chat_template.find("") != std::string::npos) { + _tool_call_style = Hermes2Pro; + } else if (chat_template.find(">>>all") != std::string::npos) { + _tool_call_style = FunctionaryV3Llama3; + } else if (chat_template.find("<|start_header_id|>") != std::string::npos) { + if (chat_template.find("") != std::string::npos) { + _tool_call_style = Llama31; + } + } + } + + static llama_chat_template from_model( + const struct llama_model * model, + const std::string & chat_template_override); + + llama_tool_call_style tool_call_style() const { return _tool_call_style; } + + const std::string & chat_template() const { return _chat_template; } + bool supports_tools() const { return _supports_tools; } + + std::string apply( + const nlohmann::ordered_json & messages, + const nlohmann::ordered_json & tools, + bool add_generation_prompt) const; +}; diff --git a/common/common.cpp b/common/common.cpp index e6254ef3b..e247a2eb4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -9,6 +9,7 @@ #include "json.hpp" #include "json-schema-to-grammar.h" #include "llama.h" +#include "chat-template.h" #include #include @@ -1511,6 +1512,20 @@ std::string llama_detokenize(llama_context * ctx, const std::vector // bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) { + if (use_jinja) { + try { + auto chat_template = llama_chat_template(tmpl, "", ""); + chat_template.apply({{ + {"role", "user"}, + {"content", "test"}, + }}, json(), true); + return true; + } catch (const std::exception & e) { + LOG_ERR("%s: failed to apply template: %s\n", __func__, e.what()); + return false; + } + } + llama_chat_message chat[] = {{"user", "test"}}; int res = llama_chat_apply_template( nullptr, @@ -1519,22 +1534,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja) { 1, /* add_ass= */ true, /* buffer= */ nullptr, - /* length= */ 0, - use_jinja, - /* tools= */ nullptr, - "", - ""); + /* length= */ 0); return res >= 0; } std::string llama_chat_apply_template(const struct llama_model * model, const std::string & tmpl, const std::vector & msgs, - bool add_ass, - bool use_jinja, - const char * tools, - const char * bos_token, - const char * eos_token) { + bool add_ass) { int alloc_size = 0; bool fallback = false; // indicate if we must fallback to default chatml std::vector chat; @@ -1547,7 +1554,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, std::vector buf(alloc_size); // run the first time to get the total output length - int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token); + int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -1557,7 +1564,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, throw std::runtime_error("this custom template is not supported"); } else { // If the built-in template is not supported, we default to chatml - res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token); + res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size()); fallback = true; } } @@ -1568,7 +1575,7 @@ std::string llama_chat_apply_template(const struct llama_model * model, res = llama_chat_apply_template( fallback ? nullptr : model, fallback ? "chatml" : ptr_tmpl, - chat.data(), chat.size(), add_ass, buf.data(), buf.size(), use_jinja, tools, bos_token, eos_token); + chat.data(), chat.size(), add_ass, buf.data(), buf.size()); } std::string formatted_chat(buf.data(), res); @@ -1579,13 +1586,9 @@ std::string llama_chat_format_single(const struct llama_model * model, const std::string & tmpl, const std::vector & past_msg, const llama_chat_msg & new_msg, - bool add_ass, - bool use_jinja, - const char * tools, - const char * bos_token, - const char * eos_token) { + bool add_ass) { std::ostringstream ss; - auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false, use_jinja, tools, bos_token, eos_token); + auto fmt_past_msg = past_msg.empty() ? "" : llama_chat_apply_template(model, tmpl, past_msg, false); std::vector chat_new(past_msg); // if the past_msg ends with a newline, we must preserve it in the formatted version if (add_ass && !fmt_past_msg.empty() && fmt_past_msg.back() == '\n') { @@ -1593,7 +1596,7 @@ std::string llama_chat_format_single(const struct llama_model * model, }; // format chat with new_msg chat_new.push_back(new_msg); - auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass, use_jinja, tools, bos_token, eos_token); + auto fmt_new_msg = llama_chat_apply_template(model, tmpl, chat_new, add_ass); // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); diff --git a/common/common.h b/common/common.h index 0d34c962e..b7a6c9181 100644 --- a/common/common.h +++ b/common/common.h @@ -471,21 +471,14 @@ std::string llama_detokenize( // Chat template utils // -struct llama_chat_msg_tool_call { - std::string name; - std::string arguments; -}; - // same as llama_chat_message, but uses std::string and std::vector struct llama_chat_msg { std::string role; std::string content; - std::string tool; - std::vector tool_calls; }; -// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid -bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false); +// Check if the template is supported or not. Returns true if it's valid +bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja); // CPP wrapper for llama_chat_apply_template // If the built-in template is not supported, we default to chatml @@ -493,22 +486,14 @@ bool llama_chat_verify_template(const std::string & tmpl, bool use_jinja = false std::string llama_chat_apply_template(const struct llama_model * model, const std::string & tmpl, const std::vector & chat, - bool add_ass, - bool use_jinja = false, - const char * tools = nullptr, - const char * bos_token = nullptr, - const char * eos_token = nullptr); + bool add_ass); // Format single message, while taking into account the position of that message in chat history std::string llama_chat_format_single(const struct llama_model * model, const std::string & tmpl, const std::vector & past_msg, const llama_chat_msg & new_msg, - bool add_ass, - bool use_jinja = false, - const char * tools = nullptr, - const char * bos_token = nullptr, - const char * eos_token = nullptr); + bool add_ass); // Returns an example of formatted chat std::string llama_chat_format_example(const struct llama_model * model, diff --git a/common/tool-call.cpp b/common/tool-call.cpp index 8304069ac..7b435703a 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -12,27 +12,6 @@ using json = nlohmann::ordered_json; -// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3.llama3.txt -static bool needs_functionary_v3_tool_call(const std::string & chat_template) { - return chat_template.find("<|start_header_id|>") != std::string::npos - && chat_template.find(">>>all") != std::string::npos; -} - -// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt -static bool needs_functionary_v3_llama_3_1_tool_call(const std::string & chat_template) { - return chat_template.find("<|start_header_id|>") != std::string::npos - && chat_template.find("") != std::string::npos - && chat_template.find("<|python_tag|>") != std::string::npos; -} - -static bool needs_hermes_pro_tool_call(const std::string & chat_template) { - return chat_template.find("") != std::string::npos; -} - static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { // // https://json.nlohmann.me/features/parsing/sax_interface/ struct json_error_locator : public nlohmann::json_sax { @@ -209,137 +188,145 @@ static llama_tool_calls parse_functionary_v3_tool_calls(const std::string& input return parse_functionary_tool_calls(input, function_regex, close_regex); } -llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_template, const std::string& input) { - if (needs_hermes_pro_tool_call(chat_template)) { - return parse_hermes_tool_calls(input); - } else if (needs_functionary_v3_tool_call(chat_template)) { - return parse_functionary_v3_tool_calls(input); - } else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) { - return parse_functionary_v3_llama_3_1_tool_calls(input); - } else if (needs_llama_3_1_tool_call(chat_template)) { - return parse_llama_3_1_tool_calls(tools, input); - } else { - throw std::runtime_error("Unsupported chat template for tool calls"); +llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) { + switch (style) { + case llama_tool_call_style::Llama31: + return parse_llama_3_1_tool_calls(tools, input); + case llama_tool_call_style::FunctionaryV3Llama3: + return parse_functionary_v3_tool_calls(input); + case llama_tool_call_style::FunctionaryV3Llama31: + return parse_functionary_v3_llama_3_1_tool_calls(input); + case llama_tool_call_style::Hermes2Pro: + return parse_hermes_tool_calls(input); + default: + throw std::runtime_error("Unsupported tool call style"); } } llama_tool_call_handler llama_tool_call_handler_init( - const std::string & chat_template, + const llama_chat_template & tmpl, bool allow_content, bool parallel_tool_calls, const nlohmann::ordered_json & tools) { llama_tool_call_handler handler; - if (needs_functionary_v3_tool_call(chat_template)) { - // MeetKaiFunctionary_3_2 - // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... - // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar - handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { - std::vector tool_rules; - for (size_t i = 0, n = tools.size(); i < n; i++) { - auto & tool = tools[i]; - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters)); - tool_rules.push_back(tool_rule); + switch (tmpl.tool_call_style()) { + case llama_tool_call_style::Llama31: { + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + static std::vector builtin_tools {"wolfram_alpha", "brave_search"}; + std::vector tool_rules; + + for (const auto & tool : tools) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + if (name == "ipython" || std::find(builtin_tools.begin(), builtin_tools.end(), name) != builtin_tools.end()) { + tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); + if (allow_content) { + handler.grammar_trigger_words.push_back("<|python_tag|>"); + } + } else { + //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + + builder.add_schema(name + "-args", parameters) + + " \"}\"")); + if (allow_content) { + handler.grammar_trigger_words.push_back("\n{\"" + name + "\""); + } + } + } + + builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | ")); + }); + handler.additional_stop_words.push_back("<|eom_id|>"); + break; + } + case llama_tool_call_style::FunctionaryV3Llama3: { + // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... + // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + std::vector tool_rules; + for (size_t i = 0, n = tools.size(); i < n; i++) { + auto & tool = tools[i]; + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters)); + tool_rules.push_back(tool_rule); + if (allow_content) { + handler.grammar_trigger_words.push_back(">>>" + name + "\n"); + } + } + auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; + builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + }); + // handler.parser = parse_functionary_3_2_tool_calls; + break; + } + case llama_tool_call_style::FunctionaryV3Llama31: { + // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja + // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt + // TODO: handle tool {type: code_interpreter} as python + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + std::vector tool_rules; + for (size_t i = 0, n = tools.size(); i < n; i++) { + auto & tool = tools[i]; + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + if (name == "python") { + tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); + if (allow_content) { + handler.grammar_trigger_words.push_back("<|python_tag|>"); + } + } else { + tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\"")); + } + } + auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; + builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); if (allow_content) { - handler.grammar_trigger_words.push_back(">>>" + name + "\n"); + handler.grammar_trigger_words.push_back(" tool_rules; - for (size_t i = 0, n = tools.size(); i < n; i++) { - auto & tool = tools[i]; - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - if (name == "python") { - tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); - if (allow_content) { - handler.grammar_trigger_words.push_back("<|python_tag|>"); - } - } else { - tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\"")); + }); + // handler.parser = parse_functionary_3_2_tool_calls; + break; + } + case llama_tool_call_style::Hermes2Pro: { + // NousResearchHermesPro_2 + // (content)?({"name": "foo", "arguments": {"a": 1}})* + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + std::vector tool_rules; + for (const auto & tool : tools) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_schema(name + "-call", { + {"type", "object"}, + {"properties", json { + {"name", json {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + })); } - } - auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; - builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - if (allow_content) { - handler.grammar_trigger_words.push_back("{"name": "foo", "arguments": {"a": 1}})* - handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { - std::vector tool_rules; - for (const auto & tool : tools) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_schema(name + "-call", { - {"type", "object"}, - {"properties", json { - {"name", json {{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - })); - } - auto tool_call = "\"\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"\" space"; - builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - if (allow_content) { - handler.grammar_trigger_words.push_back(""); - } - }); - } else if (needs_llama_3_1_tool_call(chat_template)) { - handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { - static std::vector builtin_tools {"wolfram_alpha", "brave_search"}; - std::vector tool_rules; - - for (const auto & tool : tools) { - const auto & function = tool["function"]; - std::string name = function["name"]; - auto parameters = function["parameters"]; - builder.resolve_refs(parameters); - if (name == "ipython" || std::find(builtin_tools.begin(), builtin_tools.end(), name) != builtin_tools.end()) { - tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); - if (allow_content) { - handler.grammar_trigger_words.push_back("<|python_tag|>"); - } - } else { - //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + - builder.add_schema(name + "-args", parameters) + - " \"}\"")); - if (allow_content) { - handler.grammar_trigger_words.push_back("\n{\"" + name + "\""); - } + auto tool_call = "\"\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"\" space"; + builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + if (allow_content) { + handler.grammar_trigger_words.push_back(""); } - } - - builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | ")); - }); - handler.additional_stop_words.push_back("<|eom_id|>"); - } else { - // TODO: generic thoughtful schema. - throw std::runtime_error("Unsupported tool call style!"); + }); + break; + } + default: + throw std::runtime_error("Unsupported tool call style"); } return handler; } diff --git a/common/tool-call.h b/common/tool-call.h index de3958575..1cc9f8374 100644 --- a/common/tool-call.h +++ b/common/tool-call.h @@ -5,22 +5,29 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT #include "json.hpp" +#include "chat-template.h" + +struct llama_tool_call { + std::string name; + std::string arguments; +}; struct llama_tool_calls { std::string content; - std::vector tool_calls; + std::vector tool_calls; }; struct llama_tool_call_handler { std::string grammar; std::vector grammar_trigger_words; std::vector additional_stop_words; + nlohmann::ordered_json updated_tools; }; -llama_tool_calls parse_tool_calls(const nlohmann::ordered_json & tools, const std::string & chat_template, const std::string& input); +llama_tool_calls parse_tool_calls(llama_tool_call_style style, const nlohmann::ordered_json & tools, const std::string& input); llama_tool_call_handler llama_tool_call_handler_init( - const std::string & chat_template, + const llama_chat_template & tmpl, bool allow_content, bool parallel_tool_calls, const nlohmann::ordered_json & tools); diff --git a/examples/server/README.md b/examples/server/README.md index 838a23254..cf479aeac 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -571,6 +571,12 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte ```shell llama-server --jinja -hfr lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF -hff Meta-Llama-3.1-8B-Instruct-Q5_K_M.gguf -fa + # https://huggingface.co/meetkai/functionary-medium-v3.2 + llama-server --jinja -hfr bartowski/functionary-medium-v3.2-GGUF -hff functionary-medium-v3.2-IQ4_XS.gguf -fa + + # https://huggingface.co/meetkai/functionary-medium-v3.1 + llama-server --jinja -hfr meetkai/functionary-medium-v3.1-GGUF -hff functionary-medium-llama-3.1.Q4_0.gguf -fa + curl http://localhost:8080/v1/chat/completions \ -d '{ "model": "gpt-3.5-turbo", diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 49c412f8b..341d1cb45 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -662,7 +662,7 @@ struct server_context { bool validate_model_chat_template(bool use_jinja) const { llama_chat_message chat[] = {{"user", "test"}}; - const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0, use_jinja, nullptr, nullptr, nullptr); + const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); return res > 0; } @@ -2860,9 +2860,11 @@ int main(int argc, char ** argv) { return; } + auto chat_template = llama_chat_template::from_model(ctx_server.model, params.chat_template); + json data; try { - data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template, params.use_jinja); + data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), chat_template, params.use_jinja); } catch (const std::runtime_error & e) { res_error(res, format_error_response(e.what(), ERROR_TYPE_NOT_SUPPORTED)); return; @@ -2880,7 +2882,7 @@ int main(int argc, char ** argv) { ctx_server.receive_cmpl_results(task_ids, [&](const std::vector & results) { // multitask is never support in chat completion, there is only one result try { - json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose); + json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, chat_template, /*.streaming =*/ false, verbose); res_ok(res, result_oai); } catch (const std::runtime_error & e) { res_error(res, format_error_response(e.what(), ERROR_TYPE_SERVER)); diff --git a/examples/server/tests/features/tool_call.feature b/examples/server/tests/features/tool_call.feature index 4991ed7b3..b7b073025 100644 --- a/examples/server/tests/features/tool_call.feature +++ b/examples/server/tests/features/tool_call.feature @@ -23,19 +23,19 @@ Feature: llama.cpp server And a model test And max tokens to predict And a user prompt write a hello world in python - And a tool choice + And a tool choice required And tools And an OAI compatible chat completions request with no api error Then tool is called with arguments Examples: Prompts - | template_name | n_predict | tool_name | tool_arguments | tool_choice | tools | - | meetkai-functionary-medium-v3.1 | 128 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | - | meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "Yes, you can."} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | - | meetkai-functionary-medium-v3.2 | 128 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | - | meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | - | meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | - | meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": "it and "} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | + | template_name | n_predict | tool_name | tool_arguments | tools | + | meetkai-functionary-medium-v3.1 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | + | meetkai-functionary-medium-v3.1 | 128 | ipython | {"code": "I'm sorry,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | + | meetkai-functionary-medium-v3.2 | 128 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | + | meetkai-functionary-medium-v3.2 | 128 | ipython | {"code": "Yes,"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | + | meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | test | {} | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] | + | meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": ". A"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | Scenario: OAI Compatibility w/ no tool diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index fff4a78bc..e37173885 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -14,6 +14,7 @@ // Change JSON_ASSERT from assert() to GGML_ASSERT: #define JSON_ASSERT GGML_ASSERT +#include "chat-template.h" #include "json.hpp" #include "minja.hpp" #include "tool-call.h" @@ -64,40 +65,30 @@ inline std::string format_chat(const struct llama_model * model, const std::stri for (size_t i = 0; i < messages.size(); ++i) { const auto & curr_msg = messages[i]; - llama_chat_msg msg; - msg.role = json_value(curr_msg, "role", std::string("")); - msg.tool = json_value(curr_msg, "tool", std::string("")); + std::string role = json_value(curr_msg, "role", std::string("")); + + std::string content; if (curr_msg.contains("content")) { if (curr_msg["content"].is_string()) { - msg.content = curr_msg["content"].get(); + content = curr_msg["content"].get(); } else if (curr_msg["content"].is_array()) { for (const auto & part : curr_msg["content"]) { if (part.contains("text")) { - msg.content += "\n" + part["text"].get(); + content += "\n" + part["text"].get(); } } - } else if (!(curr_msg.is_null() && curr_msg.contains("tool_calls"))) { - throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367): " + curr_msg.dump()); + } else { + throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } } else { throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)"); } - if (curr_msg.contains("tool_calls") && curr_msg["tool_calls"].is_array()) { - for (const auto & tool_call : curr_msg["tool_calls"]) { - if (json_value(tool_call, "type", std::string("")) == "function" - && tool_call.contains("function") && tool_call["function"].is_object()) { - msg.tool_calls.push_back({ - json_value(tool_call["function"], "name", std::string("")), - json_value(tool_call["function"], "arguments", std::string("")) - }); - } - } - } - chat.emplace_back(std::move(msg)); + + chat.push_back({role, content}); } - const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true, use_jinja, tools.is_null() ? nullptr : tools.dump().c_str()); + const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true); LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str()); return formatted_chat; @@ -315,38 +306,12 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons // OAI utils // -static std::string _llama_token_to_piece(const struct llama_model * model, llama_token token, bool special) { - std::string piece; - piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' - const int n_chars = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); - if (n_chars < 0) { - piece.resize(-n_chars); - int check = llama_token_to_piece(model, token, &piece[0], piece.size(), 0, special); - GGML_ASSERT(check == -n_chars); - } - else { - piece.resize(n_chars); - } - - return piece; -} - -std::string llama_model_meta_val_str(const struct llama_model * model, const char * key) { - int32_t tlen = llama_model_meta_val_str(model, key, nullptr, 0); - if (tlen > 0) { - std::vector curr_tmpl_buf(tlen + 1, 0); - if (llama_model_meta_val_str(model, key, curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { - return std::string(curr_tmpl_buf.data(), tlen); - } - } - return ""; -} - static json oaicompat_completion_params_parse( const struct llama_model * model, const json & body, /* openai api json semantics */ - const std::string & chat_template_src, - bool use_jinja) { + const llama_chat_template & tmpl, + bool use_jinja) +{ json llama_params; llama_params["__oaicompat"] = true; @@ -355,16 +320,15 @@ static json oaicompat_completion_params_parse( auto has_tools = tools.is_array() && !tools.empty(); // Apply chat template to the list of messages - auto chat_template = chat_template_src.empty() ? llama_model_meta_val_str(model, "tokenizer.chat_template") : chat_template_src; - llama_params["chat_template"] = chat_template; + llama_params["chat_template"] = tmpl.chat_template(); + if (use_jinja) { - if (has_tools && chat_template.find("tools") == std::string::npos) { + if (has_tools && !tmpl.supports_tools()) { throw std::runtime_error("Chat template does not seem to support tools. Override the model template with --chat-template."); } } else if (has_tools) { throw std::runtime_error("Tools are only supported in --jinja mode"); } - llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"), tools, use_jinja); // Handle "stop" field if (body.contains("stop") && body.at("stop").is_string()) { @@ -399,26 +363,40 @@ static json oaicompat_completion_params_parse( } else if (!response_type.empty() && response_type != "text") { throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); } - } else if (use_jinja && tool_choice != "none" && has_tools) { - bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + } + + if (use_jinja) { bool allow_content = tool_choice != "required"; + if (tool_choice != "none" && has_tools) { + bool parallel_tool_calls = json_value(body, "parallel_tool_calls", false); + llama_params["parse_tool_calls"] = true; + llama_params["parallel_tool_calls"] = parallel_tool_calls; - auto handler = llama_tool_call_handler_init(chat_template, allow_content, parallel_tool_calls, tools); + auto handler = llama_tool_call_handler_init(tmpl, allow_content, parallel_tool_calls, tools); - for (const auto & stop : handler.additional_stop_words) { - llama_params["stop"].push_back(stop); - } - if (!handler.grammar_trigger_words.empty()) { - auto triggers = json::array(); - for (const auto & word : handler.grammar_trigger_words) { - triggers.push_back(word); + for (const auto & stop : handler.additional_stop_words) { + llama_params["stop"].push_back(stop); + } + if (!handler.grammar_trigger_words.empty()) { + auto triggers = json::array(); + for (const auto & word : handler.grammar_trigger_words) { + triggers.push_back(word); + } + llama_params["grammar_trigger_words"] = triggers; + } + if (handler.updated_tools.is_null()) { + tools = handler.updated_tools; + } + if (!handler.grammar.empty()) { + if (llama_params.contains("grammar")) { + throw std::runtime_error("Cannot use custom grammar constraints with tools."); + } + llama_params["grammar"] = handler.grammar; } - llama_params["grammar_trigger_words"] = triggers; } - - llama_params["grammar"] = handler.grammar; - llama_params["parse_tool_calls"] = true; - llama_params["parallel_tool_calls"] = parallel_tool_calls; + llama_params["prompt"] = tmpl.apply(body.at("messages"), tools, /* add_generation_prompt= */ true); + } else { + llama_params["prompt"] = format_chat(model, tmpl.chat_template(), body.at("messages"), tools, /* use_jinja= */ false); } // Handle "n" field @@ -458,7 +436,7 @@ static json oaicompat_completion_params_parse( return llama_params; } -static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) { +static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, const llama_chat_template & tmpl, bool streaming = false, bool verbose = false) { bool stopped_word = result.count("stopped_word") != 0; bool stopped_eos = json_value(result, "stopped_eos", false); int num_tokens_predicted = json_value(result, "tokens_predicted", 0); @@ -474,9 +452,8 @@ static json format_final_response_oaicompat(const json & request, const json & r auto tools = json_value(request, "tools", json::array()); json tool_calls; json message_content; - printf("# CONTENT: %s\n\n", content.c_str()); if (json_value(request, "parse_tool_calls", false) - && !(parsed_tool_calls = parse_tool_calls(tools, chat_template, content)).tool_calls.empty()) { + && !(parsed_tool_calls = parse_tool_calls(tmpl.tool_call_style(), tools, content)).tool_calls.empty()) { finish_reason = "tool"; if (!parsed_tool_calls.content.empty()) { message_content = parsed_tool_calls.content; @@ -514,7 +491,6 @@ static json format_final_response_oaicompat(const json & request, const json & r }}, {"id", completion_id} }; - printf("# RES: %s\n\n", res.dump(2).c_str()); // extra fields for debugging purposes if (verbose) { diff --git a/include/llama.h b/include/llama.h index 262142b96..de5a40ef2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -377,19 +377,9 @@ extern "C" { } llama_sampler_chain_params; // used in chat template - - typedef struct llama_chat_message_tool_call { - const char * name; - const char * arguments; - } llama_chat_message_tool_call; - typedef struct llama_chat_message { const char * role; const char * content; - const char * tool; - - const llama_chat_message_tool_call * tool_calls; - uint32_t n_tool_calls; } llama_chat_message; // lora adapter @@ -986,11 +976,7 @@ extern "C" { size_t n_msg, bool add_ass, char * buf, - int32_t length, - bool use_jinja, - const char * tools, - const char * bos_token, - const char * eos_token); + int32_t length); // // Sampling API diff --git a/src/llama.cpp b/src/llama.cpp index ddaaa1f74..758067958 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2,8 +2,6 @@ #include "llama-vocab.h" #include "llama-sampling.h" -#include "minja.hpp" - #include "unicode.h" #include "ggml.h" @@ -21004,95 +21002,7 @@ int32_t llama_detokenize( static int32_t llama_chat_apply_template_internal( const std::string & tmpl, const std::vector & chat, - std::string & dest, bool add_ass, - bool use_jinja, - const std::string & tools, - const std::string & bos_token, const std::string & eos_token) { - - if (use_jinja) { - auto system_not_supported = tmpl.find("System role not supported") != std::string::npos; - - // Meta-Llama-3.1-8B-Instruct's template expects arguments to be an object. - // Most other templates (and OpenAI's API) expect the arguments object to be stringified. - auto tool_call_args_must_be_objects = tmpl.find("tool_call.arguments | items") != std::string::npos; - - auto messages = json::array(); - - std::string pending_system; - auto flush_sys = [&]() { - if (!pending_system.empty()) { - messages.push_back({ - {"role", "user"}, - {"content", pending_system}, - }); - pending_system.clear(); - } - }; - for (const auto * msg : chat) { - std::string role(msg->role); - std::string content(msg->content); - if (system_not_supported) { - if (role == "system") { - if (!pending_system.empty()) pending_system += "\n"; - pending_system += content; - continue; - } else { - if (role == "user") { - if (!pending_system.empty()) { - content = pending_system + (content.empty() ? "" : "\n" + content); - pending_system.clear(); - } - } else { - flush_sys(); - } - } - } - auto message = json({ - {"role", role}, - {"content", content}, - }); - if (msg->tool) message["tool"] = msg->tool; - if (msg->n_tool_calls) { - auto tool_calls = json::array(); - for (uint32_t i = 0; i < msg->n_tool_calls; i++) { - auto args = msg->tool_calls[i].arguments; - tool_calls.push_back(json({ - {"type", "function"}, - {"function", { - {"name", msg->tool_calls[i].name}, - {"arguments", tool_call_args_must_be_objects ? json::parse(args) : args}, - }} - })); - } - messages["tool_calls"] = tool_calls; - } - messages.push_back(message); - } - flush_sys(); - - auto context = minja::Context::make(json({ - {"messages", messages}, - {"add_generation_prompt", add_ass}, - {"bos_token", bos_token}, - {"eos_token", eos_token}, - })); - if (!tools.empty()) { - auto tools_val = minja::Value(json::parse(tools)); - context->set("tools", tools_val); - } - auto tmpl_root = minja::Parser::parse(tmpl, { - /* .trim_blocks = */ true, - /* .lstrip_blocks = */ true, - /* .keep_trailing_newline = */ false, - }); - try { - dest = tmpl_root->render(context); - return dest.size(); - } catch (const std::runtime_error & err) { - LLAMA_LOG_ERROR("Error in jinja template: %s\n", err.what()); - return -1; - } - } + std::string & dest, bool add_ass) { // Taken from the research: https://github.com/ggerganov/llama.cpp/issues/5527 std::stringstream ss; @@ -21360,11 +21270,7 @@ int32_t llama_chat_apply_template( size_t n_msg, bool add_ass, char * buf, - int32_t length, - bool use_jinja, - const char * tools, - const char * bos_token, - const char * eos_token) { + int32_t length) { std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); if (tmpl == nullptr) { GGML_ASSERT(model != nullptr); @@ -21379,16 +21285,6 @@ int32_t llama_chat_apply_template( curr_tmpl = std::string(model_template.data(), model_template.size()); } } - std::string curr_bos_token(bos_token ? bos_token : ""); - std::string curr_eos_token(eos_token ? eos_token : ""); - if (bos_token == nullptr) { - GGML_ASSERT(model != nullptr); - curr_bos_token = llama_token_to_piece(model, llama_token_bos(model), true); - } - if (eos_token == nullptr) { - GGML_ASSERT(model != nullptr); - curr_eos_token = llama_token_to_piece(model, llama_token_eos(model), true); - } // format the chat to string std::vector chat_vec; @@ -21398,7 +21294,7 @@ int32_t llama_chat_apply_template( } std::string formatted_chat; - int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass, use_jinja, tools == nullptr ? "" : tools, curr_bos_token, curr_eos_token); + int32_t res = llama_chat_apply_template_internal(curr_tmpl, chat_vec, formatted_chat, add_ass); if (res < 0) { return res; } diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index a454780e1..9f1cf7e8f 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -20,9 +20,9 @@ static void assert_equals(const std::string & expected, const std::string & actu cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call */ -static void test_parse_tool_call(const json & tools, const std::string & chat_template, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) { +static void test_parse_tool_call(llama_tool_call_style style, const json & tools, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) { std::cout << "# Testing: " << input << std::endl << std::flush; - auto result = parse_tool_calls(tools, chat_template, input); + auto result = parse_tool_calls(style, tools, input); assert_equals(expected_content, result.content); auto tool_calls = json::array(); for (const auto & tc : result.tool_calls) { @@ -59,8 +59,7 @@ int main() { {"tools", tools} }; - std::string hermes_2_pro_like_tmpl = "Hermes 2 Pro template should have inside it"; - test_parse_tool_call(tools, hermes_2_pro_like_tmpl, + test_parse_tool_call(llama_tool_call_style::Hermes2Pro, tools, "{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}", "", json {{ @@ -72,8 +71,7 @@ int main() { }} }}); - std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it"; - test_parse_tool_call(tools, functionary_v3_like_tmpl, + test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools, ">>>ipython\n{\"code\": \"print('Hello, world!')\"}", "", json {{ @@ -84,7 +82,7 @@ int main() { }).dump()} }} }}); - test_parse_tool_call(tools, functionary_v3_like_tmpl, + test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama3, tools, ">>>test\n{ } \n ", "", json {{ @@ -94,8 +92,7 @@ int main() { }} }}); - std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some {...} inside it"; - test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl, + test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools, "Hell{\"arg1\": 1}o, world{\"arg2\": 2}!", "Hello, world!", json { @@ -116,7 +113,7 @@ int main() { }} }, }); - test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl, + test_parse_tool_call(llama_tool_call_style::FunctionaryV3Llama31, tools, "{ } ", " ", json {{ @@ -126,8 +123,7 @@ int main() { }} }}); - std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it"; - test_parse_tool_call(tools, llama_3_1_like_tmpl, + test_parse_tool_call(llama_tool_call_style::Llama31, tools, "<|python_tag|>this could be anything", "", json {{ @@ -138,7 +134,7 @@ int main() { }).dump()} }} }}); - test_parse_tool_call(tools, llama_3_1_like_tmpl, + test_parse_tool_call(llama_tool_call_style::Llama31, tools, "I'm thinking<|python_tag|>", "I'm thinking", json {{ @@ -147,7 +143,7 @@ int main() { {"arguments", (json {{"code", ""}}).dump()} }} }}); - test_parse_tool_call(tools, llama_3_1_like_tmpl, + test_parse_tool_call(llama_tool_call_style::Llama31, tools, "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", "", json {{ @@ -158,7 +154,7 @@ int main() { }).dump()} }} }}); - test_parse_tool_call(tools, llama_3_1_like_tmpl, + test_parse_tool_call(llama_tool_call_style::Llama31, tools, "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array());