#include "tool-call.h" #include "json-schema-to-grammar.h" #include #include #include #include #include #include #include #include #include using json = nlohmann::ordered_json; llama_tool_call_style llama_tool_call_style_detect(const minja::chat_template & chat_template) { const auto & src = chat_template.source(); if (src.find("") != std::string::npos) { return Hermes2Pro; } else if (src.find(">>>all") != std::string::npos) { return FunctionaryV3Llama3; } else if (src.find("<|start_header_id|>") != std::string::npos && src.find("ipython<|end_header_id|>") != std::string::npos) { if (src.find("<|python_tag|>") != std::string::npos) { return Llama31; } else { return Llama32; } } else if (src.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { return CommandRPlus; } else { return Generic; } } static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { // // https://json.nlohmann.me/features/parsing/sax_interface/ struct json_error_locator : public nlohmann::json_sax { std::size_t position; bool found_error; json_error_locator() : position(0), found_error(false) {} bool parse_error(std::size_t position, const std::string &, const json::exception &) override { this->position = position - 1; this->found_error = true; return false; } bool null() override { return true; } bool boolean(bool) override { return true; } bool number_integer(number_integer_t) override { return true; } bool number_unsigned(number_unsigned_t) override { return true; } bool number_float(number_float_t, const string_t &) override { return true; } bool string(string_t &) override { return true; } bool binary(binary_t &) override { return true; } bool start_object(std::size_t) override { return true; } bool key(string_t &) override { return true; } bool end_object() override { return true; } bool start_array(std::size_t) override { return true; } bool end_array() override { return true; } }; json_error_locator err_loc; json::sax_parse(it, end, &err_loc); std::string::const_iterator temptative_end; if (err_loc.found_error) { temptative_end = it + err_loc.position; } else { temptative_end = end; } std::string json_sub {it, temptative_end}; try { out = json::parse(json_sub); it = temptative_end; return true; } catch (const std::exception &) { return false; } } /** * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. * Aggregates the prefix, suffix and in-between text into the content. */ static llama_tool_calls parse_json_tool_calls(const json & tools, const std::string& input, const std::regex & function_regex, const std::regex & close_regex, bool check_names) { std::smatch match; llama_tool_calls result; auto end = input.end(); auto it = input.begin(); std::unordered_set tool_names; if (check_names) { for (const auto & tool : tools) { if (tool.contains("type") && tool["type"] == "function") { tool_names.insert(tool["function"]["name"]); } } } while (it != end) { std::sregex_iterator rend; std::sregex_iterator rit(it, end, function_regex); if (rit == rend) { result.content += std::string(it, end); break; } auto name = rit->str(1); if (check_names && tool_names.find(name) == tool_names.end()) { result.content += std::string(it, rit->suffix().first); break; } result.content += std::string(it, rit->prefix().second); it = rit->suffix().first; json arguments; if (!parse_json(it, end, arguments)) { throw std::runtime_error("Failed to parse json tool call arguments"); } if (!std::regex_search(it, end, match, close_regex)) { throw std::runtime_error("Malformed input, missing closing pattern"); } it = match.suffix().first; result.tool_calls.push_back({name, arguments.dump()}); } return result; } static llama_tool_calls parse_hermes_tool_calls(const std::string& input) { try { std::regex start_pattern(R"([\n\s]*)"); std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); auto end = input.end(); std::sregex_iterator rend; std::sregex_iterator rit(input.begin(), end, start_pattern); if (rit == rend) { return {input, {}}; } llama_tool_calls result; result.content = rit->prefix(); auto it = rit->suffix().first; while (it != end) { json call; if (!parse_json(it, end, call)) { throw std::runtime_error("Failed to parse json tool call"); } result.tool_calls.push_back({ call["name"], call["arguments"].dump(), }); rit = {it, end, middle_pattern}; if (rit != rend) { it = rit->suffix().first; } else { rit = {it, end, end_pattern}; if (rit == rend) { throw std::runtime_error("Malformed input, missing "); } break; } } return result; } catch (const std::exception & e) { return {input, {}}; } } static llama_tool_calls parse_llama_3_tool_calls(const json & tools, const std::string& input, bool allow_python_tag) { if (allow_python_tag) { static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; if (std::regex_search(input, match, python_tag_regex)) { return { match.prefix().str(), { {"ipython", (json {{"code", match[1].str()}}).dump()}, } }; } } static std::regex function_regex("\\{(?:\"type\": \"function\", |[\\s\\n\\r]*)\"name\": \"([^\"]+)\", \"parameters\": "); static std::regex close_regex("\\}"); return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); } static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const json & tools, const std::string& input) { // This version of Functionary still supports the llama 3.1 tool call format for the python tool. static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); std::smatch match; if (std::regex_search(input, match, python_tag_regex)) { return { match.prefix().str(), { {"ipython", (json {{"code", match[1].str()}}).dump()}, } }; } static std::regex function_regex(R"()"); static std::regex close_regex(R"()"); return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ false); } static llama_tool_calls parse_functionary_v3_tool_calls(const json & tools, const std::string& input) { static std::regex function_regex(R"((?:>>>)?(\w+)\n)"); static std::regex close_regex(R"($|(?=>>>))"); return parse_json_tool_calls(tools, input, function_regex, close_regex, /* check_names= */ true); } static llama_tool_calls parse_generic_tool_calls(const std::string& input) { json data = json::parse(input); llama_tool_calls result; if (data.contains("tool_calls")) { for (const auto & tool_call : data["tool_calls"]) { result.tool_calls.push_back({ tool_call["name"], tool_call["arguments"].dump(), }); } } else if (data.contains("tool_call")) { result.tool_calls.push_back({ data["tool_call"]["name"], data["tool_call"]["arguments"].dump(), }); } else if (data.contains("response")) { const auto & response = data["response"]; result.content = response.is_string() ? response.get() : response.dump(2); } return result; } llama_tool_calls parse_tool_calls(llama_tool_call_style style, const json & tools, const std::string& input) { switch (style) { case llama_tool_call_style::Generic: return parse_generic_tool_calls(input); case llama_tool_call_style::Llama31: return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ true); case llama_tool_call_style::Llama32: return parse_llama_3_tool_calls(tools, input, /* parse_llama_3_tool_calls= */ false); case llama_tool_call_style::FunctionaryV3Llama3: return parse_functionary_v3_tool_calls(tools, input); case llama_tool_call_style::FunctionaryV3Llama31: return parse_functionary_v3_llama_3_1_tool_calls(tools, input); case llama_tool_call_style::Hermes2Pro: return parse_hermes_tool_calls(input); default: throw std::runtime_error("Unsupported tool call style"); } } llama_tool_call_handler llama_tool_call_handler_init( llama_tool_call_style style, const minja::chat_template & tmpl, bool allow_content, bool parallel_tool_calls, const nlohmann::ordered_json & messages, const nlohmann::ordered_json & tools, const nlohmann::ordered_json & json_schema) { llama_tool_call_handler handler; switch (style) { case llama_tool_call_style::Generic: { auto tool_call_schemas = json::array(); for (const auto & tool : tools) { if (tool["type"] != "function") { continue; } const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; tool_call_schemas.emplace_back(json { {"type", "object"}, {"properties", { {"name", { {"type", "string"}, {"const", name}, }}, {"arguments", parameters}, }}, {"required", json::array({"name", "arguments"})}, }); } const auto tool_call = json {{"anyOf", tool_call_schemas}}; const auto schema = json { {"anyOf", json::array({ parallel_tool_calls ? json { {"type", "object"}, {"properties", { {"tool_calls", { {"type", "array"}, {"items", tool_call} }}, }}, {"required", json::array({"tool_calls"})}, } : json { {"type", "object"}, {"properties", { {"tool_call", tool_call}, }}, {"required", json::array({"tool_call"})}, }, { {"type", "object"}, {"properties", { {"response", json_schema.is_null() ? json {{"type", "string"}} : json_schema }, }}, }, })} }; handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { builder.add_schema("", schema); }); // TODO: add schema to system prompt. handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); break; } case llama_tool_call_style::Llama31: case llama_tool_call_style::Llama32: { static auto builtin_tools = json {"wolfram_alpha", "brave_search"}; auto uses_python_tag = style == llama_tool_call_style::Llama31; // Technically we should only trigger on `"\n{\"name\": \"" + name + "\""` for each tool name, // but Llama-3.2-3B (and 1B) struggles to output valid tool calls so we're "guiding" it strongly as soon // as it seems to be outputting some JSON. // TODO: make this conditional on a very small model (e.g. 1B / 3B). auto eagerly_match_any_json = style == llama_tool_call_style::Llama32; handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { std::vector tool_rules; for (const auto & tool : tools) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; builder.resolve_refs(parameters); if (uses_python_tag && (name == "ipython" || builtin_tools.contains(name))) { tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); if (allow_content) { handler.grammar_trigger_words.push_back("<|python_tag|>"); } } else { //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + tool_rules.push_back( builder.add_rule( name + "-call", "\"\\n\"? \"{\" ( \"\\\"type\\\": \\\"function\\\", \" | space ) \"\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " + builder.add_schema(name + "-args", parameters) + " \"}\"")); if (allow_content && !eagerly_match_any_json) { handler.grammar_trigger_words.push_back("{\"name\": \"" + name + "\""); // Accommodate most common tool call variations from Llama-3.1-8B and Llama-3.2-3B. // Note that c++11's regex doesn't support partial matches, otherwise it would make // sense to add support for trigger regexes to the antiprompt mechanism. handler.grammar_trigger_words.push_back("{\n\t\"name\": \"" + name + "\""); handler.grammar_trigger_words.push_back("{\n \"name\": \"" + name + "\""); handler.grammar_trigger_words.push_back("{\n \"name\": \"" + name + "\""); handler.grammar_trigger_words.push_back("{\"type\": \"function\", \"name\": \"" + name + "\""); } } } if (allow_content && eagerly_match_any_json) { handler.grammar_trigger_words.push_back("{\""); handler.grammar_trigger_words.push_back("{\n\t\""); handler.grammar_trigger_words.push_back("{\n \""); handler.grammar_trigger_words.push_back("{\n \""); } builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | ")); }); handler.additional_stop_words.push_back("<|eom_id|>"); handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true, { {"builtin_tools", builtin_tools}, }); break; } case llama_tool_call_style::FunctionaryV3Llama3: { // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { std::vector first_tool_rules; std::vector subsequent_tool_rules; for (size_t i = 0, n = tools.size(); i < n; i++) { auto & tool = tools[i]; const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; auto args_rule = builder.add_schema(name + "-args", parameters); first_tool_rules.push_back(builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule)); subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>" + name + "\\n\" " + args_rule)); if (allow_content) { handler.grammar_trigger_words.push_back(name + "\n"); handler.grammar_trigger_words.push_back(">>>" + name + "\n"); } } auto first_rule = builder.add_rule("first_tool_call", join(first_tool_rules.begin(), first_tool_rules.end(), " | ")) + " space"; if (parallel_tool_calls) { auto subsequent_rule = builder.add_rule("subsequent_tool_call", join(subsequent_tool_rules.begin(), subsequent_tool_rules.end(), " | ")) + " space"; builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); } else { builder.add_rule("root", first_rule); } }); handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); // handler.parser = parse_functionary_3_2_tool_calls; break; } case llama_tool_call_style::FunctionaryV3Llama31: { // ./tests/chat/templates/meetkai-functionary-medium-v3.1.jinja // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt // TODO: handle tool {type: code_interpreter} as python handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { std::vector tool_rules; for (size_t i = 0, n = tools.size(); i < n; i++) { auto & tool = tools[i]; const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; if (name == "python") { tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); if (allow_content) { handler.grammar_trigger_words.push_back("<|python_tag|>"); } } else { tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); } } auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); if (allow_content) { handler.grammar_trigger_words.push_back("{"name": "foo", "arguments": {"a": 1}})* handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { std::vector tool_rules; for (const auto & tool : tools) { const auto & function = tool["function"]; std::string name = function["name"]; auto parameters = function["parameters"]; builder.resolve_refs(parameters); tool_rules.push_back(builder.add_schema(name + "-call", { {"type", "object"}, {"properties", json { {"name", json {{"const", name}}}, {"arguments", parameters}, }}, {"required", json::array({"name", "arguments"})}, })); } auto tool_call = "\"\" space " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"\" space"; builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); if (allow_content) { handler.grammar_trigger_words.push_back(""); } }); handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); break; } default: throw std::runtime_error("Unsupported tool call style"); } return handler; }