mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
tool-call
: support Functionary v3 vs. v3-llama3.1 variants
This commit is contained in:
parent
41103c0ed6
commit
4706bdbae1
@ -12,11 +12,18 @@
|
|||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
static bool needs_functionary_3_2_tool_call(const std::string & chat_template) {
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3.llama3.txt
|
||||||
|
static bool needs_functionary_v3_tool_call(const std::string & chat_template) {
|
||||||
return chat_template.find("<|start_header_id|>") != std::string::npos
|
return chat_template.find("<|start_header_id|>") != std::string::npos
|
||||||
&& chat_template.find(">>>all") != std::string::npos;
|
&& chat_template.find(">>>all") != std::string::npos;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
|
static bool needs_functionary_v3_llama_3_1_tool_call(const std::string & chat_template) {
|
||||||
|
return chat_template.find("<|start_header_id|>") != std::string::npos
|
||||||
|
&& chat_template.find("<function=") != std::string::npos;
|
||||||
|
}
|
||||||
|
|
||||||
static bool needs_llama_3_1_tool_call(const std::string & chat_template) {
|
static bool needs_llama_3_1_tool_call(const std::string & chat_template) {
|
||||||
return chat_template.find("<|start_header_id|>") != std::string::npos
|
return chat_template.find("<|start_header_id|>") != std::string::npos
|
||||||
&& chat_template.find("<|python_tag|>") != std::string::npos;
|
&& chat_template.find("<|python_tag|>") != std::string::npos;
|
||||||
@ -148,8 +155,42 @@ static llama_tool_calls parse_llama_3_1_tool_calls(const json & tools, const std
|
|||||||
return {input, {}};
|
return {input, {}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static llama_tool_calls parse_functionary_v3_llama_3_1_tool_calls(const std::string& input) {
|
||||||
|
static std::regex function_regex(R"(<function=(\w+)>)");
|
||||||
|
static std::regex close_regex(R"(</function>)");
|
||||||
|
std::smatch match;
|
||||||
|
|
||||||
static llama_tool_calls parse_functionary_3_2_tool_calls(const std::string& input) {
|
llama_tool_calls result;
|
||||||
|
auto end = input.end();
|
||||||
|
auto it = input.begin();
|
||||||
|
|
||||||
|
while (it != end) {
|
||||||
|
std::sregex_iterator rend;
|
||||||
|
std::sregex_iterator rit(it, end, function_regex);
|
||||||
|
if (rit == rend) {
|
||||||
|
result.content += std::string(it, end);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
result.content += std::string(it, rit->prefix().second);
|
||||||
|
it = rit->suffix().first;
|
||||||
|
|
||||||
|
auto name = rit->str(1);
|
||||||
|
|
||||||
|
json arguments;
|
||||||
|
if (!parse_json(it, end, arguments)) {
|
||||||
|
throw std::runtime_error("Failed to parse json tool call arguments");
|
||||||
|
}
|
||||||
|
if (!std::regex_search(it, end, match, close_regex)) {
|
||||||
|
throw std::runtime_error("Malformed input, missing closing pattern");
|
||||||
|
}
|
||||||
|
it = match.suffix().first;
|
||||||
|
result.tool_calls.push_back({name, arguments.dump()});
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
static llama_tool_calls parse_functionary_v3_tool_calls(const std::string& input) {
|
||||||
static std::regex python_tag_regex(R"(>>>(\w+)\n((?!>>>)[\s\S\n]*))");
|
static std::regex python_tag_regex(R"(>>>(\w+)\n((?!>>>)[\s\S\n]*))");
|
||||||
std::smatch match;
|
std::smatch match;
|
||||||
llama_tool_calls result;
|
llama_tool_calls result;
|
||||||
@ -172,8 +213,10 @@ llama_tool_calls parse_tool_calls(const json & tools, const std::string & chat_t
|
|||||||
return parse_hermes_tool_calls(input);
|
return parse_hermes_tool_calls(input);
|
||||||
} else if (needs_llama_3_1_tool_call(chat_template)) {
|
} else if (needs_llama_3_1_tool_call(chat_template)) {
|
||||||
return parse_llama_3_1_tool_calls(tools, input);
|
return parse_llama_3_1_tool_calls(tools, input);
|
||||||
} else if (needs_functionary_3_2_tool_call(chat_template)) {
|
} else if (needs_functionary_v3_tool_call(chat_template)) {
|
||||||
return parse_functionary_3_2_tool_calls(input);
|
return parse_functionary_v3_tool_calls(input);
|
||||||
|
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) {
|
||||||
|
return parse_functionary_v3_llama_3_1_tool_calls(input);
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error("Unsupported chat template for tool calls");
|
throw std::runtime_error("Unsupported chat template for tool calls");
|
||||||
}
|
}
|
||||||
@ -187,7 +230,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
|||||||
{
|
{
|
||||||
llama_tool_call_handler handler;
|
llama_tool_call_handler handler;
|
||||||
|
|
||||||
if (needs_functionary_3_2_tool_call(chat_template)) {
|
if (needs_functionary_v3_tool_call(chat_template)) {
|
||||||
// MeetKaiFunctionary_3_2
|
// MeetKaiFunctionary_3_2
|
||||||
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
// >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}...
|
||||||
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
// Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar
|
||||||
@ -208,6 +251,25 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
|||||||
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
});
|
});
|
||||||
// handler.parser = parse_functionary_3_2_tool_calls;
|
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||||
|
} else if (needs_functionary_v3_llama_3_1_tool_call(chat_template)) {
|
||||||
|
// https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt
|
||||||
|
handler.grammar = build_grammar([&](const llama_grammar_builder & builder) {
|
||||||
|
std::vector<std::string> tool_rules;
|
||||||
|
for (size_t i = 0, n = tools.size(); i < n; i++) {
|
||||||
|
auto & tool = tools[i];
|
||||||
|
const auto & function = tool["function"];
|
||||||
|
std::string name = function["name"];
|
||||||
|
auto parameters = function["parameters"];
|
||||||
|
auto tool_rule = builder.add_rule(name + "-call", "\"<function=" + name + ">\" " + builder.add_schema(name + "-args", parameters) + " \"</function>\"");
|
||||||
|
tool_rules.push_back(tool_rule);
|
||||||
|
}
|
||||||
|
auto tool_call = builder.add_rule("tool_call", join(tool_rules.begin(), tool_rules.end(), " | ")) + " space";
|
||||||
|
builder.add_rule("root", parallel_tool_calls ? "(" + tool_call + ")+" : tool_call);
|
||||||
|
if (allow_content) {
|
||||||
|
handler.grammar_trigger_words.push_back("<function=");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
// handler.parser = parse_functionary_3_2_tool_calls;
|
||||||
} else if (needs_hermes_pro_tool_call(chat_template)) {
|
} else if (needs_hermes_pro_tool_call(chat_template)) {
|
||||||
// NousResearchHermesPro_2
|
// NousResearchHermesPro_2
|
||||||
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
// (content)?(<tool_call>{"name": "foo", "arguments": {"a": 1}}</tool_call>)*
|
||||||
|
@ -21,6 +21,7 @@ static void assert_equals(const std::string & expected, const std::string & actu
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
static void test_parse_tool_call(const json & tools, const std::string & chat_template, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
|
static void test_parse_tool_call(const json & tools, const std::string & chat_template, const std::string & input, const std::string & expected_content, const json & expected_tool_calls) {
|
||||||
|
std::cout << "# Testing: " << input << std::endl << std::flush;
|
||||||
auto result = parse_tool_calls(tools, chat_template, input);
|
auto result = parse_tool_calls(tools, chat_template, input);
|
||||||
assert_equals(expected_content, result.content);
|
assert_equals(expected_content, result.content);
|
||||||
auto tool_calls = json::array();
|
auto tool_calls = json::array();
|
||||||
@ -71,8 +72,8 @@ int main() {
|
|||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
|
|
||||||
std::string functionary_3_2_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it";
|
std::string functionary_v3_like_tmpl = "Functionary 3.2 template should have <|start_header_id|> and then some >>>all inside it";
|
||||||
test_parse_tool_call(tools, functionary_3_2_like_tmpl,
|
test_parse_tool_call(tools, functionary_v3_like_tmpl,
|
||||||
">>>ipython\nprint('Hello, world!')",
|
">>>ipython\nprint('Hello, world!')",
|
||||||
"",
|
"",
|
||||||
json {{
|
json {{
|
||||||
@ -84,6 +85,29 @@ int main() {
|
|||||||
}}
|
}}
|
||||||
}});
|
}});
|
||||||
|
|
||||||
|
std::string functionary_v3_llama_3_1_like_tmpl = "Functionary 3.2 template for llama 3.1 should have <|start_header_id|> and then some <function=foo>{...}</function> inside it";
|
||||||
|
test_parse_tool_call(tools, functionary_v3_llama_3_1_like_tmpl,
|
||||||
|
"Hell<function=foo>{\"arg1\": 1}</function>o, world<function=bar>{\"arg2\": 2}</function>!",
|
||||||
|
"Hello, world!",
|
||||||
|
json {
|
||||||
|
{
|
||||||
|
{"function", {
|
||||||
|
{"name", "foo"},
|
||||||
|
{"arguments", (json {
|
||||||
|
{"arg1", 1}
|
||||||
|
}).dump()}
|
||||||
|
}}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{"function", {
|
||||||
|
{"name", "bar"},
|
||||||
|
{"arguments", (json {
|
||||||
|
{"arg2", 2}
|
||||||
|
}).dump()}
|
||||||
|
}}
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it";
|
std::string llama_3_1_like_tmpl = "Llama 3.1 template should have <|start_header_id|> and <|python_tag|> inside it";
|
||||||
test_parse_tool_call(tools, llama_3_1_like_tmpl,
|
test_parse_tool_call(tools, llama_3_1_like_tmpl,
|
||||||
"<|python_tag|>this could be anything",
|
"<|python_tag|>this could be anything",
|
||||||
|
Loading…
Reference in New Issue
Block a user