diff --git a/common/tool-call.cpp b/common/tool-call.cpp index 3bbec002b..7355a887b 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -12,11 +12,18 @@ using json = nlohmann::ordered_json; -static bool needs_functionary_3_2_tool_call(const std::string & chat_template) { +// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3.llama3.txt +static bool needs_functionary_v3_tool_call(const std::string & chat_template) { return chat_template.find("<|start_header_id|>") != std::string::npos && chat_template.find(">>>all") != std::string::npos; } +// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt +static bool needs_functionary_v3_llama_3_1_tool_call(const std::string & chat_template) { + return chat_template.find("<|start_header_id|>") != std::string::npos + && chat_template.find("") != std::string::npos && chat_template.find("<|python_tag|>") != std::string::npos; @@ -148,8 +155,42 @@ static llama_tool_calls parse_llama_3_1_tool_calls(const json & tools, const std return {input, {}}; } +static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const std::string& input) { + static std::regex function_regex(R"()"); + static std::regex close_regex(R"()"); + std::smatch match; -static llama_tool_calls parse_functionary_3_2_tool_calls(const std::string& input) { + llama_tool_calls result; + auto end = input.end(); + auto it = input.begin(); + + while (it != end) { + std::sregex_iterator rend; + std::sregex_iterator rit(it, end, function_regex); + if (rit == rend) { + result.content += std::string(it, end); + break; + } + + result.content += std::string(it, rit->prefix().second); + it = rit->suffix().first; + + auto name = rit->str(1); + + json arguments; + if (!parse_json(it, end, arguments)) { + throw std::runtime_error("Failed to parse json tool call arguments"); + } + if (!std::regex_search(it, end, match, close_regex)) { + throw std::runtime_error("Malformed input, missing closing pattern"); + } + it = match.suffix().first; + result.tool_calls.push_back({name, arguments.dump()}); + } + return result; +} + +static llama_tool_calls parse_functionary_v3_tool_calls(const std::string& input) { static std::regex python_tag_regex(R"(>>>(\w+)\n((?!>>>)[\s\S\n]*))"); std::smatch match; llama_tool_calls result; @@ -172,8 +213,10 @@ llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_t return parse_hermes_tool_calls(input); } else if (needs_llama_3_1_tool_call(chat_template)) { return parse_llama_3_1_tool_calls(tools, input); - } else if (needs_functionary_3_2_tool_call(chat_template)) { - return parse_functionary_3_2_tool_calls(input); + } else if (needs_functionary_v3_tool_call(chat_template)) { + return parse_functionary_v3_tool_calls(input); + } else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) { + return parse_functionary_v3_llama_3_1_tool_calls(input); } else { throw std::runtime_error("Unsupported chat template for tool calls"); } @@ -187,7 +230,7 @@ llama_tool_call_handler llama_tool_call_handler_init( { llama_tool_call_handler handler; - if (needs_functionary_3_2_tool_call(chat_template)) { + if (needs_functionary_v3_tool_call(chat_template)) { // MeetKaiFunctionary_3_2 // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar @@ -208,6 +251,25 @@ llama_tool_call_handler llama_tool_call_handler_init( builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); }); // handler.parser = parse_functionary_3_2_tool_calls; + } else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) { + // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + std::vector tool_rules; + for (size_t i = 0, n = tools.size(); i < n; i++) { + auto & tool = tools[i]; + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto tool_rule = builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\""); + tool_rules.push_back(tool_rule); + } + auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; + builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + if (allow_content) { + handler.grammar_trigger_words.push_back("{"name": "foo", "arguments": {"a": 1}})* diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp index 0a2a09416..fd0eeed01 100644 --- a/tests/test-tool-call.cpp +++ b/tests/test-tool-call.cpp @@ -21,6 +21,7 @@ static void assert_equals(const std::string & expected, const std::string & actu */ static void test_parse_tool_call(const json & tools, const std::string & chat_template, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) { + std::cout << "# Testing: " << input << std::endl << std::flush; auto result = parse_tool_calls(tools, chat_template, input); assert_equals(expected_content, result.content); auto tool_calls = json::array(); @@ -71,8 +72,8 @@ int main() { }} }}); - std::string functionary_3_2_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it"; - test_parse_tool_call(tools, functionary_3_2_like_tmpl, + std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it"; + test_parse_tool_call(tools, functionary_v3_like_tmpl, ">>>ipython\nprint('Hello, world!')", "", json {{ @@ -84,6 +85,29 @@ int main() { }} }}); + std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some {...} inside it"; + test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl, + "Hell{\"arg1\": 1}o, world{\"arg2\": 2}!", + "Hello, world!", + json { + { + {"function", { + {"name", "foo"}, + {"arguments", (json { + {"arg1", 1} + }).dump()} + }} + }, + { + {"function", { + {"name", "bar"}, + {"arguments", (json { + {"arg2", 2} + }).dump()} + }} + }, + }); + std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it"; test_parse_tool_call(tools, llama_3_1_like_tmpl, "<|python_tag|>this could be anything",