mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
tool-call
: test tool call style detection
This commit is contained in:
parent
887951beb0
commit
0c85bc7a8f
@ -41,12 +41,19 @@ llama_chat_template::llama_chat_template(const std::string & chat_template, cons
|
||||
_tool_call_style = Hermes2Pro;
|
||||
} else if (chat_template.find(">>>all") != std::string::npos) {
|
||||
_tool_call_style = FunctionaryV3Llama3;
|
||||
} else if (chat_template.find("<|start_header_id|>") != std::string::npos) {
|
||||
if (chat_template.find("<function=") != std::string::npos) {
|
||||
_tool_call_style = FunctionaryV3Llama31;
|
||||
} else if (chat_template.find("<|python_tag|>") != std::string::npos) {
|
||||
} else if (chat_template.find("<|start_header_id|>") != std::string::npos
|
||||
&& chat_template.find("<function=") != std::string::npos) {
|
||||
_tool_call_style = FunctionaryV3Llama31;
|
||||
} else if (chat_template.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
|
||||
if (chat_template.find("<|python_tag|>") != std::string::npos) {
|
||||
_tool_call_style = Llama31;
|
||||
} else {
|
||||
_tool_call_style = Llama32;
|
||||
}
|
||||
} else if (chat_template.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
|
||||
_tool_call_style = CommandRPlus;
|
||||
} else {
|
||||
_tool_call_style = UnknownToolCallStyle;
|
||||
}
|
||||
_template_root = minja::Parser::parse(_chat_template, {
|
||||
/* .trim_blocks = */ true,
|
||||
|
@ -11,9 +11,11 @@ using json = nlohmann::ordered_json;
|
||||
enum llama_tool_call_style {
|
||||
UnknownToolCallStyle,
|
||||
Llama31,
|
||||
Llama32,
|
||||
FunctionaryV3Llama3,
|
||||
FunctionaryV3Llama31,
|
||||
Hermes2Pro,
|
||||
CommandRPlus,
|
||||
};
|
||||
|
||||
class llama_chat_template {
|
||||
|
@ -27,7 +27,8 @@ static std::string filename_without_extension(const std::string & path) {
|
||||
return res;
|
||||
}
|
||||
|
||||
static void assert_equals(const std::string & expected, const std::string & actual) {
|
||||
template <class T>
|
||||
static void assert_equals(const T & expected, const T & actual) {
|
||||
if (expected != actual) {
|
||||
std::cerr << "Expected: " << expected << std::endl;
|
||||
std::cerr << "Actual: " << actual << std::endl;
|
||||
@ -118,6 +119,20 @@ static void test_jinja_templates() {
|
||||
}
|
||||
}
|
||||
|
||||
void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
|
||||
auto tmpl = llama_chat_template(read_file(template_file), "<s>", "</s>");
|
||||
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
|
||||
assert_equals(expected, tmpl.tool_call_style());
|
||||
}
|
||||
|
||||
void test_tool_call_styles() {
|
||||
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31);
|
||||
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3);
|
||||
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
|
||||
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
|
||||
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus);
|
||||
}
|
||||
|
||||
static void test_legacy_templates() {
|
||||
struct test_template {
|
||||
std::string name;
|
||||
@ -330,6 +345,7 @@ int main(void) {
|
||||
if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) {
|
||||
fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m");
|
||||
} else {
|
||||
test_tool_call_styles();
|
||||
test_jinja_templates();
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user