From f62e68838780dade9fca2dad9c9a267b5cccdce1 Mon Sep 17 00:00:00 2001 From: ochafik Date: Fri, 27 Sep 2024 06:04:41 +0100 Subject: [PATCH] `tool-call`: fix crash / test non-tool call case (added llama_sampler_is_grammar_empty) --- common/sampling.cpp | 8 +++++--- common/tool-call.cpp | 2 +- examples/server/server.cpp | 6 +++--- examples/server/tests/features/steps/steps.py | 2 +- .../server/tests/features/tool_call.feature | 20 ++++++++++++++++++- include/llama.h | 2 ++ src/llama-sampling.cpp | 5 +++++ src/llama-sampling.h | 2 ++ src/llama.cpp | 4 ++++ 9 files changed, 42 insertions(+), 9 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index bbe2f81e6..5593ae4ef 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -140,7 +140,7 @@ std::string gpt_sampler_params::print() const { } bool gpt_sampler_trigger_grammar(const struct llama_model * model, gpt_sampler * gsmpl, const std::string & trigger) { - if (gsmpl->grmr) { + if (!llama_sampler_is_grammar_empty(gsmpl->grmr)) { return false; } gsmpl->grmr = llama_sampler_init_grammar(model, gsmpl->params.grammar.c_str(), "root"); @@ -155,7 +155,7 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st auto * result = new gpt_sampler { /* .params = */ params, - /* .grmr = */ params.grammar_trigger_words.empty() ? llama_sampler_init_grammar(model, params.grammar.c_str(), "root") : nullptr, + /* .grmr = */ llama_sampler_init_grammar(model, params.grammar_trigger_words.empty() ? params.grammar.c_str() : "", "root"), /* .chain = */ llama_sampler_chain_init(lparams), /* .prev = */ ring_buffer(std::max(32, params.n_prev)), /* .cur = */ {}, @@ -256,7 +256,9 @@ void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool acce } void gpt_sampler_reset(struct gpt_sampler * gsmpl) { - llama_sampler_reset(gsmpl->grmr); + if (gsmpl->grmr) { + llama_sampler_reset(gsmpl->grmr); + } llama_sampler_reset(gsmpl->chain); } diff --git a/common/tool-call.cpp b/common/tool-call.cpp index 7b435703a..0b4750b92 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -236,7 +236,7 @@ llama_tool_call_handler llama_tool_call_handler_init( builder.add_schema(name + "-args", parameters) + " \"}\"")); if (allow_content) { - handler.grammar_trigger_words.push_back("\n{\"" + name + "\""); + handler.grammar_trigger_words.push_back("\n{\"name\": \"" + name + "\""); } } } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1a0ffa0bf..cc509d286 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -999,12 +999,12 @@ struct server_context { }; std::vector stop_words; - std::vector grammar_trigger_words; copy_string_array(data, "stop", stop_words); - copy_string_array(data, "grammar_trigger_words", grammar_trigger_words); + copy_string_array(data, "grammar_trigger_words", slot.sparams.grammar_trigger_words); - slot.antiprompts.build(ctx, stop_words, grammar_trigger_words); + slot.antiprompts.build(ctx, stop_words, slot.sparams.grammar_trigger_words); + } { diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index ac822a2eb..922ba0288 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -677,7 +677,7 @@ async def step_tool_called(context): assert n_completions > 0 def check(tool_calls): - assert tool_calls is None + assert tool_calls is None, f"tool calls: {tool_calls}" for i in range(n_completions): assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check) diff --git a/examples/server/tests/features/tool_call.feature b/examples/server/tests/features/tool_call.feature index b7b073025..6cc3e2174 100644 --- a/examples/server/tests/features/tool_call.feature +++ b/examples/server/tests/features/tool_call.feature @@ -16,7 +16,7 @@ Feature: llama.cpp server And jinja templates are enabled - Scenario Outline: OAI Compatibility w/ required tool + Scenario Outline: OAI Compatibility w/ tools and required tool_choice Given a chat template file ../../../tests/chat/templates/.jinja And the server is starting And the server is healthy @@ -38,6 +38,24 @@ Feature: llama.cpp server | meta-llama-Meta-Llama-3.1-8B-Instruct | 16 | ipython | {"code": ". A"} | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] | + Scenario Outline: OAI Compatibility w/ tools and auto tool_choice + Given a chat template file ../../../tests/chat/templates/.jinja + And the server is starting + And the server is healthy + And a model test + And max tokens to predict + And a user prompt write a hello world in python + And tools [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] + And an OAI compatible chat completions request with no api error + Then no tool is called + + Examples: Prompts + | template_name | n_predict | + | meta-llama-Meta-Llama-3.1-8B-Instruct | 64 | + | meetkai-functionary-medium-v3.1 | 128 | + | meetkai-functionary-medium-v3.2 | 128 | + + Scenario: OAI Compatibility w/ no tool Given a chat template file ../../../tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja And the server is starting diff --git a/include/llama.h b/include/llama.h index de5a40ef2..d94aeda0a 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1118,6 +1118,8 @@ extern "C" { const char * grammar_str, const char * grammar_root); + LLAMA_API bool llama_sampler_is_grammar_empty(struct llama_sampler * gsmpl); + LLAMA_API struct llama_sampler * llama_sampler_init_penalties( int32_t n_vocab, // llama_n_vocab() llama_token special_eos_id, // llama_token_eos() diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 0773cd94f..8caf9f73b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1371,6 +1371,11 @@ static struct llama_sampler_i llama_sampler_grammar_i = { /* .clone = */ llama_sampler_grammar_clone, /* .free = */ llama_sampler_grammar_free, }; + +bool llama_sampler_is_grammar_empty_impl(struct llama_sampler * gsmpl) { + struct llama_sampler_grammar * ctx = (struct llama_sampler_grammar *) gsmpl->ctx; + return ctx->grammar == nullptr; +} struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { auto * ctx = new llama_sampler_grammar; diff --git a/src/llama-sampling.h b/src/llama-sampling.h index d90b14713..07f8a66a2 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -27,3 +27,5 @@ struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); + +bool llama_sampler_is_grammar_empty_impl(struct llama_sampler * gsmpl); diff --git a/src/llama.cpp b/src/llama.cpp index 758067958..e7ebc4d1f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21312,6 +21312,10 @@ int32_t llama_chat_apply_template( struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) { return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); } + +bool llama_sampler_is_grammar_empty(struct llama_sampler * gsmpl) { + return llama_sampler_is_grammar_empty_impl(gsmpl); +} // // model split