server: tests: passkey challenge / self-extend with context shift demo (#5832)

* server: tests: add models endpoint scenario

* server: /v1/models add some metadata

* server: tests: add debug field in context before scenario

* server: tests: download model from HF, add batch size

* server: tests: add passkey test

* server: tests: add group attention params

* server: do not truncate prompt tokens if self-extend through group attention is enabled

* server: logs: do not truncate log values

* server: tests - passkey - first good working value of nga

* server: tests: fix server timeout

* server: tests: fix passkey, add doc, fix regex content matching, fix timeout

* server: tests: fix regex content matching

* server: tests: schedule slow tests on master

* server: metrics: fix when no prompt processed

* server: tests: self-extend add llama-2-7B and Mixtral-8x7B-v0.1

* server: tests: increase timeout for completion

* server: tests: keep only the PHI-2 test

* server: tests: passkey add a negative test
This commit is contained in:
Pierrick Hymbert 2024-03-02 22:00:14 +01:00 committed by GitHub
parent 4a6e2d6142
commit 9731134296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 363 additions and 112 deletions

View File

@ -10,6 +10,8 @@ on:
pull_request: pull_request:
types: [opened, synchronize, reopened] types: [opened, synchronize, reopened]
paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/tests/**.*'] paths: ['.github/workflows/server.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/tests/**.*']
schedule:
- cron: '00 0 * * *'
jobs: jobs:
server: server:
@ -70,14 +72,15 @@ jobs:
run: | run: |
pip install -r examples/server/tests/requirements.txt pip install -r examples/server/tests/requirements.txt
- name: Download models
id: download_models
run: |
cd examples/server/tests
../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf
- name: Tests - name: Tests
id: server_integration_test id: server_integration_tests
run: | run: |
cd examples/server/tests cd examples/server/tests
PORT=8888 ./tests.sh PORT=8888 ./tests.sh
- name: Slow tests
id: server_integration_tests_slow
if: github.event.schedule != ''
run: |
cd examples/server/tests
PORT=8888 ./tests.sh --stop --no-skipped --no-capture --tags slow

View File

@ -441,8 +441,8 @@ struct llama_server_context
const int ga_w = params.grp_attn_w; const int ga_w = params.grp_attn_w;
if (ga_n != 1) { if (ga_n != 1) {
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
@ -1709,8 +1709,8 @@ struct llama_server_context
} }
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep); slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
// if input prompt is too big, truncate it // if input prompt is too big, truncate it, if group attention self-extend is disabled
if (slot.n_prompt_tokens >= slot.n_ctx) if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx)
{ {
const int n_left = slot.n_ctx - slot.params.n_keep; const int n_left = slot.n_ctx - slot.params.n_keep;
const int n_block_size = n_left / 2; const int n_block_size = n_left / 2;
@ -1785,9 +1785,11 @@ struct llama_server_context
} }
LOG_INFO("slot progression", { LOG_INFO("slot progression", {
{ "slot_id", slot.id }, { "slot_id", slot.id },
{ "task_id", slot.task_id }, { "task_id", slot.task_id },
{ "n_past", slot.n_past }, { "n_past", slot.n_past },
{ "n_past_se", slot.n_past_se },
{ "ga_i", slot.ga_i },
{ "n_prompt_tokens_processed", slot.n_prompt_tokens_processed } { "n_prompt_tokens_processed", slot.n_prompt_tokens_processed }
}); });
} }
@ -2001,6 +2003,17 @@ struct llama_server_context
LOG_VERBOSE("slots updated", {}); LOG_VERBOSE("slots updated", {});
return true; return true;
} }
json model_meta() {
return json{
{"vocab_type", llama_vocab_type(model)},
{"n_vocab", llama_n_vocab(model)},
{"n_ctx_train", llama_n_ctx_train(model)},
{"n_embd", llama_n_embd(model)},
{"n_params", llama_model_n_params(model)},
{"size", llama_model_size(model)},
};
}
}; };
static void server_print_usage(const char *argv0, const gpt_params &params, static void server_print_usage(const char *argv0, const gpt_params &params,
@ -2911,9 +2924,10 @@ int main(int argc, char **argv)
for (const auto& metric_def : metrics_def) { for (const auto& metric_def : metrics_def) {
std::string name = metric_def["name"]; std::string name = metric_def["name"];
std::string help = metric_def["help"]; std::string help = metric_def["help"];
prometheus << "# HELP llamacpp:" << name << " " << help << "\n" auto value = json_value(metric_def, "value", 0);
<< "# TYPE llamacpp:" << name << " " << type << "\n" prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
<< "llamacpp:" << name << " " << metric_def["value"] << "\n"; << "# TYPE llamacpp:" << name << " " << type << "\n"
<< "llamacpp:" << name << " " << value << "\n";
} }
} }
@ -2994,6 +3008,7 @@ int main(int argc, char **argv)
state.store(SERVER_STATE_READY); state.store(SERVER_STATE_READY);
LOG_INFO("model loaded", {}); LOG_INFO("model loaded", {});
} }
const auto model_meta = llama.model_meta();
if (sparams.chat_template.empty()) { // custom chat template is not supplied if (sparams.chat_template.empty()) { // custom chat template is not supplied
// check if the template comes with the model is supported by us // check if the template comes with the model is supported by us
@ -3143,7 +3158,7 @@ int main(int argc, char **argv)
} }
}); });
svr.Get("/v1/models", [&params](const httplib::Request& req, httplib::Response& res) svr.Get("/v1/models", [&params, &model_meta](const httplib::Request& req, httplib::Response& res)
{ {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
std::time_t t = std::time(0); std::time_t t = std::time(0);
@ -3152,10 +3167,11 @@ int main(int argc, char **argv)
{"object", "list"}, {"object", "list"},
{"data", { {"data", {
{ {
{"id", params.model_alias}, {"id", params.model_alias},
{"object", "model"}, {"object", "model"},
{"created", t}, {"created", t},
{"owned_by", "llamacpp"} {"owned_by", "llamacpp"},
{"meta", model_meta}
}, },
}} }}
}; };

View File

@ -1,22 +1,30 @@
# Server tests # Server tests
Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development) and [behave](https://behave.readthedocs.io/en/latest/): Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development)
* [issues.feature](./features/issues.feature) Pending issues scenario and [behave](https://behave.readthedocs.io/en/latest/):
* [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
* [security.feature](./features/security.feature) Security, CORS and API Key * [issues.feature](./features/issues.feature) Pending issues scenario
* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc... * [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
* [security.feature](./features/security.feature) Security, CORS and API Key
* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc...
Tests target GitHub workflows job runners with 4 vCPU. Tests target GitHub workflows job runners with 4 vCPU.
Requests are using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html) based http client. Requests are
using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html)
based http client.
Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. To mitigate it, you can increase values in `n_predict`, `kv_size`. Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail.
To mitigate it, you can increase values in `n_predict`, `kv_size`.
### Install dependencies ### Install dependencies
`pip install -r requirements.txt` `pip install -r requirements.txt`
### Run tests ### Run tests
1. Build the server 1. Build the server
```shell ```shell
cd ../../.. cd ../../..
mkdir build mkdir build
@ -24,24 +32,36 @@ cd build
cmake ../ cmake ../
cmake --build . --target server cmake --build . --target server
``` ```
2. download required models:
1. `../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf` 2. Start the test: `./tests.sh`
3. Start the test: `./tests.sh`
It's possible to override some scenario steps values with environment variables: It's possible to override some scenario steps values with environment variables:
- `PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080`
- `LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server` | variable | description |
- `DEBUG` -> "ON" to enable steps and server verbose mode `--verbose` |--------------------------|------------------------------------------------------------------------------------------------|
- `SERVER_LOG_FORMAT_JSON` -> if set switch server logs to json format | `PORT` | `context.server_port` to set the listening port of the server during scenario, default: `8080` |
| `LLAMA_SERVER_BIN_PATH` | to change the server binary path, default: `../../../build/bin/server` |
| `DEBUG` | "ON" to enable steps and server verbose mode `--verbose` |
| `SERVER_LOG_FORMAT_JSON` | if set switch server logs to json format |
| `N_GPU_LAYERS` | number of model layers to offload to VRAM `-ngl --n-gpu-layers` |
### Run @bug, @wip or @wrong_usage annotated scenario ### Run @bug, @wip or @wrong_usage annotated scenario
Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope. Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope.
- `@bug` annotation aims to link a scenario with a GitHub issue. - `@bug` annotation aims to link a scenario with a GitHub issue.
- `@wrong_usage` are meant to show user issue that are actually an expected behavior - `@wrong_usage` are meant to show user issue that are actually an expected behavior
- `@wip` to focus on a scenario working in progress - `@wip` to focus on a scenario working in progress
- `@slow` heavy test, disabled by default
To run a scenario annotated with `@bug`, start: To run a scenario annotated with `@bug`, start:
`DEBUG=ON ./tests.sh --no-skipped --tags bug`
```shell
DEBUG=ON ./tests.sh --no-skipped --tags bug
```
After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated. After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated.
```shell
./tests.sh --no-skipped --tags bug,wrong_usage || echo "should failed but compile"
```

View File

@ -7,7 +7,10 @@ from signal import SIGKILL
def before_scenario(context, scenario): def before_scenario(context, scenario):
print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m") context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
if context.debug:
print("DEBUG=ON\n")
print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m\n")
port = 8080 port = 8080
if 'PORT' in os.environ: if 'PORT' in os.environ:
port = int(os.environ['PORT']) port = int(os.environ['PORT'])

View File

@ -1,4 +1,5 @@
# List of ongoing issues # List of ongoing issues
# run with: DEBUG=ON ./tests.sh --no-skipped --tags bug
@bug @bug
Feature: Issues Feature: Issues
# No confirmed issue at the moment # No confirmed issue at the moment

View File

@ -1,11 +1,12 @@
@llama.cpp @llama.cpp
@parallel
Feature: Parallel Feature: Parallel
Background: Server startup Background: Server startup
Given a server listening on localhost:8080 Given a server listening on localhost:8080
And a model file stories260K.gguf And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model alias tinyllama-2
And 42 as server seed And 42 as server seed
And 512 as batch size
And 64 KV cache size And 64 KV cache size
And 2 slots And 2 slots
And embeddings extraction And embeddings extraction

View File

@ -0,0 +1,55 @@
# run with: ./tests.sh --no-skipped --tags passkey
@passkey
@slow
Feature: Passkey / Self-extend with context shift
Background: Server startup
Given a server listening on localhost:8080
# Generates a long text of junk and inserts a secret passkey number inside it.
# Then we query the LLM for the secret passkey.
# see #3856 and #4810
Scenario Outline: Passkey
Given a model file <hf_file> from HF repo <hf_repo>
And <n_batch> as batch size
And <n_junk> as number of junk
And <n_predicted> server max tokens to predict
And 42 as seed
And <n_ctx> KV cache size
And 1 slots
And <n_ga> group attention factor to extend context size through self-extend
And <n_ga_w> group attention width to extend context size through self-extend
# Can be override with N_GPU_LAYERS
And <ngl> GPU offloaded layers
Then the server is starting
Then the server is healthy
Given available models
Then model 0 is trained on <n_ctx_train> tokens context
Given a prefix prompt:
"""
here is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there.
"""
And a passkey prompt template:
"""
The pass key is <passkey> Remember it. <passkey> is the pass key.
"""
And a junk suffix prompt:
"""
The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again.
"""
And a suffix prompt:
"""
What is the pass key? The pass key is
"""
Given a "<passkey>" passkey challenge prompt with the passkey inserted every <i_pos> junk
And a completion request with no api error
Then <n_predicted> tokens are predicted matching <re_content>
Examples:
| hf_repo | hf_file | n_ctx_train | ngl | n_ctx | n_batch | n_ga | n_ga_w | n_junk | i_pos | passkey | n_predicted | re_content |
| TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 4 | 512 | 250 | 50 | 42 | 1 | 42 |
| TheBloke/phi-2-GGUF | phi-2.Q4_K_M.gguf | 2048 | 5 | 8192 | 512 | 2 | 512 | 250 | 50 | 42 | 1 | \b((?!42)\w)+\b |
#| TheBloke/Llama-2-7B-GGUF | llama-2-7b.Q2_K.gguf | 4096 | 3 | 16384 | 512 | 4 | 512 | 500 | 300 | 1234 | 5 | 1234 |
#| TheBloke/Mixtral-8x7B-v0.1-GGUF | mixtral-8x7b-v0.1.Q2_K.gguf | 32768 | 2 | 16384 | 512 | 4 | 512 | 500 | 100 | 0987 | 5 | 0
# 987 |

View File

@ -1,9 +1,10 @@
@llama.cpp @llama.cpp
@security
Feature: Security Feature: Security
Background: Server startup with an api key defined Background: Server startup with an api key defined
Given a server listening on localhost:8080 Given a server listening on localhost:8080
And a model file stories260K.gguf And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a server api key llama.cpp And a server api key llama.cpp
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy

View File

@ -1,15 +1,17 @@
@llama.cpp @llama.cpp
@server
Feature: llama.cpp server Feature: llama.cpp server
Background: Server startup Background: Server startup
Given a server listening on localhost:8080 Given a server listening on localhost:8080
And a model file stories260K.gguf And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
And a model alias tinyllama-2 And a model alias tinyllama-2
And 42 as server seed And 42 as server seed
# KV Cache corresponds to the total amount of tokens # KV Cache corresponds to the total amount of tokens
# that can be stored across all independent sequences: #4130 # that can be stored across all independent sequences: #4130
# see --ctx-size and #5568 # see --ctx-size and #5568
And 32 KV cache size And 32 KV cache size
And 512 as batch size
And 1 slots And 1 slots
And embeddings extraction And embeddings extraction
And 32 server max tokens to predict And 32 server max tokens to predict
@ -29,9 +31,9 @@ Feature: llama.cpp server
And prometheus metrics are exposed And prometheus metrics are exposed
Examples: Prompts Examples: Prompts
| prompt | n_predict | re_content | n_predicted | | prompt | n_predict | re_content | n_predicted |
| I believe the meaning of life is | 8 | (read<or>going)+ | 8 | | I believe the meaning of life is | 8 | (read\|going)+ | 8 |
| Write a joke about AI | 64 | (park<or>friends<or>scared<or>always)+ | 32 | | Write a joke about AI | 64 | (park\|friends\|scared\|always)+ | 32 |
Scenario Outline: OAI Compatibility Scenario Outline: OAI Compatibility
Given a model <model> Given a model <model>
@ -43,9 +45,9 @@ Feature: llama.cpp server
Then <n_predicted> tokens are predicted matching <re_content> Then <n_predicted> tokens are predicted matching <re_content>
Examples: Prompts Examples: Prompts
| model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming | | model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming |
| llama-2 | Book | What is the best book | 8 | (Mom<or>what)+ | 8 | disabled | | llama-2 | Book | What is the best book | 8 | (Mom\|what)+ | 8 | disabled |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks<or>happy<or>bird)+ | 32 | enabled | | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks\|happy\|bird)+ | 32 | enabled |
Scenario: Embedding Scenario: Embedding
When embeddings are computed for: When embeddings are computed for:
@ -75,10 +77,15 @@ Feature: llama.cpp server
When an OAI compatible embeddings computation request for multiple inputs When an OAI compatible embeddings computation request for multiple inputs
Then embeddings are generated Then embeddings are generated
Scenario: Tokenize / Detokenize Scenario: Tokenize / Detokenize
When tokenizing: When tokenizing:
""" """
What is the capital of France ? What is the capital of France ?
""" """
Then tokens can be detokenize Then tokens can be detokenize
Scenario: Models available
Given available models
Then 1 models are supported
Then model 0 is identified by tinyllama-2
Then model 0 is trained on 128 tokens context

View File

@ -13,6 +13,7 @@ import aiohttp
import openai import openai
from behave import step from behave import step
from behave.api.async_step import async_run_until_complete from behave.api.async_step import async_run_until_complete
from huggingface_hub import hf_hub_download
from prometheus_client import parser from prometheus_client import parser
@ -26,17 +27,23 @@ def step_server_config(context, server_fqdn, server_port):
context.base_url = f'http://{context.server_fqdn}:{context.server_port}' context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
context.model_alias = None context.model_alias = None
context.n_batch = None
context.n_ctx = None context.n_ctx = None
context.n_ga = None
context.n_ga_w = None
context.n_gpu_layer = None
context.n_predict = None context.n_predict = None
context.n_server_predict = None context.n_server_predict = None
context.n_slots = None context.n_slots = None
context.prompt_prefix = None
context.prompt_suffix = None
context.server_api_key = None context.server_api_key = None
context.server_continuous_batching = False context.server_continuous_batching = False
context.server_embeddings = False context.server_embeddings = False
context.server_metrics = False context.server_metrics = False
context.server_process = None context.server_process = None
context.seed = None
context.server_seed = None context.server_seed = None
context.user_api_key = None context.user_api_key = None
@ -45,9 +52,11 @@ def step_server_config(context, server_fqdn, server_port):
context.prompts = [] context.prompts = []
@step(u'a model file {model_file}') @step(u'a model file {hf_file} from HF repo {hf_repo}')
def step_model_file(context, model_file): def step_download_hf_model(context, hf_file, hf_repo):
context.model_file = model_file context.model_file = hf_hub_download(repo_id=hf_repo, filename=hf_file)
if context.debug:
print(f"model file: {context.model_file}\n")
@step(u'a model alias {model_alias}') @step(u'a model alias {model_alias}')
@ -55,24 +64,34 @@ def step_model_alias(context, model_alias):
context.model_alias = model_alias context.model_alias = model_alias
@step(u'{seed} as server seed') @step(u'{seed:d} as server seed')
def step_seed(context, seed): def step_seed(context, seed):
context.server_seed = int(seed) context.server_seed = seed
@step(u'{n_ctx} KV cache size') @step(u'{ngl:d} GPU offloaded layers')
def step_n_gpu_layer(context, ngl):
if 'N_GPU_LAYERS' in os.environ:
new_ngl = int(os.environ['N_GPU_LAYERS'])
if context.debug:
print(f"-ngl upgraded from {ngl} to {new_ngl}")
ngl = new_ngl
context.n_gpu_layer = ngl
@step(u'{n_ctx:d} KV cache size')
def step_n_ctx(context, n_ctx): def step_n_ctx(context, n_ctx):
context.n_ctx = int(n_ctx) context.n_ctx = n_ctx
@step(u'{n_slots} slots') @step(u'{n_slots:d} slots')
def step_n_slots(context, n_slots): def step_n_slots(context, n_slots):
context.n_slots = int(n_slots) context.n_slots = n_slots
@step(u'{n_predict} server max tokens to predict') @step(u'{n_predict:d} server max tokens to predict')
def step_server_n_predict(context, n_predict): def step_server_n_predict(context, n_predict):
context.n_server_predict = int(n_predict) context.n_server_predict = n_predict
@step(u'continuous batching') @step(u'continuous batching')
@ -116,11 +135,13 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
case 'ready' | 'idle': case 'ready' | 'idle':
await wait_for_health_status(context, context.base_url, 200, 'ok', await wait_for_health_status(context, context.base_url, 200, 'ok',
timeout=10,
params={'fail_on_no_slot': 0, 'include_slots': 0}, params={'fail_on_no_slot': 0, 'include_slots': 0},
slots_idle=context.n_slots, slots_idle=context.n_slots,
slots_processing=0, slots_processing=0,
expected_slots=[{'id': slot_id, 'state': 0} expected_slots=[{'id': slot_id, 'state': 0}
for slot_id in range(context.n_slots)]) for slot_id in
range(context.n_slots if context.n_slots else 1)])
case 'busy': case 'busy':
await wait_for_health_status(context, context.base_url, 503, await wait_for_health_status(context, context.base_url, 503,
'no slot available', 'no slot available',
@ -128,7 +149,8 @@ async def step_wait_for_the_server_to_be_started(context, expecting_status):
slots_idle=0, slots_idle=0,
slots_processing=context.n_slots, slots_processing=context.n_slots,
expected_slots=[{'id': slot_id, 'state': 1} expected_slots=[{'id': slot_id, 'state': 1}
for slot_id in range(context.n_slots)]) for slot_id in
range(context.n_slots if context.n_slots else 1)])
case _: case _:
assert False, "unknown status" assert False, "unknown status"
@ -157,24 +179,24 @@ async def step_request_completion(context, api_error):
context.base_url, context.base_url,
debug=context.debug, debug=context.debug,
n_predict=context.n_predict, n_predict=context.n_predict,
server_seed=context.server_seed, seed=await completions_seed(context),
expect_api_error=expect_api_error, expect_api_error=expect_api_error,
user_api_key=context.user_api_key) user_api_key=context.user_api_key)
context.tasks_result.append(completion) context.tasks_result.append(completion)
if context.debug: if context.debug:
print(f"Completion response: {completion}") print(f"Completion response: {completion}\n")
if expect_api_error: if expect_api_error:
assert completion == 401, f"completion must be an 401 status code: {completion}" assert completion == 401, f"completion must be an 401 status code: {completion}"
@step(u'{predicted_n} tokens are predicted matching {re_content}') @step(u'{predicted_n:d} tokens are predicted matching {re_content}')
def step_n_tokens_predicted_with_content(context, predicted_n, re_content): def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content) assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n, re_content)
@step(u'{predicted_n} tokens are predicted') @step(u'{predicted_n:d} tokens are predicted')
def step_n_tokens_predicted(context, predicted_n): def step_n_tokens_predicted(context, predicted_n):
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n)) assert_n_tokens_predicted(context.tasks_result.pop(), predicted_n)
@step(u'a user prompt {user_prompt}') @step(u'a user prompt {user_prompt}')
@ -192,9 +214,9 @@ def step_model(context, model):
context.model = model context.model = model
@step(u'{max_tokens} max tokens to predict') @step(u'{max_tokens:d} max tokens to predict')
def step_max_tokens(context, max_tokens): def step_max_tokens(context, max_tokens):
context.n_predict = int(max_tokens) context.n_predict = max_tokens
@step(u'streaming is {enable_streaming}') @step(u'streaming is {enable_streaming}')
@ -222,11 +244,70 @@ def step_server_api_key(context, server_api_key):
context.server_api_key = server_api_key context.server_api_key = server_api_key
@step(u'{n_junk:d} as number of junk')
def step_n_junk(context, n_junk):
context.n_junk = n_junk
@step(u'{n_batch:d} as batch size')
def step_n_batch(context, n_batch):
context.n_batch = n_batch
@step(u'{seed:d} as seed')
def step_seed(context, seed):
context.seed = seed
@step(u'a prefix prompt')
def step_prompt_prefix(context):
context.prompt_prefix = context.text
@step(u'a junk suffix prompt')
def step_prompt_junk_suffix(context):
context.prompt_junk_suffix = context.text
@step(u'a suffix prompt')
def step_prompt_suffix(context):
context.prompt_suffix = context.text
@step(u'{n_ga:d} group attention factor'
u' to extend context size through self-extend')
def step_impl(context, n_ga):
context.n_ga = n_ga
@step(u'{n_ga_w:d} group attention width to extend context size through self-extend')
def step_impl(context, n_ga_w):
context.n_ga_w = n_ga_w
@step(u'a passkey prompt template')
def step_prompt_passkey(context):
context.prompt_passkey = context.text
@step(u'a "{passkey}" passkey challenge prompt with the passkey inserted every {i_pos:d} junk')
def step_prompt_passkey(context, passkey, i_pos):
prompt = ""
for i in range(context.n_junk):
if i % context.n_junk == i_pos:
prompt += context.prompt_passkey # the passkey is already substituted
prompt += context.prompt_junk_suffix
if context.debug:
passkey_highlight = "\x1b[33m" + passkey + "\x1b[0m"
print(f"Passkey challenge:\n```{prompt.replace(passkey, passkey_highlight)}```\n")
context.prompts.append(context.prompt_prefix + prompt + context.prompt_suffix)
@step(u'an OAI compatible chat completions request with {api_error} api error') @step(u'an OAI compatible chat completions request with {api_error} api error')
@async_run_until_complete @async_run_until_complete
async def step_oai_chat_completions(context, api_error): async def step_oai_chat_completions(context, api_error):
if context.debug: if context.debug:
print(f"Submitting OAI compatible completions request...") print(f"Submitting OAI compatible completions request...\n")
expect_api_error = api_error == 'raised' expect_api_error = api_error == 'raised'
completion = await oai_chat_completions(context.prompts.pop(), completion = await oai_chat_completions(context.prompts.pop(),
context.system_prompt, context.system_prompt,
@ -241,8 +322,7 @@ async def step_oai_chat_completions(context, api_error):
enable_streaming=context.enable_streaming enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
server_seed=context.server_seed seed=await completions_seed(context),
if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None, if hasattr(context, 'user_api_key') else None,
@ -276,8 +356,10 @@ async def step_concurrent_completion_requests(context):
# prompt is inserted automatically # prompt is inserted automatically
context.base_url, context.base_url,
debug=context.debug, debug=context.debug,
prompt_prefix=context.prompt_prefix,
prompt_suffix=context.prompt_suffix,
n_predict=context.n_predict if hasattr(context, 'n_predict') else None, n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
server_seed=context.server_seed if hasattr(context, 'server_seed') else None, seed=await completions_seed(context),
user_api_key=context.user_api_key if hasattr(context, user_api_key=context.user_api_key if hasattr(context,
'user_api_key') else None) 'user_api_key') else None)
@ -297,8 +379,7 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'n_predict') else None, if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
server_seed=context.server_seed seed=await completions_seed(context),
if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None) if hasattr(context, 'user_api_key') else None)
@ -318,7 +399,9 @@ async def step_oai_chat_completions(context):
if hasattr(context, 'n_predict') else None, if hasattr(context, 'n_predict') else None,
enable_streaming=context.enable_streaming enable_streaming=context.enable_streaming
if hasattr(context, 'enable_streaming') else None, if hasattr(context, 'enable_streaming') else None,
server_seed=context.server_seed seed=context.seed
if hasattr(context, 'seed') else
context.server_seed
if hasattr(context, 'server_seed') else None, if hasattr(context, 'server_seed') else None,
user_api_key=context.user_api_key user_api_key=context.user_api_key
if hasattr(context, 'user_api_key') else None) if hasattr(context, 'user_api_key') else None)
@ -330,11 +413,10 @@ async def step_all_prompts_are_predicted(context):
await all_prompts_are_predicted(context) await all_prompts_are_predicted(context)
@step(u'all prompts are predicted with {n_predict} tokens') @step(u'all prompts are predicted with {n_expected_predicted:d} tokens')
@async_run_until_complete @async_run_until_complete
async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict): async def step_all_prompts_are_predicted_with_n_tokens(context, n_expected_predicted):
expected_predicted_n = int(n_predict) await all_prompts_are_predicted(context, n_expected_predicted)
await all_prompts_are_predicted(context, expected_predicted_n)
async def all_prompts_are_predicted(context, expected_predicted_n=None): async def all_prompts_are_predicted(context, expected_predicted_n=None):
@ -464,6 +546,8 @@ async def step_prometheus_metrics_exported(context):
assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4" assert metrics_response.headers['Content-Type'] == "text/plain; version=0.0.4"
metrics_raw = await metrics_response.text() metrics_raw = await metrics_response.text()
metric_exported = False metric_exported = False
if context.debug:
print(f"/metrics answer:\n{metrics_raw}\n")
for metric in parser.text_string_to_metric_families(metrics_raw): for metric in parser.text_string_to_metric_families(metrics_raw):
match metric.name: match metric.name:
case "llamacpp:kv_cache_usage_ratio": case "llamacpp:kv_cache_usage_ratio":
@ -472,6 +556,37 @@ async def step_prometheus_metrics_exported(context):
assert metric_exported, "No metrics exported" assert metric_exported, "No metrics exported"
@step(u'available models')
def step_available_models(context):
# openai client always expects an api_key
openai.api_key = context.user_api_key if context.user_api_key is not None else 'nope'
openai.api_base = f'{context.base_url}/v1'
context.models = openai.Model.list().data
@step(u'{n_model:d} models are supported')
def step_supported_models(context, n_model):
if context.debug:
print("server models available:", context.models)
assert len(context.models) == n_model
@step(u'model {i_model:d} is {param} {preposition} {param_value}')
def step_supported_models(context, i_model, param, preposition, param_value):
assert i_model < len(context.models)
model = context.models[i_model]
param_value = param_value.split(' ', 1)[0]
match param:
case 'identified':
value = model.id
case 'trained':
value = str(model.meta.n_ctx_train)
case _:
assert False, "param {param} not supported"
assert param_value == value, f"model param {param} {value} != {param_value}"
async def concurrent_requests(context, f_completion, *args, **kwargs): async def concurrent_requests(context, f_completion, *args, **kwargs):
n_prompts = len(context.prompts) n_prompts = len(context.prompts)
if context.debug: if context.debug:
@ -486,8 +601,10 @@ async def concurrent_requests(context, f_completion, *args, **kwargs):
async def request_completion(prompt, async def request_completion(prompt,
base_url, base_url,
debug=False, debug=False,
prompt_prefix=None,
prompt_suffix=None,
n_predict=None, n_predict=None,
server_seed=None, seed=None,
expect_api_error=None, expect_api_error=None,
user_api_key=None): user_api_key=None):
if debug: if debug:
@ -504,11 +621,14 @@ async def request_completion(prompt,
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(f'{base_url}/completion', async with session.post(f'{base_url}/completion',
json={ json={
"input_prefix": prompt_prefix,
"prompt": prompt, "prompt": prompt,
"n_predict": int(n_predict) if n_predict is not None else -1, "input_suffix": prompt_suffix,
"seed": server_seed if server_seed is not None else 42 "n_predict": n_predict if n_predict is not None else -1,
"seed": seed if seed is not None else 42
}, },
headers=headers) as response: headers=headers,
timeout=3600) as response:
if expect_api_error is None or not expect_api_error: if expect_api_error is None or not expect_api_error:
assert response.status == 200 assert response.status == 200
assert response.headers['Access-Control-Allow-Origin'] == origin assert response.headers['Access-Control-Allow-Origin'] == origin
@ -526,14 +646,14 @@ async def oai_chat_completions(user_prompt,
model=None, model=None,
n_predict=None, n_predict=None,
enable_streaming=None, enable_streaming=None,
server_seed=None, seed=None,
user_api_key=None, user_api_key=None,
expect_api_error=None): expect_api_error=None):
if debug: if debug:
print(f"Sending OAI Chat completions request: {user_prompt}") print(f"Sending OAI Chat completions request: {user_prompt}")
# openai client always expects an api key # openai client always expects an api key
user_api_key = user_api_key if user_api_key is not None else 'nope' user_api_key = user_api_key if user_api_key is not None else 'nope'
seed = server_seed if server_seed is not None else 42 seed = seed if seed is not None else 42
enable_streaming = enable_streaming if enable_streaming is not None else False enable_streaming = enable_streaming if enable_streaming is not None else False
payload = { payload = {
"messages": [ "messages": [
@ -692,20 +812,32 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
content = completion_response['content'] content = completion_response['content']
n_predicted = completion_response['timings']['predicted_n'] n_predicted = completion_response['timings']['predicted_n']
assert len(content) > 0, "no token predicted" assert len(content) > 0, "no token predicted"
if expected_predicted_n is not None: if re_content is not None:
p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
matches = p.finditer(content)
last_match = 0
highlighted = ''
for match in matches:
start, end = match.span()
highlighted += content[last_match: start]
highlighted += '\x1b[33m'
highlighted += content[start: end]
highlighted += '\x1b[0m'
last_match = end
highlighted += content[last_match:]
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
print(f"Checking completion response: {highlighted}\n")
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
if expected_predicted_n and expected_predicted_n > 0:
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:' assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
f' {n_predicted} <> {expected_predicted_n}') f' {n_predicted} <> {expected_predicted_n}')
if re_content is not None:
re_content = '^.*' + re_content.replace('<or>', '|') + '.*$'
assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), (
f'invalid tokens predicted:'
f' ```\n{content}\n``` do not match /{re_content}/')
async def gather_tasks_results(context): async def gather_tasks_results(context):
n_tasks = len(context.concurrent_tasks) n_tasks = len(context.concurrent_tasks)
if context.debug: if context.debug:
print(f"Waiting for all {n_tasks} tasks results...") print(f"Waiting for all {n_tasks} tasks results...\n")
for task_no in range(n_tasks): for task_no in range(n_tasks):
context.tasks_result.append(await context.concurrent_tasks.pop()) context.tasks_result.append(await context.concurrent_tasks.pop())
n_completions = len(context.tasks_result) n_completions = len(context.tasks_result)
@ -716,15 +848,13 @@ async def wait_for_health_status(context,
base_url, base_url,
expected_http_status_code, expected_http_status_code,
expected_health_status, expected_health_status,
timeout=3,
params=None, params=None,
slots_idle=None, slots_idle=None,
slots_processing=None, slots_processing=None,
expected_slots=None): expected_slots=None):
if context.debug: if context.debug:
print(f"Starting checking for health for expected_health_status={expected_health_status}") print(f"Starting checking for health for expected_health_status={expected_health_status}\n")
timeout = 3 # seconds
if expected_health_status == 'ok':
timeout = 10 # CI slow inference
interval = 0.5 interval = 0.5
counter = 0 counter = 0
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
@ -734,7 +864,7 @@ async def wait_for_health_status(context,
health = await health_response.json() health = await health_response.json()
if context.debug: if context.debug:
print(f"HEALTH - response for expected health status='{expected_health_status}' on " print(f"HEALTH - response for expected health status='{expected_health_status}' on "
f"'{base_url}/health'?{params} is {health}") f"'{base_url}/health'?{params} is {health}\n")
if (status_code == expected_http_status_code if (status_code == expected_http_status_code
and health['status'] == expected_health_status and health['status'] == expected_health_status
and (slots_idle is None or health['slots_idle'] == slots_idle) and (slots_idle is None or health['slots_idle'] == slots_idle)
@ -757,7 +887,7 @@ async def wait_for_health_status(context,
if expected_http_status_code == 503: if expected_http_status_code == 503:
if len(context.tasks_result) == 0: if len(context.tasks_result) == 0:
print("\x1b[5;37;43mWARNING: forcing concurrent tasks," print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
" busy health check missed, probably too fast inference\x1b[0m") " busy health check missed, probably too fast inference\x1b[0m\n")
n_completions = await gather_tasks_results(context) n_completions = await gather_tasks_results(context)
if n_completions > 0: if n_completions > 0:
return return
@ -791,6 +921,11 @@ def assert_slots_status(slots, expected_slots):
f" = {expected[key]} != {slot[key]}") f" = {expected[key]} != {slot[key]}")
async def completions_seed(context):
return context.seed if hasattr(context, 'seed') and context.seed is not None \
else context.server_seed if hasattr(context, 'server_seed') else None
def start_server_background(context): def start_server_background(context):
context.server_path = '../../../build/bin/server' context.server_path = '../../../build/bin/server'
if 'LLAMA_SERVER_BIN_PATH' in os.environ: if 'LLAMA_SERVER_BIN_PATH' in os.environ:
@ -800,27 +935,35 @@ def start_server_background(context):
'--port', context.server_port, '--port', context.server_port,
'--model', context.model_file '--model', context.model_file
] ]
if context.n_batch:
server_args.extend(['--batch-size', context.n_batch])
if context.n_gpu_layer:
server_args.extend(['--n-gpu-layers', context.n_gpu_layer])
if context.server_continuous_batching: if context.server_continuous_batching:
server_args.append('--cont-batching') server_args.append('--cont-batching')
if context.server_embeddings: if context.server_embeddings:
server_args.append('--embedding') server_args.append('--embedding')
if context.server_metrics: if context.server_metrics:
server_args.append('--metrics') server_args.append('--metrics')
if context.model_alias is not None: if context.model_alias:
server_args.extend(['--alias', context.model_alias]) server_args.extend(['--alias', context.model_alias])
if context.n_ctx is not None: if context.n_ctx:
server_args.extend(['--ctx-size', context.n_ctx]) server_args.extend(['--ctx-size', context.n_ctx])
if context.n_slots is not None: if context.n_slots:
server_args.extend(['--parallel', context.n_slots]) server_args.extend(['--parallel', context.n_slots])
if context.n_server_predict is not None: if context.n_server_predict:
server_args.extend(['--n-predict', context.n_server_predict]) server_args.extend(['--n-predict', context.n_server_predict])
if context.server_api_key is not None: if context.server_api_key:
server_args.extend(['--api-key', context.server_api_key]) server_args.extend(['--api-key', context.server_api_key])
if context.n_ga:
server_args.extend(['--grp-attn-n', context.n_ga])
if context.n_ga_w:
server_args.extend(['--grp-attn-w', context.n_ga_w])
if context.debug: if context.debug:
server_args.append('--verbose') server_args.append('--verbose')
if 'SERVER_LOG_FORMAT_JSON' not in os.environ: if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
server_args.extend(['--log-format', "text"]) server_args.extend(['--log-format', "text"])
print(f"starting server with: {context.server_path}", *server_args) print(f"starting server with: {context.server_path} {server_args}\n")
context.server_process = subprocess.Popen( context.server_process = subprocess.Popen(
[str(arg) for arg in [context.server_path, *server_args]], [str(arg) for arg in [context.server_path, *server_args]],
close_fds=True) close_fds=True)

View File

@ -1,4 +1,4 @@
# run with ./test.sh --tags wrong_usage # run with: ./tests.sh --no-skipped --tags wrong_usage
@wrong_usage @wrong_usage
Feature: Wrong usage of llama.cpp server Feature: Wrong usage of llama.cpp server
@ -7,7 +7,7 @@ Feature: Wrong usage of llama.cpp server
# or pass n_predict/max_tokens in the request. # or pass n_predict/max_tokens in the request.
Scenario: Infinite loop Scenario: Infinite loop
Given a server listening on localhost:8080 Given a server listening on localhost:8080
And a model file stories260K.gguf And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models
# Uncomment below to fix the issue # Uncomment below to fix the issue
#And 64 server max tokens to predict #And 64 server max tokens to predict
Then the server is starting Then the server is starting
@ -18,4 +18,5 @@ Feature: Wrong usage of llama.cpp server
# Uncomment below to fix the issue # Uncomment below to fix the issue
#And 128 max tokens to predict #And 128 max tokens to predict
Given concurrent completion requests Given concurrent completion requests
Then the server is idle
Then all prompts are predicted Then all prompts are predicted

View File

@ -1,4 +1,5 @@
aiohttp~=3.9.3 aiohttp~=3.9.3
behave~=1.2.6 behave~=1.2.6
huggingface_hub~=0.20.3
openai~=0.25.0 openai~=0.25.0
prometheus-client~=0.20.0 prometheus-client~=0.20.0

View File

@ -5,7 +5,7 @@ set -eu
if [ $# -lt 1 ] if [ $# -lt 1 ]
then then
# Start @llama.cpp scenario # Start @llama.cpp scenario
behave --summary --stop --no-capture --exclude 'issues|wrong_usages' --tags llama.cpp behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
else else
behave "$@" behave "$@"
fi fi

View File

@ -126,8 +126,7 @@ static inline void server_log(const char *level, const char *function, int line,
for (const auto& el : log.items()) for (const auto& el : log.items())
{ {
const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace); const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
snprintf(buf, 1024, " %s=%s", el.key().c_str(), value.c_str()); ss << " " << el.key() << "=" << value;
ss << buf;
} }
const std::string str = ss.str(); const std::string str = ss.str();