tool-call: nemo tweak (accept raw sql again)

This commit is contained in:
ochafik 2024-10-31 04:39:40 +00:00
parent 542853b34b
commit 7d9c90f46b
2 changed files with 5 additions and 3 deletions

View File

@ -285,7 +285,7 @@ static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input)
result.tool_calls.push_back({ result.tool_calls.push_back({
tool_call["name"], tool_call["name"],
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(), arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tool_call["id"], tool_call.contains("id") ? tool_call["id"] : "",
}); });
} }
}; };
@ -453,7 +453,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
{"pattern", "^[a-zA-Z0-9]{9}$"}, {"pattern", "^[a-zA-Z0-9]{9}$"},
}}, }},
}}, }},
{"required", json::array({"arguments", "id", "name"})}, {"required", json::array({"name", "arguments", "id"})},
}; };
schemas.push_back(schema); schemas.push_back(schema);
} }
@ -465,10 +465,11 @@ llama_tool_call_handler llama_tool_call_handler_init(
if (!parallel) { if (!parallel) {
schema["maxItems"] = 1; schema["maxItems"] = 1;
} }
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); builder.add_rule("root", "\"[TOOL_CALLS]\"? " + builder.add_schema("tool_calls", schema));
}); });
if (allow_content) { if (allow_content) {
handler.grammar_trigger_words.push_back("[TOOL_CALLS]"); handler.grammar_trigger_words.push_back("[TOOL_CALLS]");
handler.grammar_trigger_words.push_back("[{\"arguments\":");
} }
// auto tweaked_messages = add_system(messages, "You are a helpful AI with tool calling capabilities. Prefix any tool calls with [TOOL_CALLS]"); // auto tweaked_messages = add_system(messages, "You are a helpful AI with tool calling capabilities. Prefix any tool calls with [TOOL_CALLS]");
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true); handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);

View File

@ -397,6 +397,7 @@ static void test_grammars() {
test_template("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", "<s>", "</s>", { "</s>" }, tool_call_message_with_id, tools, test_template("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", "<s>", "</s>", { "</s>" }, tool_call_message_with_id, tools,
/* skip_grammar_test= */ true); /* skip_grammar_test= */ true);
test_template("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", "<s>", "</s>", { "</s>" }, tool_call_message, tools);
test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools); test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
test_template("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools); test_template("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools); test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);