tool-call: test tool call style detection

This commit is contained in:
ochafik 2024-09-28 17:43:09 +01:00
parent 887951beb0
commit 0c85bc7a8f
3 changed files with 30 additions and 5 deletions

View File

@ -41,12 +41,19 @@ llama_chat_template::llama_chat_template(const std::string & chat_template, cons
_tool_call_style = Hermes2Pro;
} else if (chat_template.find(">>>all") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama3;
} else if (chat_template.find("<|start_header_id|>") != std::string::npos) {
if (chat_template.find("<function=") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama31;
} else if (chat_template.find("<|python_tag|>") != std::string::npos) {
} else if (chat_template.find("<|start_header_id|>") != std::string::npos
&& chat_template.find("<function=") != std::string::npos) {
_tool_call_style = FunctionaryV3Llama31;
} else if (chat_template.find("<|start_header_id|>ipython<|end_header_id|>") != std::string::npos) {
if (chat_template.find("<|python_tag|>") != std::string::npos) {
_tool_call_style = Llama31;
} else {
_tool_call_style = Llama32;
}
} else if (chat_template.find("<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>") != std::string::npos) {
_tool_call_style = CommandRPlus;
} else {
_tool_call_style = UnknownToolCallStyle;
}
_template_root = minja::Parser::parse(_chat_template, {
/* .trim_blocks = */ true,

View File

@ -11,9 +11,11 @@ using json = nlohmann::ordered_json;
enum llama_tool_call_style {
UnknownToolCallStyle,
Llama31,
Llama32,
FunctionaryV3Llama3,
FunctionaryV3Llama31,
Hermes2Pro,
CommandRPlus,
};
class llama_chat_template {

View File

@ -27,7 +27,8 @@ static std::string filename_without_extension(const std::string & path) {
return res;
}
static void assert_equals(const std::string & expected, const std::string & actual) {
template <class T>
static void assert_equals(const T & expected, const T & actual) {
if (expected != actual) {
std::cerr << "Expected: " << expected << std::endl;
std::cerr << "Actual: " << actual << std::endl;
@ -118,6 +119,20 @@ static void test_jinja_templates() {
}
}
void test_tool_call_style(const std::string & template_file, llama_tool_call_style expected) {
auto tmpl = llama_chat_template(read_file(template_file), "<s>", "</s>");
std::cout << "# Testing tool call style of: " << template_file << std::endl << std::flush;
assert_equals(expected, tmpl.tool_call_style());
}
void test_tool_call_styles() {
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.1.jinja", FunctionaryV3Llama31);
test_tool_call_style("tests/chat/templates/meetkai-functionary-medium-v3.2.jinja", FunctionaryV3Llama3);
test_tool_call_style("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", Llama31);
test_tool_call_style("tests/chat/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", Llama32);
test_tool_call_style("tests/chat/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", CommandRPlus);
}
static void test_legacy_templates() {
struct test_template {
std::string name;
@ -330,6 +345,7 @@ int main(void) {
if (getenv("LLAMA_SKIP_TESTS_SLOW_ON_EMULATOR")) {
fprintf(stderr, "\033[33mWARNING: Skipping slow tests on emulator.\n\033[0m");
} else {
test_tool_call_styles();
test_jinja_templates();
}