diff --git a/common/chat-template.cpp b/common/chat-template.cpp index eee134dba..266ae7c80 100644 --- a/common/chat-template.cpp +++ b/common/chat-template.cpp @@ -41,12 +41,19 @@ llama_chat_template::llama_chat_template(const std::string & chat_template, cons _tool_call_style = Hermes2Pro; } else if (chat_template.find(">>>all") != std::string::npos) { _tool_call_style = FunctionaryV3Llama3; - } else if (chat_template.find("<|start_header_id|>") != std::string::npos) { - if (chat_template.find("") != std::string::npos) { + } else if (chat_template.find("<|start_header_id|>") != std::string::npos + && chat_template.find("ipython<|end_header_id|>") != std::string::npos) { + if (chat_template.find("<|python_tag|>") != std::string::npos) { _tool_call_style = Llama31; + } else { + _tool_call_style = Llama32; } + } else if (chat_template.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) { + _tool_call_style = CommandRPlus; + } else { + _tool_call_style = UnknownToolCallStyle; } _template_root = minja::Parser::parse(_chat_template, { /* .trim_blocks = */ true, diff --git a/common/chat-template.h b/common/chat-template.h index 162497b8e..ff2b56745 100644 --- a/common/chat-template.h +++ b/common/chat-template.h @@ -11,9 +11,11 @@ using json = nlohmann::ordered_json; enum llama_tool_call_style { UnknownToolCallStyle, Llama31, + Llama32, FunctionaryV3Llama3, FunctionaryV3Llama31, Hermes2Pro, + CommandRPlus, }; class llama_chat_template { diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 8f2a58bc4..b9e07b109 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -27,7 +27,8 @@ static std::string filename_without_extension(const std::string & path) { return res; } -static void assert_equals(const std::string & expected, const std::string & actual) { +template +static void assert_equals(const T & expected, const T & actual) { if (expected != actual) { std::cerr << "Expected: " << expected << std::endl; std::cerr << "Actual: " << actual << std::endl; @@ -118,6 +119,20 @@ static void test_jinja_templates() { } } +void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) { + auto tmpl = llama_chat_template(read_file(template_file), "", ""); + std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush; + assert_equals(expected, tmpl.tool_call_style()); +} + +void test_tool_call_styles() { + test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31); + test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3); + test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31); + test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32); + test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus); +} + static void test_legacy_templates() { struct test_template { std::string name; @@ -330,6 +345,7 @@ int main(void) { if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) { fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m"); } else { + test_tool_call_styles(); test_jinja_templates(); }