diff --git a/Makefile b/Makefile index e5e7e62fa..25f5db074 100644 --- a/Makefile +++ b/Makefile @@ -55,6 +55,7 @@ TEST_TARGETS = \ tests/test-grammar-parser \ tests/test-json-schema-to-grammar \ tests/test-minja \ + tests/test-tool-call \ tests/test-llama-grammar \ tests/test-log \ tests/test-model-load-cancel \ @@ -940,7 +941,8 @@ OBJ_COMMON = \ common/sampling.o \ common/train.o \ common/build-info.o \ - common/json-schema-to-grammar.o + common/json-schema-to-grammar.o \ + common/tool-call.o OBJ_ALL = $(OBJ_GGML) $(OBJ_LLAMA) $(OBJ_COMMON) @@ -1201,6 +1203,11 @@ common/json-schema-to-grammar.o: \ common/json-schema-to-grammar.h $(CXX) $(CXXFLAGS) -c $< -o $@ +common/tool-call.o: \ + common/tool-call.cpp \ + common/tool-call.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + common/train.o: \ common/train.cpp \ common/train.h @@ -1574,6 +1581,11 @@ tests/test-antiprompts: tests/test-antiprompts.cpp \ $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) +tests/test-tool-call: tests/test-tool-call.cpp \ + $(OBJ_ALL) + $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) + $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) + tests/test-minja: tests/test-minja.cpp \ $(OBJ_ALL) $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 34c3620c2..c132e8333 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -67,6 +67,7 @@ add_library(${TARGET} STATIC ngram-cache.h sampling.cpp sampling.h + tool-call.cpp train.cpp train.h ) diff --git a/common/tool-call.cpp b/common/tool-call.cpp new file mode 100644 index 000000000..3bbec002b --- /dev/null +++ b/common/tool-call.cpp @@ -0,0 +1,274 @@ +#include "tool-call.h" +#include "json-schema-to-grammar.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +static bool needs_functionary_3_2_tool_call(const std::string & chat_template) { + return chat_template.find("<|start_header_id|>") != std::string::npos + && chat_template.find(">>>all") != std::string::npos; +} + +static bool needs_llama_3_1_tool_call(const std::string & chat_template) { + return chat_template.find("<|start_header_id|>") != std::string::npos + && chat_template.find("<|python_tag|>") != std::string::npos; +} + +static bool needs_hermes_pro_tool_call(const std::string & chat_template) { + return chat_template.find("") != std::string::npos; +} + +static bool parse_json(std::string::const_iterator & it, const std::string::const_iterator & end, json & out) { + // // https://json.nlohmann.me/features/parsing/sax_interface/ + struct json_error_locator : public nlohmann::json_sax { + std::size_t position; + bool found_error; + + bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { + // LOG_WARNING("JSON error (Expected)", {{"position", position}, {"last_token", last_token}, {"error", ex.what()}}); + this->position = position - 1; + this->found_error = true; + return false; + } + bool null() override { return true; } + bool boolean(bool) override { return true; } + bool number_integer(number_integer_t) override { return true; } + bool number_unsigned(number_unsigned_t) override { return true; } + bool number_float(number_float_t, const string_t &) override { return true; } + bool string(string_t &) override { return true; } + bool binary(binary_t &) override { return true; } + bool start_object(std::size_t) override { return true; } + bool key(string_t &) override { return true; } + bool end_object() override { return true; } + bool start_array(std::size_t) override { return true; } + bool end_array() override { return true; } + }; + json_error_locator err_loc; + json::sax_parse(it, end, &err_loc); + + std::string::const_iterator temptative_end; + if (err_loc.found_error) { + temptative_end = it + err_loc.position; + } else { + temptative_end = end; + } + std::string json_sub {it, it + err_loc.position}; + // LOG_WARNING("Parsing json", {{"json_sub", json_sub}}); + try { + out = json::parse(json_sub); + it = temptative_end; + return true; + } catch (const std::exception & e) { + // LOG_WARNING("Failed to parse tool call", {{"json_sub", json_sub}, {"error", e.what()}}); + return false; + } +} + +static llama_tool_calls parse_hermes_tool_calls(const std::string& input) { + try { + std::regex start_pattern(R"([\n\s]*)"); + std::regex middle_pattern(R"([\n\s]*[\n\s]*)"); + std::regex end_pattern(R"([\n\s]*[\n\s]*$)"); + + auto end = input.end(); + std::sregex_iterator rend; + std::sregex_iterator rit(input.begin(), end, start_pattern); + if (rit == rend) { + return {input, {}}; + } + + llama_tool_calls result; + result.content = rit->prefix(); + + auto it = rit->suffix().first; + while (it != end) { + json call; + if (!parse_json(it, end, call)) { + throw std::runtime_error("Failed to parse json tool call"); + } + result.tool_calls.push_back({ + call["name"], + call["arguments"].dump(), + }); + rit = {it, end, middle_pattern}; + if (rit != rend) { + it = rit->suffix().first; + } else { + rit = {it, end, end_pattern}; + if (rit == rend) { + throw std::runtime_error("Malformed input, missing "); + } + break; + } + } + return result; + } catch (const std::exception & e) { + return {input, {}}; + } +} + +static llama_tool_calls parse_llama_3_1_tool_calls(const json & tools, const std::string& input) { + static std::regex python_tag_regex(R"(<\|python_tag\|>([\s\S\n]*)$)"); + std::smatch match; + if (std::regex_search(input, match, python_tag_regex)) { + return { + match.prefix().str(), { + {"ipython", (json {{"code", match[1].str()}}).dump()}, + } + }; + } + try { + auto call = json::parse(input); + // Only treat JSON as a tool call if it has a name attribute that matches any of the tools specified in the request. + // There doesn't seem to be any better way to detect a tool call. + if (call.contains("name") && call["name"].is_string()) { + std::string name = call["name"]; + for (const auto & tool : tools) { + if (tool.at("function").at("name") == name) { + return { + "", + { + {name, call["parameters"].dump()}, + } + }; + } + } + } + } catch (const std::exception & e) { + // Do nothing + } + return {input, {}}; +} + + +static llama_tool_calls parse_functionary_3_2_tool_calls(const std::string& input) { + static std::regex python_tag_regex(R"(>>>(\w+)\n((?!>>>)[\s\S\n]*))"); + std::smatch match; + llama_tool_calls result; + std::string content; + std::string in = input; + while (std::regex_search(in, match, python_tag_regex)) { + content += match.prefix().str(); + result.tool_calls.push_back({ + match[1].str(), + (json {{"code", match[2].str()}}).dump(), + }); + in = match.suffix().str(); + } + result.content = content + in; + return result; +} + +llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_template, const std::string& input) { + if (needs_hermes_pro_tool_call(chat_template)) { + return parse_hermes_tool_calls(input); + } else if (needs_llama_3_1_tool_call(chat_template)) { + return parse_llama_3_1_tool_calls(tools, input); + } else if (needs_functionary_3_2_tool_call(chat_template)) { + return parse_functionary_3_2_tool_calls(input); + } else { + throw std::runtime_error("Unsupported chat template for tool calls"); + } +} + +llama_tool_call_handler llama_tool_call_handler_init( + const std::string & chat_template, + bool allow_content, + bool parallel_tool_calls, + const nlohmann::ordered_json & tools) +{ + llama_tool_call_handler handler; + + if (needs_functionary_3_2_tool_call(chat_template)) { + // MeetKaiFunctionary_3_2 + // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... + // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + std::vector tool_rules; + for (size_t i = 0, n = tools.size(); i < n; i++) { + auto & tool = tools[i]; + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + auto tool_rule = builder.add_rule(name + "-call", "\">>>" + name + "\\n\" " + builder.add_schema(name + "-args", parameters)); + tool_rules.push_back(tool_rule); + if (allow_content) { + handler.grammar_trigger_words.push_back(">>>" + name + "\n"); + } + } + auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space"; + builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + }); + // handler.parser = parse_functionary_3_2_tool_calls; + } else if (needs_hermes_pro_tool_call(chat_template)) { + // NousResearchHermesPro_2 + // (content)?({"name": "foo", "arguments": {"a": 1}})* + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + std::vector tool_rules; + for (const auto & tool : tools) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + tool_rules.push_back(builder.add_schema(name + "-call", { + {"type", "object"}, + {"properties", json { + {"name", json {{"const", name}}}, + {"arguments", parameters}, + }}, + {"required", json::array({"name", "arguments"})}, + })); + } + + auto tool_call = "\"\" " + builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " \"\" space"; + builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); + if (allow_content) { + handler.grammar_trigger_words.push_back(""); + } + }); + } else if (needs_llama_3_1_tool_call(chat_template)) { + handler.grammar = build_grammar([&](const llama_grammar_builder & builder) { + static std::vector builtin_tools {"wolfram_alpha", "brave_search"}; + std::vector tool_rules; + + for (const auto & tool : tools) { + const auto & function = tool["function"]; + std::string name = function["name"]; + auto parameters = function["parameters"]; + builder.resolve_refs(parameters); + if (name == "ipython" || std::find(builtin_tools.begin(), builtin_tools.end(), name) != builtin_tools.end()) { + tool_rules.push_back(builder.add_rule("ipython-call", "\"<|python_tag|>\" .*")); + if (allow_content) { + handler.grammar_trigger_words.push_back("<|python_tag|>"); + } + } else { + //"<|start_header_id|>assistant<|end_header_id|>\n\n{\"name\": \"" + name + "\", " + + tool_rules.push_back( + builder.add_rule( + name + "-call", + "\"\\n{\\\"name\\\": " + name + "\\\", \\\"parameters\\\", \" " + + builder.add_schema(name + "-args", parameters) + + " \"}\"")); + if (allow_content) { + handler.grammar_trigger_words.push_back("\n{\"" + name + "\""); + } + } + } + + builder.add_rule("root", join(tool_rules.begin(), tool_rules.end(), " | ")); + }); + handler.additional_stop_words.push_back("<|eom_id|>"); + } else { + // TODO: generic thoughtful schema. + throw std::runtime_error("Unsupported tool call style!"); + } + return handler; +} diff --git a/common/tool-call.h b/common/tool-call.h new file mode 100644 index 000000000..fd30f1f7c --- /dev/null +++ b/common/tool-call.h @@ -0,0 +1,30 @@ +#pragma once + +#include "ggml.h" +// Change JSON_ASSERT from assert() to GGML_ASSERT: +#define JSON_ASSERT GGML_ASSERT +#include "json.hpp" + +struct llama_tool_call { + std::string name; + std::string arguments; +}; + +struct llama_tool_calls { + std::string content; + std::vector tool_calls; +}; + +struct llama_tool_call_handler { + std::string grammar; + std::vector grammar_trigger_words; + std::vector additional_stop_words; +}; + +llama_tool_calls parse_tool_calls(const nlohmann::ordered_json & tools, const std::string & chat_template, const std::string& input); + +llama_tool_call_handler llama_tool_call_handler_init( + const std::string & chat_template, + bool allow_content, + bool parallel_tool_calls, + const nlohmann::ordered_json & tools); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 86705386a..d7ffed8b3 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -124,6 +124,7 @@ llama_target_and_test(test-barrier.cpp) llama_target_and_test(test-backend-ops.cpp) llama_target_and_test(test-antiprompts.cpp) llama_target_and_test(test-minja.cpp) +llama_target_and_test(test-tool-call.cpp) llama_target_and_test(test-rope.cpp) diff --git a/tests/test-tool-call.cpp b/tests/test-tool-call.cpp new file mode 100644 index 000000000..0a2a09416 --- /dev/null +++ b/tests/test-tool-call.cpp @@ -0,0 +1,124 @@ +#include "tool-call.h" + +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +static void assert_equals(const std::string & expected, const std::string & actual) { + if (expected != actual) { + std::cerr << "Expected: " << expected << std::endl; + std::cerr << "Actual: " << actual << std::endl; + std::cerr << std::flush; + throw std::runtime_error("Test failed"); + } +} + +/* + cmake -B build -DLLAMA_CURL=1 -DCMAKE_BUILD_TYPE=Release && cmake --build build -t test-tool-call -j && ./build/bin/test-tool-call +*/ + +static void test_parse_tool_call(const json & tools, const std::string & chat_template, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) { + auto result = parse_tool_calls(tools, chat_template, input); + assert_equals(expected_content, result.content); + auto tool_calls = json::array(); + for (const auto & tc : result.tool_calls) { + tool_calls.push_back({ + {"function", { + {"name", tc.name}, + {"arguments", tc.arguments}, + }} + }); + } + assert_equals(expected_tool_calls.dump(), tool_calls.dump()); +} +int main() { + json tools = json::parse(R"([ + { + "type": "function", + "function": { + "name": "special_function", + "description": "I'm special", + "parameters": { + "type": "object", + "properties": { + "arg1": { + "type": "string", + "description": "The arg." + } + }, + "required": ["arg1"] + } + } + } + ])"); + json request = { + {"tools", tools} + }; + + std::string hermes_2_pro_like_tmpl = "Hermes 2 Pro template should have inside it"; + test_parse_tool_call(tools, hermes_2_pro_like_tmpl, + "{\"name\": \"foo\", \"arguments\": {\"bar\": 1}}", + "", + json {{ + {"function", { + {"name", "foo"}, + {"arguments", (json { + {"bar", 1} + }).dump()} + }} + }}); + + std::string functionary_3_2_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it"; + test_parse_tool_call(tools, functionary_3_2_like_tmpl, + ">>>ipython\nprint('Hello, world!')", + "", + json {{ + {"function", { + {"name", "ipython"}, + {"arguments", (json { + {"code", "print('Hello, world!')"} + }).dump()} + }} + }}); + + std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it"; + test_parse_tool_call(tools, llama_3_1_like_tmpl, + "<|python_tag|>this could be anything", + "", + json {{ + {"function", { + {"name", "ipython"}, + {"arguments", (json { + {"code", "this could be anything"} + }).dump()} + }} + }}); + test_parse_tool_call(tools, llama_3_1_like_tmpl, + "I'm thinking<|python_tag|>", + "I'm thinking", + json {{ + {"function", { + {"name", "ipython"}, + {"arguments", (json {{"code", ""}}).dump()} + }} + }}); + test_parse_tool_call(tools, llama_3_1_like_tmpl, + "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", + "", + json {{ + {"function", { + {"name", "special_function"}, + {"arguments", (json { + {"arg1", 1} + }).dump()} + }} + }}); + test_parse_tool_call(tools, llama_3_1_like_tmpl, + "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", + "{\"name\": \"unknown_function\", \"arguments\": {\"arg1\": 1}}", json::array()); + + return 0; +} \ No newline at end of file