diff --git a/common/common.h b/common/common.h index a7aeda5cf..693561569 100644 --- a/common/common.h +++ b/common/common.h @@ -646,7 +646,7 @@ class llama_antiprompts { }; std::vector stop_words; - std::vector grammar_trigger_words; + std::vector grammar_triggers; private: // The Aho–Corasick algorithm allows efficient string matching with multiple patterns. @@ -740,25 +740,25 @@ private: stop_tokens.clear(); } - void build(const llama_context * ctx, const std::vector & stop_words, const std::vector & grammar_trigger_words) { + void build(const llama_context * ctx, const std::vector & stop_words, const std::vector & grammar_triggers) { build( [&](const std::string & text) { return common_tokenize(ctx, text, /* special= */ true); }, stop_words, - grammar_trigger_words + grammar_triggers ); } - void build(const std::function(const std::string &)> & tokenizer, const std::vector & stop_words, const std::vector & grammar_trigger_words) { + void build(const std::function(const std::string &)> & tokenizer, const std::vector & stop_words, const std::vector & grammar_triggers) { clear(); this->stop_words = stop_words; - this->grammar_trigger_words = grammar_trigger_words; + this->grammar_triggers = grammar_triggers; for (const std::string & stop_word : stop_words) { antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false}); } - for (const std::string & trigger : grammar_trigger_words) { + for (const std::string & trigger : grammar_triggers) { antiprompts.push_back({trigger, /* is_grammar_trigger= */ true}); } diff --git a/common/tool-call.cpp b/common/tool-call.cpp index 39b6326d5..f6d509f4d 100644 --- a/common/tool-call.cpp +++ b/common/tool-call.cpp @@ -520,7 +520,7 @@ llama_tool_call_handler llama_tool_call_handler_init( if (!parallel) { schema["maxItems"] = 1; } - builder.add_rule("root", "\"[TOOL_CALLS]\"? " + builder.add_schema("tool_calls", schema)); + builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); }); if (allow_content) { handler.grammar_triggers.push_back("[TOOL_CALLS]"); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8304ecaac..3a18844b6 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -93,7 +93,6 @@ struct slot_params { json input_prefix; json input_suffix; std::vector antiprompt; - std::vector grammar_triggers; bool timings_per_token = false; bool ignore_eos = false; @@ -318,47 +317,39 @@ struct server_task { } } - if (data.contains("grammar_triggers")) { - const auto & triggers = data.at("grammar_triggers"); - if (triggers.is_array()) { - for (const auto & trigger : triggers) { - if (trigger.is_string()) { - params.grammar_triggers.push_back(trigger); + auto to_string_vec = [](const json & j) { + std::vector out; + if (j.is_array()) { + for (const auto & e : j) { + if (e.is_string()) { + out.push_back(e); } } } + return out; + }; + + { + const auto grammar_trigger_words = data.find("grammar_trigger_words"); + if (grammar_trigger_words != data.end()) { + params.sampling.grammar_trigger_words = to_string_vec(*grammar_trigger_words); + } } { - params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - params.antiprompt.push_back(word); - } - } + const auto stop = data.find("stop"); + if (stop != data.end()) { + params.antiprompt = to_string_vec(*stop); } } { - const auto & samplers = data.find("samplers"); + const auto samplers = data.find("samplers"); if (samplers != data.end()) { if (samplers->is_array()) { - std::vector sampler_names; - for (const auto & name : *samplers) { - if (name.is_string()) { - sampler_names.emplace_back(name); - } - } - params.sampling.samplers = common_sampler_types_from_names(sampler_names, false); + params.sampling.samplers = common_sampler_types_from_names(to_string_vec(*samplers), false); } else if (samplers->is_string()){ - std::string sampler_string; - for (const auto & name : *samplers) { - sampler_string += name; - } - params.sampling.samplers = common_sampler_types_from_chars(sampler_string); + params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); } } else { params.sampling.samplers = defaults.sampling.samplers; @@ -546,7 +537,7 @@ struct server_task_result_cmpl_final : server_task_result { llama_tool_calls parsed_tool_calls; json tool_calls; json message_content; - if (!oaicompat_tools.is_null()) { + if (oaicompat_tool_call_style != llama_tool_call_style::None && !oaicompat_tools.is_null()) { parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content); if (!parsed_tool_calls.tool_calls.empty()) { finish_reason = "tool_calls"; @@ -1759,7 +1750,7 @@ struct server_context { { slot.antiprompts.clear(); - slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.grammar_triggers); + slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.sampling.grammar_trigger_words); } { @@ -1805,7 +1796,7 @@ struct server_context { if (match.pos != std::string::npos && !match.is_partial) { if (match.is_grammar_trigger) { - common_sampler_trigger_grammar(model, slot.smpl, common_token_to_piece(ctx, result.tok, params_base.special)); + common_sampler_trigger_grammar(model, slot.smpl, token_str); } else { // slot.stopped_word = true; slot.stopping_word = match.pattern; @@ -2014,7 +2005,7 @@ struct server_context { {"mirostat_eta", slot.params.sampling.mirostat_eta}, {"penalize_nl", slot.params.sampling.penalize_nl}, {"stop", slot.params.antiprompt}, - {"grammar_trigger", slot.params.grammar_triggers}, + {"grammar_trigger_words", slot.params.sampling.grammar_trigger_words}, {"max_tokens", slot.params.n_predict}, // User configured n_predict {"n_keep", slot.params.n_keep}, {"n_discard", slot.params.n_discard}, @@ -3564,7 +3555,7 @@ int main(int argc, char ** argv) { task.params.oaicompat = oaicompat; task.params.oaicompat_chat = oaicompat_chat; task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_tools = json_value(data, "tools", json::array()); + task.params.oaicompat_tools = json_value(data, "tools", json()); task.params.oaicompat_tool_call_style = tool_call_style; // oaicompat_model is already populated by params_from_json_cmpl diff --git a/examples/server/tests/unit/test_chat_completion.py b/examples/server/tests/unit/test_chat_completion.py index 3b1f25f97..1da9f8c4b 100644 --- a/examples/server/tests/unit/test_chat_completion.py +++ b/examples/server/tests/unit/test_chat_completion.py @@ -202,23 +202,24 @@ CODE_INTEPRETER_TOOL = { @pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [ ("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ), - ("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": ". She was so excited to go to the park and s"} ), - ("meetkai-functionary-medium-v3.2", 32, TEST_TOOL, {} ), - ("meetkai-functionary-medium-v3.2", 32, PYTHON_TOOL, {"code": "Yes,"} ), + ("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": " and played all day.\" exclasted her pare"} ), + ("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {} ), + ("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "Sure, I cannything,"} ), ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ), - ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ), + ("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ), ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ), - ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ), + ("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ), ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ), - ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ), + ("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "Let's feel out cooking fun together,"} ), ("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ), - ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ), + ("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "Well you fight. Peopballs donto cheep and come again."} ), ("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ), - ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a small cost."} ), + ("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "I can cannot count."} ), ]) def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict): global server server.use_jinja = True + server.n_predict = n_predict server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -227,13 +228,14 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool: {"role": "system", "content": "You are a coding assistant."}, {"role": "user", "content": "Write an example"}, ], - "tool_choice": tool["function"]["name"], + "tool_choice": "required", "tools": [tool], + "parallel_tool_calls": False, }) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] tool_calls = choice["message"].get("tool_calls") - assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}' + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] assert tool["function"]["name"] == tool_call["function"]["name"] actual_arguments = json.loads(tool_call["function"]["arguments"]) @@ -254,6 +256,7 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool: def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None): global server server.use_jinja = True + server.n_predict = n_predict server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja' server.start() res = server.make_request("POST", "/chat/completions", data={ @@ -267,7 +270,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools: }) assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] - assert "tool_calls" not in choice["message"], f'Expected no tool call in {choice["message"]}' + assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}' @pytest.mark.slow @@ -296,6 +299,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools: def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None): global server server.use_jinja = True + server.n_predict = 128 server.model_hf_repo = hf_repo server.model_hf_file = hf_file if template_override: @@ -314,7 +318,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st assert res.status_code == 200, f"Expected status code 200, got {res.status_code}" choice = res.body["choices"][0] tool_calls = choice["message"].get("tool_calls") - assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}' + assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}' tool_call = tool_calls[0] assert tool["function"]["name"] == tool_call["function"]["name"] actual_arguments = json.loads(tool_call["function"]["arguments"]) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index c73a5f042..e5ae16a70 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -494,7 +494,7 @@ static json oaicompat_completion_params_parse( auto tools = json_value(body, "tools", json()); auto has_tools = tools.is_array() && !tools.empty(); - auto stream = json_value(body, "stream", json()); + auto stream = json_value(body, "stream", false); if (stream && has_tools) { throw std::runtime_error("Cannot use tools with stream"); } @@ -561,11 +561,12 @@ static json oaicompat_completion_params_parse( llama_params["stop"].push_back(stop); } if (!handler.grammar_triggers.empty()) { - auto triggers = json::array(); + auto trigger_words = json::array(); for (const auto & word : handler.grammar_triggers) { - triggers.push_back(word); + trigger_words.push_back(word); + } - llama_params["grammar_triggers"] = triggers; + llama_params["grammar_trigger_words"] = trigger_words; } if (!handler.grammar.empty()) { if (llama_params.contains("grammar")) {