mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 03:31:46 +00:00
tool-call: stabilize server tests
This commit is contained in:
parent
7bfcd0a8dd
commit
7e3feff073
@ -646,7 +646,7 @@ class llama_antiprompts {
|
||||
};
|
||||
|
||||
std::vector<std::string> stop_words;
|
||||
std::vector<std::string> grammar_trigger_words;
|
||||
std::vector<std::string> grammar_triggers;
|
||||
|
||||
private:
|
||||
// The Aho–Corasick algorithm allows efficient string matching with multiple patterns.
|
||||
@ -740,25 +740,25 @@ private:
|
||||
stop_tokens.clear();
|
||||
}
|
||||
|
||||
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
|
||||
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
|
||||
build(
|
||||
[&](const std::string & text) {
|
||||
return common_tokenize(ctx, text, /* special= */ true);
|
||||
},
|
||||
stop_words,
|
||||
grammar_trigger_words
|
||||
grammar_triggers
|
||||
);
|
||||
}
|
||||
|
||||
void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
|
||||
void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
|
||||
clear();
|
||||
this->stop_words = stop_words;
|
||||
this->grammar_trigger_words = grammar_trigger_words;
|
||||
this->grammar_triggers = grammar_triggers;
|
||||
|
||||
for (const std::string & stop_word : stop_words) {
|
||||
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
|
||||
}
|
||||
for (const std::string & trigger : grammar_trigger_words) {
|
||||
for (const std::string & trigger : grammar_triggers) {
|
||||
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
|
||||
}
|
||||
|
||||
|
@ -520,7 +520,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||
if (!parallel) {
|
||||
schema["maxItems"] = 1;
|
||||
}
|
||||
builder.add_rule("root", "\"[TOOL_CALLS]\"? " + builder.add_schema("tool_calls", schema));
|
||||
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
|
||||
});
|
||||
if (allow_content) {
|
||||
handler.grammar_triggers.push_back("[TOOL_CALLS]");
|
||||
|
@ -93,7 +93,6 @@ struct slot_params {
|
||||
json input_prefix;
|
||||
json input_suffix;
|
||||
std::vector<std::string> antiprompt;
|
||||
std::vector<std::string> grammar_triggers;
|
||||
bool timings_per_token = false;
|
||||
bool ignore_eos = false;
|
||||
|
||||
@ -318,47 +317,39 @@ struct server_task {
|
||||
}
|
||||
}
|
||||
|
||||
if (data.contains("grammar_triggers")) {
|
||||
const auto & triggers = data.at("grammar_triggers");
|
||||
if (triggers.is_array()) {
|
||||
for (const auto & trigger : triggers) {
|
||||
if (trigger.is_string()) {
|
||||
params.grammar_triggers.push_back(trigger);
|
||||
auto to_string_vec = [](const json & j) {
|
||||
std::vector<std::string> out;
|
||||
if (j.is_array()) {
|
||||
for (const auto & e : j) {
|
||||
if (e.is_string()) {
|
||||
out.push_back(e);
|
||||
}
|
||||
}
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
{
|
||||
const auto grammar_trigger_words = data.find("grammar_trigger_words");
|
||||
if (grammar_trigger_words != data.end()) {
|
||||
params.sampling.grammar_trigger_words = to_string_vec(*grammar_trigger_words);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
params.antiprompt.clear();
|
||||
|
||||
const auto & stop = data.find("stop");
|
||||
if (stop != data.end() && stop->is_array()) {
|
||||
for (const auto & word : *stop) {
|
||||
if (!word.empty()) {
|
||||
params.antiprompt.push_back(word);
|
||||
}
|
||||
}
|
||||
const auto stop = data.find("stop");
|
||||
if (stop != data.end()) {
|
||||
params.antiprompt = to_string_vec(*stop);
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
const auto & samplers = data.find("samplers");
|
||||
const auto samplers = data.find("samplers");
|
||||
if (samplers != data.end()) {
|
||||
if (samplers->is_array()) {
|
||||
std::vector<std::string> sampler_names;
|
||||
for (const auto & name : *samplers) {
|
||||
if (name.is_string()) {
|
||||
sampler_names.emplace_back(name);
|
||||
}
|
||||
}
|
||||
params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
|
||||
params.sampling.samplers = common_sampler_types_from_names(to_string_vec(*samplers), false);
|
||||
} else if (samplers->is_string()){
|
||||
std::string sampler_string;
|
||||
for (const auto & name : *samplers) {
|
||||
sampler_string += name;
|
||||
}
|
||||
params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
|
||||
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
|
||||
}
|
||||
} else {
|
||||
params.sampling.samplers = defaults.sampling.samplers;
|
||||
@ -546,7 +537,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
llama_tool_calls parsed_tool_calls;
|
||||
json tool_calls;
|
||||
json message_content;
|
||||
if (!oaicompat_tools.is_null()) {
|
||||
if (oaicompat_tool_call_style != llama_tool_call_style::None && !oaicompat_tools.is_null()) {
|
||||
parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
|
||||
if (!parsed_tool_calls.tool_calls.empty()) {
|
||||
finish_reason = "tool_calls";
|
||||
@ -1759,7 +1750,7 @@ struct server_context {
|
||||
|
||||
{
|
||||
slot.antiprompts.clear();
|
||||
slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.grammar_triggers);
|
||||
slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.sampling.grammar_trigger_words);
|
||||
}
|
||||
|
||||
{
|
||||
@ -1805,7 +1796,7 @@ struct server_context {
|
||||
|
||||
if (match.pos != std::string::npos && !match.is_partial) {
|
||||
if (match.is_grammar_trigger) {
|
||||
common_sampler_trigger_grammar(model, slot.smpl, common_token_to_piece(ctx, result.tok, params_base.special));
|
||||
common_sampler_trigger_grammar(model, slot.smpl, token_str);
|
||||
} else {
|
||||
// slot.stopped_word = true;
|
||||
slot.stopping_word = match.pattern;
|
||||
@ -2014,7 +2005,7 @@ struct server_context {
|
||||
{"mirostat_eta", slot.params.sampling.mirostat_eta},
|
||||
{"penalize_nl", slot.params.sampling.penalize_nl},
|
||||
{"stop", slot.params.antiprompt},
|
||||
{"grammar_trigger", slot.params.grammar_triggers},
|
||||
{"grammar_trigger_words", slot.params.sampling.grammar_trigger_words},
|
||||
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_discard", slot.params.n_discard},
|
||||
@ -3564,7 +3555,7 @@ int main(int argc, char ** argv) {
|
||||
task.params.oaicompat = oaicompat;
|
||||
task.params.oaicompat_chat = oaicompat_chat;
|
||||
task.params.oaicompat_cmpl_id = completion_id;
|
||||
task.params.oaicompat_tools = json_value(data, "tools", json::array());
|
||||
task.params.oaicompat_tools = json_value(data, "tools", json());
|
||||
task.params.oaicompat_tool_call_style = tool_call_style;
|
||||
|
||||
// oaicompat_model is already populated by params_from_json_cmpl
|
||||
|
@ -202,23 +202,24 @@ CODE_INTEPRETER_TOOL = {
|
||||
|
||||
@pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [
|
||||
("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ),
|
||||
("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": ". She was so excited to go to the park and s"} ),
|
||||
("meetkai-functionary-medium-v3.2", 32, TEST_TOOL, {} ),
|
||||
("meetkai-functionary-medium-v3.2", 32, PYTHON_TOOL, {"code": "Yes,"} ),
|
||||
("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": " and played all day.\" exclasted her pare"} ),
|
||||
("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {} ),
|
||||
("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "Sure, I cannything,"} ),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ),
|
||||
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ),
|
||||
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ),
|
||||
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "Let's feel out cooking fun together,"} ),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ),
|
||||
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "Well you fight. Peopballs donto cheep and come again."} ),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a small cost."} ),
|
||||
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "I can cannot count."} ),
|
||||
])
|
||||
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict):
|
||||
global server
|
||||
server.use_jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
@ -227,13 +228,14 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
|
||||
{"role": "system", "content": "You are a coding assistant."},
|
||||
{"role": "user", "content": "Write an example"},
|
||||
],
|
||||
"tool_choice": tool["function"]["name"],
|
||||
"tool_choice": "required",
|
||||
"tools": [tool],
|
||||
"parallel_tool_calls": False,
|
||||
})
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}'
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert tool["function"]["name"] == tool_call["function"]["name"]
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
@ -254,6 +256,7 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
|
||||
def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
|
||||
global server
|
||||
server.use_jinja = True
|
||||
server.n_predict = n_predict
|
||||
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
@ -267,7 +270,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
||||
})
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
assert "tool_calls" not in choice["message"], f'Expected no tool call in {choice["message"]}'
|
||||
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@ -296,6 +299,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
|
||||
def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
|
||||
global server
|
||||
server.use_jinja = True
|
||||
server.n_predict = 128
|
||||
server.model_hf_repo = hf_repo
|
||||
server.model_hf_file = hf_file
|
||||
if template_override:
|
||||
@ -314,7 +318,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
|
||||
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
|
||||
choice = res.body["choices"][0]
|
||||
tool_calls = choice["message"].get("tool_calls")
|
||||
assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}'
|
||||
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
|
||||
tool_call = tool_calls[0]
|
||||
assert tool["function"]["name"] == tool_call["function"]["name"]
|
||||
actual_arguments = json.loads(tool_call["function"]["arguments"])
|
||||
|
@ -494,7 +494,7 @@ static json oaicompat_completion_params_parse(
|
||||
auto tools = json_value(body, "tools", json());
|
||||
auto has_tools = tools.is_array() && !tools.empty();
|
||||
|
||||
auto stream = json_value(body, "stream", json());
|
||||
auto stream = json_value(body, "stream", false);
|
||||
if (stream && has_tools) {
|
||||
throw std::runtime_error("Cannot use tools with stream");
|
||||
}
|
||||
@ -561,11 +561,12 @@ static json oaicompat_completion_params_parse(
|
||||
llama_params["stop"].push_back(stop);
|
||||
}
|
||||
if (!handler.grammar_triggers.empty()) {
|
||||
auto triggers = json::array();
|
||||
auto trigger_words = json::array();
|
||||
for (const auto & word : handler.grammar_triggers) {
|
||||
triggers.push_back(word);
|
||||
trigger_words.push_back(word);
|
||||
|
||||
}
|
||||
llama_params["grammar_triggers"] = triggers;
|
||||
llama_params["grammar_trigger_words"] = trigger_words;
|
||||
}
|
||||
if (!handler.grammar.empty()) {
|
||||
if (llama_params.contains("grammar")) {
|
||||
|
Loading…
Reference in New Issue
Block a user