tool-call: force printing of lazy grammar trigger tokens to regularize function call parsing

This commit is contained in:
Olivier Chafik 2024-10-29 15:26:51 +00:00
parent fa4c1119c9
commit 773ff91b7a
2 changed files with 6 additions and 7 deletions

View File

@ -455,12 +455,10 @@ llama_tool_call_handler llama_tool_call_handler_init(
if (!parallel) {
schema["maxItems"] = 1;
}
builder.add_schema("root", schema);
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
});
if (allow_content) {
handler.grammar_trigger_words.push_back("[TOOL_CALLS]");
handler.grammar_trigger_words.push_back("[{\"");
handler.grammar_trigger_words.push_back("[ { \"");
}
// auto tweaked_messages = add_system(messages, "You are a helpful AI with tool calling capabilities. Prefix any tool calls with [TOOL_CALLS]");
handler.prompt = tmpl.apply(messages, tools, /* add_generation_prompt= */ true);
@ -468,7 +466,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
}
case llama_tool_call_style::Llama31:
case llama_tool_call_style::Llama32: {
static auto builtin_tools = json {"wolfram_alpha", "brave_search"};
static auto builtin_tools = json {"wolfram_alpha", "brave_search", "code_interpreter"};
auto uses_python_tag = style == llama_tool_call_style::Llama31;
@ -569,7 +567,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
const auto & function = tool["function"];
std::string name = function["name"];
auto parameters = function["parameters"];
if (name == "python") {
if (name == "python" || name == "ipython") {
tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*"));
if (allow_content) {
handler.grammar_trigger_words.push_back("<|python_tag|>");

View File

@ -1062,11 +1062,12 @@ struct server_context {
}
bool process_token(completion_token_output & result, server_slot & slot) {
auto match = slot.antiprompts.findSingleTokenMatch(result.tok);
// remember which tokens were sampled - used for repetition penalties during sampling
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special || (match.pos != std::string::npos && match.is_grammar_trigger));
slot.sampled = result.tok;
auto match = slot.antiprompts.findSingleTokenMatch(result.tok);
if (match.pos != std::string::npos && !match.is_partial) {
if (match.is_grammar_trigger) {
common_sampler_trigger_grammar(model, slot.smpl, common_token_to_piece(ctx, result.tok, params.special));