tool-call: nemo tweak (accept raw sql again)

This commit is contained in:
ochafik 2024-10-31 04:39:40 +00:00
parent 542853b34b
commit 7d9c90f46b
2 changed files with 5 additions and 3 deletions

View File

@ -285,7 +285,7 @@ static llama_tool_calls parse_mistral_nemo_tool_calls(const std::string& input)
result.tool_calls.push_back({
tool_call["name"],
arguments.is_string() ? arguments.get<std::string>() : arguments.dump(),
tool_call["id"],
tool_call.contains("id") ? tool_call["id"] : "",
});
}
};
@ -453,7 +453,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
{"pattern", "^[a-zA-Z0-9]{9}$"},
}},
}},
{"required", json::array({"arguments", "id", "name"})},
{"required", json::array({"name", "arguments", "id"})},
};
schemas.push_back(schema);
}
@ -465,10 +465,11 @@ llama_tool_call_handler llama_tool_call_handler_init(
if (!parallel) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
builder.add_rule("root", "\"[TOOL_CALLS]\"? " + builder.add_schema("tool_calls", schema));
});
if (allow_content) {
handler.grammar_trigger_words.push_back("[TOOL_CALLS]");
handler.grammar_trigger_words.push_back("[{\"arguments\":");
}
// auto tweaked_messages = add_system(messages, "You are a helpful AI with tool calling capabilities. Prefix any tool calls with [TOOL_CALLS]");
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);

View File

@ -397,6 +397,7 @@ static void test_grammars() {
test_template("tests/chat/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", "<s>", "</s>", { "</s>" }, tool_call_message_with_id, tools,
/* skip_grammar_test= */ true);
test_template("tests/chat/templates/Qwen-Qwen2.5-7B-Instruct.jinja", "<s>", "</s>", { "</s>" }, tool_call_message, tools);
test_template("tests/chat/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
test_template("tests/chat/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", "<s>", "</s>", { "<|im_end|>" }, tool_call_message, tools);
test_template("tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja", "<s>", "</s>", { "<|eom_id|>", "<|eot_id|>" }, tool_call_message, tools);