tool-call: stabilize server tests

This commit is contained in:
ochafik 2024-12-15 00:16:12 +00:00
parent 7bfcd0a8dd
commit 7e3feff073
5 changed files with 53 additions and 57 deletions

View File

@ -646,7 +646,7 @@ class llama_antiprompts {
};
std::vector<std::string> stop_words;
std::vector<std::string> grammar_trigger_words;
std::vector<std::string> grammar_triggers;
private:
// The AhoCorasick algorithm allows efficient string matching with multiple patterns.
@ -740,25 +740,25 @@ private:
stop_tokens.clear();
}
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
void build(const llama_context * ctx, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
build(
[&](const std::string & text) {
return common_tokenize(ctx, text, /* special= */ true);
},
stop_words,
grammar_trigger_words
grammar_triggers
);
}
void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_trigger_words) {
void build(const std::function<std::vector<llama_token>(const std::string &)> & tokenizer, const std::vector<std::string> & stop_words, const std::vector<std::string> & grammar_triggers) {
clear();
this->stop_words = stop_words;
this->grammar_trigger_words = grammar_trigger_words;
this->grammar_triggers = grammar_triggers;
for (const std::string & stop_word : stop_words) {
antiprompts.push_back({stop_word, /* is_grammar_trigger= */ false});
}
for (const std::string & trigger : grammar_trigger_words) {
for (const std::string & trigger : grammar_triggers) {
antiprompts.push_back({trigger, /* is_grammar_trigger= */ true});
}

View File

@ -520,7 +520,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
if (!parallel) {
schema["maxItems"] = 1;
}
builder.add_rule("root", "\"[TOOL_CALLS]\"? " + builder.add_schema("tool_calls", schema));
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
});
if (allow_content) {
handler.grammar_triggers.push_back("[TOOL_CALLS]");

View File

@ -93,7 +93,6 @@ struct slot_params {
json input_prefix;
json input_suffix;
std::vector<std::string> antiprompt;
std::vector<std::string> grammar_triggers;
bool timings_per_token = false;
bool ignore_eos = false;
@ -318,47 +317,39 @@ struct server_task {
}
}
if (data.contains("grammar_triggers")) {
const auto & triggers = data.at("grammar_triggers");
if (triggers.is_array()) {
for (const auto & trigger : triggers) {
if (trigger.is_string()) {
params.grammar_triggers.push_back(trigger);
auto to_string_vec = [](const json & j) {
std::vector<std::string> out;
if (j.is_array()) {
for (const auto & e : j) {
if (e.is_string()) {
out.push_back(e);
}
}
}
return out;
};
{
const auto grammar_trigger_words = data.find("grammar_trigger_words");
if (grammar_trigger_words != data.end()) {
params.sampling.grammar_trigger_words = to_string_vec(*grammar_trigger_words);
}
}
{
params.antiprompt.clear();
const auto & stop = data.find("stop");
if (stop != data.end() && stop->is_array()) {
for (const auto & word : *stop) {
if (!word.empty()) {
params.antiprompt.push_back(word);
}
}
const auto stop = data.find("stop");
if (stop != data.end()) {
params.antiprompt = to_string_vec(*stop);
}
}
{
const auto & samplers = data.find("samplers");
const auto samplers = data.find("samplers");
if (samplers != data.end()) {
if (samplers->is_array()) {
std::vector<std::string> sampler_names;
for (const auto & name : *samplers) {
if (name.is_string()) {
sampler_names.emplace_back(name);
}
}
params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
params.sampling.samplers = common_sampler_types_from_names(to_string_vec(*samplers), false);
} else if (samplers->is_string()){
std::string sampler_string;
for (const auto & name : *samplers) {
sampler_string += name;
}
params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
params.sampling.samplers = common_sampler_types_from_chars(samplers->get<std::string>());
}
} else {
params.sampling.samplers = defaults.sampling.samplers;
@ -546,7 +537,7 @@ struct server_task_result_cmpl_final : server_task_result {
llama_tool_calls parsed_tool_calls;
json tool_calls;
json message_content;
if (!oaicompat_tools.is_null()) {
if (oaicompat_tool_call_style != llama_tool_call_style::None && !oaicompat_tools.is_null()) {
parsed_tool_calls = parse_tool_calls(oaicompat_tool_call_style, oaicompat_tools, content);
if (!parsed_tool_calls.tool_calls.empty()) {
finish_reason = "tool_calls";
@ -1759,7 +1750,7 @@ struct server_context {
{
slot.antiprompts.clear();
slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.grammar_triggers);
slot.antiprompts.build(ctx, slot.params.antiprompt, slot.params.sampling.grammar_trigger_words);
}
{
@ -1805,7 +1796,7 @@ struct server_context {
if (match.pos != std::string::npos && !match.is_partial) {
if (match.is_grammar_trigger) {
common_sampler_trigger_grammar(model, slot.smpl, common_token_to_piece(ctx, result.tok, params_base.special));
common_sampler_trigger_grammar(model, slot.smpl, token_str);
} else {
// slot.stopped_word = true;
slot.stopping_word = match.pattern;
@ -2014,7 +2005,7 @@ struct server_context {
{"mirostat_eta", slot.params.sampling.mirostat_eta},
{"penalize_nl", slot.params.sampling.penalize_nl},
{"stop", slot.params.antiprompt},
{"grammar_trigger", slot.params.grammar_triggers},
{"grammar_trigger_words", slot.params.sampling.grammar_trigger_words},
{"max_tokens", slot.params.n_predict}, // User configured n_predict
{"n_keep", slot.params.n_keep},
{"n_discard", slot.params.n_discard},
@ -3564,7 +3555,7 @@ int main(int argc, char ** argv) {
task.params.oaicompat = oaicompat;
task.params.oaicompat_chat = oaicompat_chat;
task.params.oaicompat_cmpl_id = completion_id;
task.params.oaicompat_tools = json_value(data, "tools", json::array());
task.params.oaicompat_tools = json_value(data, "tools", json());
task.params.oaicompat_tool_call_style = tool_call_style;
// oaicompat_model is already populated by params_from_json_cmpl

View File

@ -202,23 +202,24 @@ CODE_INTEPRETER_TOOL = {
@pytest.mark.parametrize("template_name,n_predict,tool,expected_arguments", [
("meetkai-functionary-medium-v3.1", 32, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": ". She was so excited to go to the park and s"} ),
("meetkai-functionary-medium-v3.2", 32, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.2", 32, PYTHON_TOOL, {"code": "Yes,"} ),
("meetkai-functionary-medium-v3.1", 32, PYTHON_TOOL, {"code": " and played all day.\" exclasted her pare"} ),
("meetkai-functionary-medium-v3.2", 128, TEST_TOOL, {} ),
("meetkai-functionary-medium-v3.2", 128, PYTHON_TOOL, {"code": "Sure, I cannything,"} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, TEST_TOOL, {} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ),
("NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, TEST_TOOL, {} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": "Yes,"} ),
("NousResearch-Hermes-3-Llama-3.1-8B-tool_use", 128, PYTHON_TOOL, {"code": " out the owl cried. Jack said "} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, TEST_TOOL, {} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ),
("meta-llama-Meta-Llama-3.1-8B-Instruct", 128, PYTHON_TOOL, {"code": "Let's feel out cooking fun together,"} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, TEST_TOOL, {} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "It's a shark."} ),
("meta-llama-Llama-3.2-3B-Instruct", 128, PYTHON_TOOL, {"code": "Well you fight. Peopballs donto cheep and come again."} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, TEST_TOOL, {} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "It's a small cost."} ),
("mistralai-Mistral-Nemo-Instruct-2407", 128, PYTHON_TOOL, {"code": "I can cannot count."} ),
])
def test_completion_with_required_tool(template_name: str, n_predict: int, tool: dict, expected_arguments: dict):
global server
server.use_jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
server.start()
res = server.make_request("POST", "/chat/completions", data={
@ -227,13 +228,14 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write an example"},
],
"tool_choice": tool["function"]["name"],
"tool_choice": "required",
"tools": [tool],
"parallel_tool_calls": False,
})
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}'
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
assert tool["function"]["name"] == tool_call["function"]["name"]
actual_arguments = json.loads(tool_call["function"]["arguments"])
@ -254,6 +256,7 @@ def test_completion_with_required_tool(template_name: str, n_predict: int, tool:
def test_completion_without_tool_call(template_name: str, n_predict: int, tools: list[dict], tool_choice: str | None):
global server
server.use_jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../tests/chat/templates/{template_name}.jinja'
server.start()
res = server.make_request("POST", "/chat/completions", data={
@ -267,7 +270,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
})
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
assert "tool_calls" not in choice["message"], f'Expected no tool call in {choice["message"]}'
assert choice["message"].get("tool_calls") is None, f'Expected no tool call in {choice["message"]}'
@pytest.mark.slow
@ -296,6 +299,7 @@ def test_completion_without_tool_call(template_name: str, n_predict: int, tools:
def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: str, hf_file: str, template_override: Tuple[str, str | None] | None):
global server
server.use_jinja = True
server.n_predict = 128
server.model_hf_repo = hf_repo
server.model_hf_file = hf_file
if template_override:
@ -314,7 +318,7 @@ def test_hello_world_tool_call(tool: dict, expected_arguments: dict, hf_repo: st
assert res.status_code == 200, f"Expected status code 200, got {res.status_code}"
choice = res.body["choices"][0]
tool_calls = choice["message"].get("tool_calls")
assert tool_calls and len(tool_calls==1), f'Expected 1 tool call in {choice["message"]}'
assert tool_calls and len(tool_calls) == 1, f'Expected 1 tool call in {choice["message"]}'
tool_call = tool_calls[0]
assert tool["function"]["name"] == tool_call["function"]["name"]
actual_arguments = json.loads(tool_call["function"]["arguments"])

View File

@ -494,7 +494,7 @@ static json oaicompat_completion_params_parse(
auto tools = json_value(body, "tools", json());
auto has_tools = tools.is_array() && !tools.empty();
auto stream = json_value(body, "stream", json());
auto stream = json_value(body, "stream", false);
if (stream && has_tools) {
throw std::runtime_error("Cannot use tools with stream");
}
@ -561,11 +561,12 @@ static json oaicompat_completion_params_parse(
llama_params["stop"].push_back(stop);
}
if (!handler.grammar_triggers.empty()) {
auto triggers = json::array();
auto trigger_words = json::array();
for (const auto & word : handler.grammar_triggers) {
triggers.push_back(word);
trigger_words.push_back(word);
}
llama_params["grammar_triggers"] = triggers;
llama_params["grammar_trigger_words"] = trigger_words;
}
if (!handler.grammar.empty()) {
if (llama_params.contains("grammar")) {