tool-call: add server tests for llama 3.1

This commit is contained in:
ochafik 2024-09-26 02:17:30 +01:00
parent 9e366b3d03
commit a774093a99
3 changed files with 129 additions and 16 deletions

View File

@ -316,7 +316,7 @@ llama_tool_call_handler llama_tool_call_handler_init(
tool_rules.push_back(
builder.add_rule(
name + "-call",
"\"\\n{\\\"name\\\": " + name + "\\\", \\\"parameters\\\", \" " +
"\"\\n{\\\"name\\\": \\\"" + name + "\\\", \\\"parameters\\\": \" " +
builder.add_schema(name + "-args", parameters) +
" \"}\""));
if (allow_content) {

View File

@ -80,6 +80,8 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.temperature = None
context.lora_file = None
context.disable_ctx_shift = False
context.use_jinja = False
context.chat_template_file = None
context.tasks_result = []
context.concurrent_tasks = []
@ -159,6 +161,16 @@ def step_slot_save_path(context, slot_save_path: str):
context.slot_save_path = slot_save_path
@step('jinja templates are enabled')
def step_use_jinja(context):
context.use_jinja = True
@step('chat template file {file}')
def step_use_jinja(context, file):
context.chat_template_file = file
@step('using slot id {id_slot:d}')
def step_id_slot(context, id_slot: int):
context.id_slot = id_slot
@ -369,7 +381,7 @@ def step_response_format(context, response_format):
def step_tools(context, tools):
context.tools = json.loads(tools)
@step('tool choice {tool_choice}')
@step('a tool choice {tool_choice}')
def step_tool_choice(context, tool_choice):
context.tool_choice = tool_choice
@ -490,8 +502,11 @@ async def step_oai_chat_completions(context, api_error):
expect_api_error = api_error == 'raised'
seeds = await completions_seed(context, num_seeds=1)
completion = await oai_chat_completions(context.prompts.pop(),
seeds[0] if seeds is not None else seeds,
context.system_prompt,
seeds[0] if seeds else None,
context.system_prompt
if hasattr(context, 'system_prompt') else None,
context.base_url,
'/v1/chat',
False,
@ -631,6 +646,43 @@ async def all_prompts_are_predicted(context, expected_predicted_n=None):
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
@step('tool {expected_name} is called with arguments {expected_arguments}')
@async_run_until_complete
async def step_tool_called(context, expected_name, expected_arguments):
n_completions = await gather_tasks_results(context)
assert n_completions > 0
expected_name = expected_name if expected_name else None
expected_arguments = json.loads(expected_arguments) if expected_arguments else None
def check(tool_calls):
if tool_calls is None:
assert expected_name is None and expected_arguments is None, f'expected_name = {expected_name}, expected_arguments = {expected_arguments}'
else:
assert len(tool_calls) == 1, f"tool calls: {tool_calls}"
tool_call = tool_calls[0]
actual_name = tool_call.name
actual_arguments = json.loads(tool_call.arguments)
assert expected_name == actual_name, f"tool name: {actual_name}, expected: {expected_name}"
assert json.dumps(expected_arguments) == json.dumps(actual_arguments), f"tool arguments: {json.dumps(actual_arguments)}, expected: {json.dumps(expected_arguments)}"
for i in range(n_completions):
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
@step('no tool is called')
@async_run_until_complete
async def step_tool_called(context):
n_completions = await gather_tasks_results(context)
assert n_completions > 0
def check(tool_calls):
assert tool_calls is None
for i in range(n_completions):
assert_n_tokens_predicted(context.tasks_result.pop(), tool_calls_check=check)
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
@step('embeddings are computed for')
@async_run_until_complete
async def step_compute_embedding(context):
@ -1001,19 +1053,23 @@ async def oai_chat_completions(user_prompt,
print(f"Sending OAI Chat completions request: {user_prompt}")
# openai client always expects an api key
user_api_key = user_api_key if user_api_key is not None else 'nope'
assert isinstance(seed, int), f'seed: {seed}'
seed = seed if seed is not None else 42
enable_streaming = enable_streaming if enable_streaming is not None else False
messages = []
if system_prompt:
messages.append({
"role": "system",
"content": system_prompt,
})
if user_prompt:
messages.append({
"role": "user",
"content": user_prompt,
})
payload = {
"messages": [
{
"role": "system",
"content": system_prompt,
},
{
"role": "user",
"content": user_prompt,
}
],
"messages": messages,
"model": model,
"max_tokens": n_predict,
"stream": enable_streaming,
@ -1115,6 +1171,7 @@ async def oai_chat_completions(user_prompt,
assert chat_completion.usage is not None
completion_response = {
'content': chat_completion.choices[0].message.content,
'tool_calls': chat_completion.choices[0].message.tool_calls,
'timings': {
'predicted_n': chat_completion.usage.completion_tokens,
'prompt_n': chat_completion.usage.prompt_tokens
@ -1181,11 +1238,13 @@ async def request_oai_embeddings(input, seed,
return [e.embedding for e in oai_embeddings.data]
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None, tool_calls_check=None):
content = completion_response['content']
tool_calls = completion_response.get('tool_calls')
n_predicted = completion_response['timings']['predicted_n']
assert len(content) > 0, "no token predicted"
assert (content and len(content) > 0) or (tool_calls and len(tool_calls) > 0), "no token predicted"
if re_content is not None:
assert content
p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
matches = p.finditer(content)
last_match = 0
@ -1201,6 +1260,8 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
print(f"Checking completion response: {highlighted}")
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
if tool_calls_check:
tool_calls_check(tool_calls)
if expected_predicted_n and expected_predicted_n > 0:
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
f' {n_predicted} <> {expected_predicted_n}')
@ -1409,6 +1470,10 @@ def start_server_background(context):
server_args.extend(['--grp-attn-w', context.n_ga_w])
if context.debug:
server_args.append('--verbose')
if context.use_jinja:
server_args.append('--jinja')
if context.chat_template_file:
server_args.extend(['--chat-template-file', context.chat_template_file])
if context.lora_file:
server_args.extend(['--lora', context.lora_file])
if context.disable_ctx_shift:

View File

@ -0,0 +1,48 @@
@llama.cpp
@server
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model file test-model.gguf
And a model alias tinyllama-2
And BOS token is 1
And 42 as server seed
And 8192 KV cache size
And 32 as batch size
And 2 slots
And 64 server max tokens to predict
And prometheus compatible metrics exposed
And jinja templates are enabled
And chat template file ../../../tests/chat/templates/meta-llama-Meta-Llama-3.1-8B-Instruct.jinja
Then the server is starting
Then the server is healthy
Scenario: Health
Then the server is ready
And all slots are idle
Scenario Outline: OAI Compatibility w/ required tool
Given a model test
And <n> max tokens to predict
And a user prompt write a hello world in python
And a tool choice <tool_choice>
And tools <tools>
Given an OAI compatible chat completions request with no api error
Then tool <tool_name> is called with arguments <tool_arguments>
Examples: Prompts
| n | tool_name | tool_arguments | tool_choice | tools |
| 64 | test | {} | required | [{"type":"function", "function": {"name": "test", "description": "", "parameters": {"type": "object", "properties": {}}}}] |
| 16 | ipython | {"code": "it and "} | required | [{"type":"function", "function": {"name": "ipython", "description": "", "parameters": {"type": "object", "properties": {"code": {"type": "string", "description": ""}}, "required": ["code"]}}}] |
Scenario: OAI Compatibility w/ no tool
Given a model test
And 16 max tokens to predict
And a user prompt write a hello world in python
And a tool choice <tool_choice>
And tools []
Given an OAI compatible chat completions request with no api error
Then no tool is called