mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
tool-call
: add server tests for llama 3.1
This commit is contained in:
parent
9e366b3d03
commit
a774093a99
@ -316,7 +316,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
|
||||
tool_rules.push_back(
|
||||
builder.add_rule(
|
||||
name + "-call",
|
||||
"\"\\n{\\\"name\\\": " + name + "\\\", \\\"parameters\\\", \" " +
|
||||
"\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
|
||||
builder.add_schema(name + "-args", parameters) +
|
||||
" \"}\""));
|
||||
if (allow_content) {
|
||||
|
@ -80,6 +80,8 @@ def step_server_config(context, server_fqdn: str, server_port: str):
|
||||
context.temperature = None
|
||||
context.lora_file = None
|
||||
context.disable_ctx_shift = False
|
||||
context.use_jinja = False
|
||||
context.chat_template_file = None
|
||||
|
||||
context.tasks_result = []
|
||||
context.concurrent_tasks = []
|
||||
@ -159,6 +161,16 @@ def step_slot_save_path(context, slot_save_path: str):
|
||||
context.slot_save_path = slot_save_path
|
||||
|
||||
|
||||
@step('jinja templates are enabled')
|
||||
def step_use_jinja(context):
|
||||
context.use_jinja = True
|
||||
|
||||
|
||||
@step('chat template file {file}')
|
||||
def step_use_jinja(context, file):
|
||||
context.chat_template_file = file
|
||||
|
||||
|
||||
@step('using slot id {id_slot:d}')
|
||||
def step_id_slot(context, id_slot: int):
|
||||
context.id_slot = id_slot
|
||||
@ -369,7 +381,7 @@ def step_response_format(context, response_format):
|
||||
def step_tools(context, tools):
|
||||
context.tools = json.loads(tools)
|
||||
|
||||
@step('tool choice {tool_choice}')
|
||||
@step('a tool choice {tool_choice}')
|
||||
def step_tool_choice(context, tool_choice):
|
||||
context.tool_choice = tool_choice
|
||||
|
||||
@ -490,8 +502,11 @@ async def step_oai_chat_completions(context, api_error):
|
||||
expect_api_error = api_error == 'raised'
|
||||
seeds = await completions_seed(context, num_seeds=1)
|
||||
completion = await oai_chat_completions(context.prompts.pop(),
|
||||
seeds[0] if seeds is not None else seeds,
|
||||
context.system_prompt,
|
||||
seeds[0] if seeds else None,
|
||||
|
||||
context.system_prompt
|
||||
if hasattr(context, 'system_prompt') else None,
|
||||
|
||||
context.base_url,
|
||||
'/v1/chat',
|
||||
False,
|
||||
@ -631,6 +646,43 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
|
||||
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
|
||||
|
||||
|
||||
@step('tool {expected_name} is called with arguments {expected_arguments}')
|
||||
@async_run_until_complete
|
||||
async def step_tool_called(context, expected_name, expected_arguments):
|
||||
n_completions = await gather_tasks_results(context)
|
||||
assert n_completions > 0
|
||||
|
||||
expected_name = expected_name if expected_name else None
|
||||
expected_arguments = json.loads(expected_arguments) if expected_arguments else None
|
||||
|
||||
def check(tool_calls):
|
||||
if tool_calls is None:
|
||||
assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}'
|
||||
else:
|
||||
assert len(tool_calls) == 1, f"tool calls: {tool_calls}"
|
||||
tool_call = tool_calls[0]
|
||||
actual_name = tool_call.name
|
||||
actual_arguments = json.loads(tool_call.arguments)
|
||||
assert expected_name == actual_name, f"tool name: {actual_name}, expected: {expected_name}"
|
||||
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}"
|
||||
|
||||
for i in range(n_completions):
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)
|
||||
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
|
||||
|
||||
@step('no tool is called')
|
||||
@async_run_until_complete
|
||||
async def step_tool_called(context):
|
||||
n_completions = await gather_tasks_results(context)
|
||||
assert n_completions > 0
|
||||
|
||||
def check(tool_calls):
|
||||
assert tool_calls is None
|
||||
|
||||
for i in range(n_completions):
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)
|
||||
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
|
||||
|
||||
@step('embeddings are computed for')
|
||||
@async_run_until_complete
|
||||
async def step_compute_embedding(context):
|
||||
@ -1001,19 +1053,23 @@ async def oai_chat_completions(user_prompt,
|
||||
print(f"Sending OAI Chat completions request: {user_prompt}")
|
||||
# openai client always expects an api key
|
||||
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||
assert isinstance(seed, int), f'seed: {seed}'
|
||||
seed = seed if seed is not None else 42
|
||||
|
||||
enable_streaming = enable_streaming if enable_streaming is not None else False
|
||||
messages = []
|
||||
if system_prompt:
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
})
|
||||
if user_prompt:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
})
|
||||
payload = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
"messages": messages,
|
||||
"model": model,
|
||||
"max_tokens": n_predict,
|
||||
"stream": enable_streaming,
|
||||
@ -1115,6 +1171,7 @@ async def oai_chat_completions(user_prompt,
|
||||
assert chat_completion.usage is not None
|
||||
completion_response = {
|
||||
'content': chat_completion.choices[0].message.content,
|
||||
'tool_calls': chat_completion.choices[0].message.tool_calls,
|
||||
'timings': {
|
||||
'predicted_n': chat_completion.usage.completion_tokens,
|
||||
'prompt_n': chat_completion.usage.prompt_tokens
|
||||
@ -1181,11 +1238,13 @@ async def request_oai_embeddings(input, seed,
|
||||
return [e.embedding for e in oai_embeddings.data]
|
||||
|
||||
|
||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None, tool_calls_check=None):
|
||||
content = completion_response['content']
|
||||
tool_calls = completion_response.get('tool_calls')
|
||||
n_predicted = completion_response['timings']['predicted_n']
|
||||
assert len(content) > 0, "no token predicted"
|
||||
assert (content and len(content) > 0) or (tool_calls and len(tool_calls) > 0), "no token predicted"
|
||||
if re_content is not None:
|
||||
assert content
|
||||
p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
|
||||
matches = p.finditer(content)
|
||||
last_match = 0
|
||||
@ -1201,6 +1260,8 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
|
||||
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
|
||||
print(f"Checking completion response: {highlighted}")
|
||||
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
|
||||
if tool_calls_check:
|
||||
tool_calls_check(tool_calls)
|
||||
if expected_predicted_n and expected_predicted_n > 0:
|
||||
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
|
||||
f' {n_predicted} <> {expected_predicted_n}')
|
||||
@ -1409,6 +1470,10 @@ def start_server_background(context):
|
||||
server_args.extend(['--grp-attn-w', context.n_ga_w])
|
||||
if context.debug:
|
||||
server_args.append('--verbose')
|
||||
if context.use_jinja:
|
||||
server_args.append('--jinja')
|
||||
if context.chat_template_file:
|
||||
server_args.extend(['--chat-template-file', context.chat_template_file])
|
||||
if context.lora_file:
|
||||
server_args.extend(['--lora', context.lora_file])
|
||||
if context.disable_ctx_shift:
|
||||
|
48
examples/server/tests/features/tool_call.feature
Normal file
48
examples/server/tests/features/tool_call.feature
Normal file
@ -0,0 +1,48 @@
|
||||
@llama.cpp
|
||||
@server
|
||||
Feature: llama.cpp server
|
||||
|
||||
Background: Server startup
|
||||
Given a server listening on localhost:8080
|
||||
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
|
||||
And a model file test-model.gguf
|
||||
And a model alias tinyllama-2
|
||||
And BOS token is 1
|
||||
And 42 as server seed
|
||||
And 8192 KV cache size
|
||||
And 32 as batch size
|
||||
And 2 slots
|
||||
And 64 server max tokens to predict
|
||||
And prometheus compatible metrics exposed
|
||||
And jinja templates are enabled
|
||||
And chat template file ../../../tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
Scenario: Health
|
||||
Then the server is ready
|
||||
And all slots are idle
|
||||
|
||||
Scenario Outline: OAI Compatibility w/ required tool
|
||||
Given a model test
|
||||
And <n> max tokens to predict
|
||||
And a user prompt write a hello world in python
|
||||
And a tool choice <tool_choice>
|
||||
And tools <tools>
|
||||
Given an OAI compatible chat completions request with no api error
|
||||
Then tool <tool_name> is called with arguments <tool_arguments>
|
||||
|
||||
Examples: Prompts
|
||||
| n | tool_name | tool_arguments | tool_choice | tools |
|
||||
| 64 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
|
||||
| 16 | ipython | {"code": "it and "} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
|
||||
|
||||
Scenario: OAI Compatibility w/ no tool
|
||||
Given a model test
|
||||
And 16 max tokens to predict
|
||||
And a user prompt write a hello world in python
|
||||
And a tool choice <tool_choice>
|
||||
And tools []
|
||||
Given an OAI compatible chat completions request with no api error
|
||||
Then no tool is called
|
||||
|
Loading…
Reference in New Issue
Block a user