tool-call: fix crash / test non-tool call case (added llama_sampler_is_grammar_empty)

This commit is contained in:
ochafik 2024-09-27 06:04:41 +01:00
parent 0abfa36ca7
commit f62e688387
9 changed files with 42 additions and 9 deletions

View File

@ -140,7 +140,7 @@ std::string gpt_sampler_params::print() const {
}
bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger) {
if (gsmpl->grmr) {
if (!llama_sampler_is_grammar_empty(gsmpl->grmr)) {
return false;
}
gsmpl->grmr = llama_sampler_init_grammar(model, gsmpl->params.grammar.c_str(), "root");
@ -155,7 +155,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
auto * result = new gpt_sampler {
/* .params = */ params,
/* .grmr = */ params.grammar_trigger_words.empty() ? llama_sampler_init_grammar(model, params.grammar.c_str(), "root") : nullptr,
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar_trigger_words.empty() ? params.grammar.c_str() : "", "root"),
/* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {},
@ -256,7 +256,9 @@ void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool acce
}
void gpt_sampler_reset(struct gpt_sampler * gsmpl) {
llama_sampler_reset(gsmpl->grmr);
if (gsmpl->grmr) {
llama_sampler_reset(gsmpl->grmr);
}
llama_sampler_reset(gsmpl->chain);
}

View File

@ -236,7 +236,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (allow_content) {
handler.grammar_trigger_words.push_back("\n{\"" + name + "\"");
handler.grammar_trigger_words.push_back("\n{\"name\": \"" + name + "\"");
}
}
}

View File

@ -999,12 +999,12 @@ struct server_context {
};
std::vector<std::string> stop_words;
std::vector<std::string> grammar_trigger_words;
copy_string_array(data, "stop", stop_words);
copy_string_array(data, "grammar_trigger_words", grammar_trigger_words);
copy_string_array(data, "grammar_trigger_words", slot.sparams.grammar_trigger_words);
slot.antiprompts.build(ctx, stop_words, grammar_trigger_words);
slot.antiprompts.build(ctx, stop_words, slot.sparams.grammar_trigger_words);
}
{

View File

@ -677,7 +677,7 @@ async def step_tool_called(context):
assert n_completions > 0
def check(tool_calls):
assert tool_calls is None
assert tool_calls is None, f"tool calls: {tool_calls}"
for i in range(n_completions):
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)

View File

@ -16,7 +16,7 @@ Feature: llama.cpp server
And jinja templates are enabled
Scenario Outline: OAI Compatibility w/ required tool
Scenario Outline: OAI Compatibility w/ tools and required tool_choice
Given a chat template file ../../../tests/chat/templates/<template_name>.jinja
And the server is starting
And the server is healthy
@ -38,6 +38,24 @@ Feature: llama.cpp server
| meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": ". A"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
Scenario Outline: OAI Compatibility w/ tools and auto tool_choice
Given a chat template file ../../../tests/chat/templates/<template_name>.jinja
And the server is starting
And the server is healthy
And a model test
And <n_predict> max tokens to predict
And a user prompt write a hello world in python
And tools [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}]
And an OAI compatible chat completions request with no api error
Then no tool is called
Examples: Prompts
| template_name | n_predict |
| meta-llama-Meta-Llama-3.1-8B-Instruct | 64 |
| meetkai-functionary-medium-v3.1 | 128 |
| meetkai-functionary-medium-v3.2 | 128 |
Scenario: OAI Compatibility w/ no tool
Given a chat template file ../../../tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja
And the server is starting

View File

@ -1118,6 +1118,8 @@ extern "C" {
const char * grammar_str,
const char * grammar_root);
LLAMA_API bool llama_sampler_is_grammar_empty(struct llama_sampler * gsmpl);
LLAMA_API struct llama_sampler * llama_sampler_init_penalties(
int32_t n_vocab, // llama_n_vocab()
llama_token special_eos_id, // llama_token_eos()

View File

@ -1371,6 +1371,11 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
/* .clone = */ llama_sampler_grammar_clone,
/* .free = */ llama_sampler_grammar_free,
};
bool llama_sampler_is_grammar_empty_impl(struct llama_sampler * gsmpl) {
struct llama_sampler_grammar * ctx = (struct llama_sampler_grammar *) gsmpl->ctx;
return ctx->grammar == nullptr;
}
struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
auto * ctx = new llama_sampler_grammar;

View File

@ -27,3 +27,5 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab,
const char * grammar_str,
const char * grammar_root);
bool llama_sampler_is_grammar_empty_impl(struct llama_sampler * gsmpl);

View File

@ -21312,6 +21312,10 @@ int32_t llama_chat_apply_template(
struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
}
bool llama_sampler_is_grammar_empty(struct llama_sampler * gsmpl) {
return llama_sampler_is_grammar_empty_impl(gsmpl);
}
//
// model split