mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
server: init functional tests (#5566)
* server: tests: init scenarios - health and slots endpoints - completion endpoint - OAI compatible chat completion requests w/ and without streaming - completion multi users scenario - multi users scenario on OAI compatible endpoint with streaming - multi users with total number of tokens to predict exceeds the KV Cache size - server wrong usage scenario, like in Infinite loop of "context shift" #3969 - slots shifting - continuous batching - embeddings endpoint - multi users embedding endpoint: Segmentation fault #5655 - OpenAI-compatible embeddings API - tokenize endpoint - CORS and api key scenario * server: CI GitHub workflow --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
fd43d66f46
commit
525213d2f5
2
.github/ISSUE_TEMPLATE/bug.md
vendored
2
.github/ISSUE_TEMPLATE/bug.md
vendored
@ -7,3 +7,5 @@ assignees: ''
|
||||
---
|
||||
|
||||
Please include information about your system, the steps to reproduce the bug, and the version of llama.cpp that you are using. If possible, please provide a minimal code example that reproduces the bug.
|
||||
|
||||
If the bug concerns the server, please try to reproduce it first using the [server test scenario framework](https://github.com/ggerganov/llama.cpp/tree/master/examples/server/tests).
|
||||
|
127
.github/workflows/server.yml
vendored
Normal file
127
.github/workflows/server.yml
vendored
Normal file
@ -0,0 +1,127 @@
|
||||
# Server build and tests
|
||||
name: Server
|
||||
|
||||
on:
|
||||
workflow_dispatch: # allows manual triggering
|
||||
push:
|
||||
branches:
|
||||
- master
|
||||
- test/server-add-ci-test # FIXME remove
|
||||
paths: ['.github/workflows/**', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*']
|
||||
pull_request:
|
||||
types: [opened, synchronize, reopened]
|
||||
paths: ['**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'examples/server/**.*']
|
||||
|
||||
jobs:
|
||||
server:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
strategy:
|
||||
matrix:
|
||||
build: [noavx, avx2, avx, avx512, cublas, clblast, openblas, kompute, vulkan]
|
||||
sanitizer: [ADDRESS, THREAD, UNDEFINED]
|
||||
build_type: [Debug, Release]
|
||||
include:
|
||||
- build: 'noavx'
|
||||
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF'
|
||||
image: ubuntu:latest
|
||||
- build: 'avx2'
|
||||
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON'
|
||||
image: ubuntu:latest
|
||||
- build: 'avx'
|
||||
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_AVX2=OFF'
|
||||
image: ubuntu:latest
|
||||
- build: 'avx512'
|
||||
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_AVX512=ON'
|
||||
image: ubuntu:latest
|
||||
experimental: true
|
||||
- build: 'cublas'
|
||||
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_CUBLAS=ON'
|
||||
image: nvidia/cuda:12.3.1-devel-ubuntu22.04
|
||||
arch_not_available: true # require nvidia docker engine
|
||||
- build: 'clblast'
|
||||
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_CLBLAST=ON'
|
||||
image: ubuntu:latest
|
||||
arch_not_available: true
|
||||
- build: 'openblas'
|
||||
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=OpenBLAS'
|
||||
image: ubuntu:latest
|
||||
- build: 'kompute'
|
||||
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_KOMPUTE=ON -DKOMPUTE_OPT_DISABLE_VULKAN_VERSION_CHECK=ON'
|
||||
image: ubuntu:latest
|
||||
arch_not_available: true
|
||||
- build: 'vulkan'
|
||||
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_VULKAN=ON'
|
||||
image: ubuntu:latest
|
||||
arch_not_available: true
|
||||
|
||||
container:
|
||||
image: ${{ matrix.image }}
|
||||
ports:
|
||||
- 8888
|
||||
options: --cpus 4
|
||||
|
||||
steps:
|
||||
- name: Clone
|
||||
id: checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Dependencies
|
||||
id: depends
|
||||
run: |
|
||||
apt-get update
|
||||
apt-get -y install \
|
||||
build-essential \
|
||||
pkg-config \
|
||||
git \
|
||||
cmake \
|
||||
python3-pip \
|
||||
wget \
|
||||
psmisc
|
||||
|
||||
- name: Download CLBlast
|
||||
id: get_clblast
|
||||
if: ${{ matrix.build == 'clblast' }}
|
||||
run: |
|
||||
apt install -y libclblast-dev
|
||||
|
||||
- name: Download OpenBLAS
|
||||
id: get_openblas
|
||||
if: ${{ matrix.build == 'openblas' }}
|
||||
run: |
|
||||
apt-get -y install libopenblas-dev
|
||||
|
||||
- name: Install Vulkan SDK
|
||||
id: get_vulkan
|
||||
if: ${{ matrix.build == 'kompute' || matrix.build == 'vulkan' }}
|
||||
run: |
|
||||
wget -qO- https://packages.lunarg.com/lunarg-signing-key-pub.asc | tee /etc/apt/trusted.gpg.d/lunarg.asc
|
||||
wget -qO /etc/apt/sources.list.d/lunarg-vulkan-jammy.list http://packages.lunarg.com/vulkan/lunarg-vulkan-jammy.list
|
||||
apt-get update
|
||||
apt-get -y install vulkan-sdk
|
||||
|
||||
- name: Build
|
||||
id: cmake_build
|
||||
run: |
|
||||
mkdir build
|
||||
cd build
|
||||
cmake .. -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ${{ matrix.defines }}
|
||||
cmake --build . --config ${{ matrix.build_type }} -j $(nproc) --target server
|
||||
|
||||
- name: Tests dependencies
|
||||
id: test_dependencies
|
||||
run: |
|
||||
pip install -r examples/server/tests/requirements.txt
|
||||
|
||||
- name: Download models
|
||||
id: download_models
|
||||
run: |
|
||||
cd examples/server/tests
|
||||
../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf
|
||||
|
||||
- name: Tests
|
||||
id: server_integration_test
|
||||
continue-on-error: ${{ matrix.experimental || matrix.arch_not_available }}
|
||||
run: |
|
||||
cd examples/server/tests
|
||||
PORT=8888 ./tests.sh
|
@ -98,6 +98,12 @@ curl --request POST \
|
||||
--data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}'
|
||||
```
|
||||
|
||||
## Advanced testing
|
||||
|
||||
We implemented a [server test framework](./tests/README.md) using human-readable scenario.
|
||||
|
||||
*Before submitting an issue, please try to reproduce it with this format.*
|
||||
|
||||
## Node JS Test
|
||||
|
||||
You need to have [Node.js](https://nodejs.org/en) installed.
|
||||
|
@ -1410,11 +1410,6 @@ struct llama_server_context
|
||||
int n_processing_slots = 0;
|
||||
|
||||
for (llama_client_slot &slot: slots) {
|
||||
if (slot.available()) {
|
||||
n_idle_slots++;
|
||||
} else {
|
||||
n_processing_slots++;
|
||||
}
|
||||
json slot_data = get_formated_generation(slot);
|
||||
slot_data["id"] = slot.id;
|
||||
slot_data["task_id"] = slot.task_id;
|
||||
@ -1429,6 +1424,11 @@ struct llama_server_context
|
||||
{"stopped_limit", slot.stopped_limit},
|
||||
{"stopping_word", slot.stopping_word},
|
||||
};
|
||||
if (slot_data["state"] == IDLE) {
|
||||
n_idle_slots++;
|
||||
} else {
|
||||
n_processing_slots++;
|
||||
}
|
||||
slots_data.push_back(slot_data);
|
||||
}
|
||||
LOG_TEE("task %i - slots data: idle=%i processing=%i\n", task.id, n_idle_slots, n_processing_slots);
|
||||
@ -2748,19 +2748,6 @@ int main(int argc, char **argv)
|
||||
log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
|
||||
}
|
||||
|
||||
LOG_INFO("HTTP server listening", log_data);
|
||||
// run the HTTP server in a thread - see comment below
|
||||
std::thread t([&]()
|
||||
{
|
||||
if (!svr.listen_after_bind())
|
||||
{
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
});
|
||||
|
||||
// load the model
|
||||
if (!llama.load_model(params))
|
||||
{
|
||||
@ -3228,6 +3215,19 @@ int main(int argc, char **argv)
|
||||
}*/
|
||||
//);
|
||||
|
||||
LOG_INFO("HTTP server listening", log_data);
|
||||
// run the HTTP server in a thread - see comment below
|
||||
std::thread t([&]()
|
||||
{
|
||||
if (!svr.listen_after_bind())
|
||||
{
|
||||
state.store(SERVER_STATE_ERROR);
|
||||
return 1;
|
||||
}
|
||||
|
||||
return 0;
|
||||
});
|
||||
|
||||
llama.queue_tasks.on_new_task(std::bind(
|
||||
&llama_server_context::process_single_task, &llama, std::placeholders::_1));
|
||||
llama.queue_tasks.on_finish_multitask(std::bind(
|
||||
|
46
examples/server/tests/README.md
Normal file
46
examples/server/tests/README.md
Normal file
@ -0,0 +1,46 @@
|
||||
# Server tests
|
||||
|
||||
Python based server tests scenario using [BDD](https://en.wikipedia.org/wiki/Behavior-driven_development) and [behave](https://behave.readthedocs.io/en/latest/):
|
||||
* [issues.feature](./features/issues.feature) Pending issues scenario
|
||||
* [parallel.feature](./features/parallel.feature) Scenario involving multi slots and concurrent requests
|
||||
* [security.feature](./features/security.feature) Security, CORS and API Key
|
||||
* [server.feature](./features/server.feature) Server base scenario: completion, embedding, tokenization, etc...
|
||||
|
||||
Tests target GitHub workflows job runners with 4 vCPU.
|
||||
|
||||
Requests are using [aiohttp](https://docs.aiohttp.org/en/stable/client_reference.html), [asyncio](https://docs.python.org/fr/3/library/asyncio.html) based http client.
|
||||
|
||||
Note: If the host architecture inference speed is faster than GitHub runners one, parallel scenario may randomly fail. To mitigate it, you can increase values in `n_predict`, `kv_size`.
|
||||
|
||||
### Install dependencies
|
||||
`pip install -r requirements.txt`
|
||||
|
||||
### Run tests
|
||||
1. Build the server
|
||||
```shell
|
||||
cd ../../..
|
||||
mkdir build
|
||||
cd build
|
||||
cmake ../
|
||||
cmake --build . --target server
|
||||
```
|
||||
2. download required models:
|
||||
1. `../../../scripts/hf.sh --repo ggml-org/models --file tinyllamas/stories260K.gguf`
|
||||
3. Start the test: `./tests.sh`
|
||||
|
||||
It's possible to override some scenario steps values with environment variables:
|
||||
- `PORT` -> `context.server_port` to set the listening port of the server during scenario, default: `8080`
|
||||
- `LLAMA_SERVER_BIN_PATH` -> to change the server binary path, default: `../../../build/bin/server`
|
||||
- `DEBUG` -> "ON" to enable steps and server verbose mode `--verbose`
|
||||
|
||||
### Run @bug, @wip or @wrong_usage annotated scenario
|
||||
|
||||
Feature or Scenario must be annotated with `@llama.cpp` to be included in the default scope.
|
||||
- `@bug` annotation aims to link a scenario with a GitHub issue.
|
||||
- `@wrong_usage` are meant to show user issue that are actually an expected behavior
|
||||
- `@wip` to focus on a scenario working in progress
|
||||
|
||||
To run a scenario annotated with `@bug`, start:
|
||||
`DEBUG=ON ./tests.sh --no-skipped --tags bug`
|
||||
|
||||
After changing logic in `steps.py`, ensure that `@bug` and `@wrong_usage` scenario are updated.
|
67
examples/server/tests/features/environment.py
Normal file
67
examples/server/tests/features/environment.py
Normal file
@ -0,0 +1,67 @@
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from contextlib import closing
|
||||
from signal import SIGKILL
|
||||
|
||||
|
||||
def before_scenario(context, scenario):
|
||||
print(f"\x1b[33;42mStarting new scenario: {scenario.name}!\x1b[0m")
|
||||
port = 8080
|
||||
if 'PORT' in os.environ:
|
||||
port = int(os.environ['PORT'])
|
||||
if is_server_listening("localhost", port):
|
||||
assert False, "Server already started"
|
||||
|
||||
|
||||
def after_scenario(context, scenario):
|
||||
if scenario.status == "failed":
|
||||
if 'GITHUB_ACTIONS' in os.environ:
|
||||
print(f"\x1b[33;101mSCENARIO FAILED: {scenario.name} server logs:\x1b[0m\n\n")
|
||||
if os.path.isfile('llama.log'):
|
||||
with closing(open('llama.log', 'r')) as f:
|
||||
for line in f:
|
||||
print(line)
|
||||
if not is_server_listening(context.server_fqdn, context.server_port):
|
||||
print("\x1b[33;101mERROR: Server stopped listening\x1b[0m")
|
||||
|
||||
if not pid_exists(context.server_process.pid):
|
||||
assert False, f"Server not running pid={context.server_process.pid} ..."
|
||||
|
||||
print(f"stopping server pid={context.server_process.pid} ...")
|
||||
context.server_process.kill()
|
||||
# Wait few for socket to free up
|
||||
time.sleep(0.05)
|
||||
|
||||
attempts = 0
|
||||
while is_server_listening(context.server_fqdn, context.server_port):
|
||||
print(f"stopping server pid={context.server_process.pid} ...")
|
||||
os.kill(context.server_process.pid, SIGKILL)
|
||||
time.sleep(0.1)
|
||||
attempts += 1
|
||||
if attempts > 5:
|
||||
print(f"Server dangling exits, killing all {context.server_path} ...")
|
||||
process = subprocess.run(['killall', '-9', context.server_path],
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True)
|
||||
print(process)
|
||||
|
||||
|
||||
def is_server_listening(server_fqdn, server_port):
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
result = sock.connect_ex((server_fqdn, server_port))
|
||||
return result == 0
|
||||
|
||||
|
||||
def pid_exists(pid):
|
||||
"""Check whether pid exists in the current process table."""
|
||||
import errno
|
||||
if pid < 0:
|
||||
return False
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except OSError as e:
|
||||
return e.errno == errno.EPERM
|
||||
else:
|
||||
return True
|
36
examples/server/tests/features/issues.feature
Normal file
36
examples/server/tests/features/issues.feature
Normal file
@ -0,0 +1,36 @@
|
||||
# List of ongoing issues
|
||||
@bug
|
||||
Feature: Issues
|
||||
# Issue #5655
|
||||
Scenario: Multi users embeddings
|
||||
Given a server listening on localhost:8080
|
||||
And a model file stories260K.gguf
|
||||
And a model alias tinyllama-2
|
||||
And 42 as server seed
|
||||
And 64 KV cache size
|
||||
And 2 slots
|
||||
And continuous batching
|
||||
And embeddings extraction
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
Given a prompt:
|
||||
"""
|
||||
Write a very long story about AI.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write another very long music lyrics.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write a very long poem.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write a very long joke.
|
||||
"""
|
||||
Given concurrent embedding requests
|
||||
Then the server is busy
|
||||
Then the server is idle
|
||||
Then all embeddings are generated
|
77
examples/server/tests/features/parallel.feature
Normal file
77
examples/server/tests/features/parallel.feature
Normal file
@ -0,0 +1,77 @@
|
||||
@llama.cpp
|
||||
Feature: Parallel
|
||||
|
||||
Background: Server startup
|
||||
Given a server listening on localhost:8080
|
||||
And a model file stories260K.gguf
|
||||
And a model alias tinyllama-2
|
||||
And 42 as server seed
|
||||
And 64 KV cache size
|
||||
And 2 slots
|
||||
And continuous batching
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
Scenario Outline: Multi users completion
|
||||
Given a prompt:
|
||||
"""
|
||||
Write a very long story about AI.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write another very long music lyrics.
|
||||
"""
|
||||
And <n_predict> max tokens to predict
|
||||
Given concurrent completion requests
|
||||
Then the server is busy
|
||||
Then the server is idle
|
||||
And all slots are idle
|
||||
Then all prompts are predicted with <n_predict> tokens
|
||||
Examples:
|
||||
| n_predict |
|
||||
| 128 |
|
||||
|
||||
Scenario Outline: Multi users OAI completions compatibility
|
||||
Given a system prompt You are a writer.
|
||||
And a model tinyllama-2
|
||||
Given a prompt:
|
||||
"""
|
||||
Write a very long book.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write another a poem.
|
||||
"""
|
||||
And <n_predict> max tokens to predict
|
||||
And streaming is <streaming>
|
||||
Given concurrent OAI completions requests
|
||||
Then the server is busy
|
||||
Then the server is idle
|
||||
Then all prompts are predicted with <n_predict> tokens
|
||||
Examples:
|
||||
| streaming | n_predict |
|
||||
| disabled | 128 |
|
||||
| enabled | 64 |
|
||||
|
||||
Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969
|
||||
Given a prompt:
|
||||
"""
|
||||
Write a very long story about AI.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write another very long music lyrics.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write a very long poem.
|
||||
"""
|
||||
And a prompt:
|
||||
"""
|
||||
Write a very long joke.
|
||||
"""
|
||||
And 128 max tokens to predict
|
||||
Given concurrent completion requests
|
||||
Then the server is busy
|
||||
Then the server is idle
|
||||
Then all prompts are predicted
|
50
examples/server/tests/features/security.feature
Normal file
50
examples/server/tests/features/security.feature
Normal file
@ -0,0 +1,50 @@
|
||||
@llama.cpp
|
||||
Feature: Security
|
||||
|
||||
Background: Server startup with an api key defined
|
||||
Given a server listening on localhost:8080
|
||||
And a model file stories260K.gguf
|
||||
And a server api key llama.cpp
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
Scenario Outline: Completion with some user api key
|
||||
Given a prompt test
|
||||
And a user api key <api_key>
|
||||
And 4 max tokens to predict
|
||||
And a completion request with <api_error> api error
|
||||
|
||||
Examples: Prompts
|
||||
| api_key | api_error |
|
||||
| llama.cpp | no |
|
||||
| llama.cpp | no |
|
||||
| hackeme | raised |
|
||||
| | raised |
|
||||
|
||||
Scenario Outline: OAI Compatibility
|
||||
Given a system prompt test
|
||||
And a user prompt test
|
||||
And a model test
|
||||
And 2 max tokens to predict
|
||||
And streaming is disabled
|
||||
And a user api key <api_key>
|
||||
Given an OAI compatible chat completions request with <api_error> api error
|
||||
|
||||
Examples: Prompts
|
||||
| api_key | api_error |
|
||||
| llama.cpp | no |
|
||||
| llama.cpp | no |
|
||||
| hackme | raised |
|
||||
|
||||
|
||||
Scenario Outline: CORS Options
|
||||
When an OPTIONS request is sent from <origin>
|
||||
Then CORS header <cors_header> is set to <cors_header_value>
|
||||
|
||||
Examples: Headers
|
||||
| origin | cors_header | cors_header_value |
|
||||
| localhost | Access-Control-Allow-Origin | localhost |
|
||||
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
|
||||
| origin | Access-Control-Allow-Credentials | true |
|
||||
| web.mydomain.fr | Access-Control-Allow-Methods | POST |
|
||||
| web.mydomain.fr | Access-Control-Allow-Headers | * |
|
69
examples/server/tests/features/server.feature
Normal file
69
examples/server/tests/features/server.feature
Normal file
@ -0,0 +1,69 @@
|
||||
@llama.cpp
|
||||
Feature: llama.cpp server
|
||||
|
||||
Background: Server startup
|
||||
Given a server listening on localhost:8080
|
||||
And a model file stories260K.gguf
|
||||
And a model alias tinyllama-2
|
||||
And 42 as server seed
|
||||
# KV Cache corresponds to the total amount of tokens
|
||||
# that can be stored across all independent sequences: #4130
|
||||
# see --ctx-size and #5568
|
||||
And 32 KV cache size
|
||||
And 1 slots
|
||||
And embeddings extraction
|
||||
And 32 server max tokens to predict
|
||||
Then the server is starting
|
||||
Then the server is healthy
|
||||
|
||||
Scenario: Health
|
||||
Then the server is ready
|
||||
And all slots are idle
|
||||
|
||||
Scenario Outline: Completion
|
||||
Given a prompt <prompt>
|
||||
And <n_predict> max tokens to predict
|
||||
And a completion request with no api error
|
||||
Then <n_predicted> tokens are predicted matching <re_content>
|
||||
|
||||
Examples: Prompts
|
||||
| prompt | n_predict | re_content | n_predicted |
|
||||
| I believe the meaning of life is | 8 | read | 8 |
|
||||
| Write a joke about AI | 64 | (park<or>friends<or>scared)+ | 32 |
|
||||
|
||||
Scenario Outline: OAI Compatibility
|
||||
Given a model <model>
|
||||
And a system prompt <system_prompt>
|
||||
And a user prompt <user_prompt>
|
||||
And <max_tokens> max tokens to predict
|
||||
And streaming is <enable_streaming>
|
||||
Given an OAI compatible chat completions request with no api error
|
||||
Then <n_predicted> tokens are predicted matching <re_content>
|
||||
|
||||
Examples: Prompts
|
||||
| model | system_prompt | user_prompt | max_tokens | re_content | n_predicted | enable_streaming |
|
||||
| llama-2 | Book | What is the best book | 8 | (Mom<or>what)+ | 8 | disabled |
|
||||
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 64 | (thanks<or>happy<or>bird)+ | 32 | enabled |
|
||||
|
||||
Scenario: Embedding
|
||||
When embeddings are computed for:
|
||||
"""
|
||||
What is the capital of Bulgaria ?
|
||||
"""
|
||||
Then embeddings are generated
|
||||
|
||||
Scenario: OAI Embeddings compatibility
|
||||
Given a model tinyllama-2
|
||||
When an OAI compatible embeddings computation request for:
|
||||
"""
|
||||
What is the capital of Spain ?
|
||||
"""
|
||||
Then embeddings are generated
|
||||
|
||||
|
||||
Scenario: Tokenize / Detokenize
|
||||
When tokenizing:
|
||||
"""
|
||||
What is the capital of France ?
|
||||
"""
|
||||
Then tokens can be detokenize
|
709
examples/server/tests/features/steps/steps.py
Normal file
709
examples/server/tests/features/steps/steps.py
Normal file
@ -0,0 +1,709 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import socket
|
||||
import subprocess
|
||||
import time
|
||||
from contextlib import closing
|
||||
from re import RegexFlag
|
||||
|
||||
import aiohttp
|
||||
import openai
|
||||
from behave import step
|
||||
from behave.api.async_step import async_run_until_complete
|
||||
|
||||
|
||||
@step(u"a server listening on {server_fqdn}:{server_port}")
|
||||
def step_server_config(context, server_fqdn, server_port):
|
||||
context.server_fqdn = server_fqdn
|
||||
context.server_port = int(server_port)
|
||||
if 'PORT' in os.environ:
|
||||
context.server_port = int(os.environ['PORT'])
|
||||
print(f"$PORT set, overriding server port with to {context.server_port}")
|
||||
|
||||
context.base_url = f'http://{context.server_fqdn}:{context.server_port}'
|
||||
|
||||
context.debug = 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON'
|
||||
context.model_alias = None
|
||||
context.n_ctx = None
|
||||
context.n_predict = None
|
||||
context.n_server_predict = None
|
||||
context.n_slots = None
|
||||
context.server_api_key = None
|
||||
context.server_continuous_batching = False
|
||||
context.server_embeddings = False
|
||||
context.server_seed = None
|
||||
context.user_api_key = None
|
||||
|
||||
context.tasks_result = []
|
||||
context.concurrent_tasks = []
|
||||
context.prompts = []
|
||||
|
||||
|
||||
@step(u'a model file {model_file}')
|
||||
def step_model_file(context, model_file):
|
||||
context.model_file = model_file
|
||||
|
||||
|
||||
@step(u'a model alias {model_alias}')
|
||||
def step_model_alias(context, model_alias):
|
||||
context.model_alias = model_alias
|
||||
|
||||
|
||||
@step(u'{seed} as server seed')
|
||||
def step_seed(context, seed):
|
||||
context.server_seed = int(seed)
|
||||
|
||||
|
||||
@step(u'{n_ctx} KV cache size')
|
||||
def step_n_ctx(context, n_ctx):
|
||||
context.n_ctx = int(n_ctx)
|
||||
|
||||
|
||||
@step(u'{n_slots} slots')
|
||||
def step_n_slots(context, n_slots):
|
||||
context.n_slots = int(n_slots)
|
||||
|
||||
|
||||
@step(u'{n_predict} server max tokens to predict')
|
||||
def step_server_n_predict(context, n_predict):
|
||||
context.n_server_predict = int(n_predict)
|
||||
|
||||
|
||||
@step(u'continuous batching')
|
||||
def step_server_continuous_batching(context):
|
||||
context.server_continuous_batching = True
|
||||
|
||||
|
||||
@step(u'embeddings extraction')
|
||||
def step_server_embeddings(context):
|
||||
context.server_embeddings = True
|
||||
|
||||
|
||||
@step(u"the server is starting")
|
||||
def step_start_server(context):
|
||||
start_server_background(context)
|
||||
attempts = 0
|
||||
while True:
|
||||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock:
|
||||
result = sock.connect_ex((context.server_fqdn, context.server_port))
|
||||
if result == 0:
|
||||
print("\x1b[33;46mserver started!\x1b[0m")
|
||||
return
|
||||
attempts += 1
|
||||
if attempts > 20:
|
||||
assert False, "server not started"
|
||||
print(f"waiting for server to start, connect error code = {result}...")
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
@step(u"the server is {expecting_status}")
|
||||
@async_run_until_complete
|
||||
async def step_wait_for_the_server_to_be_started(context, expecting_status):
|
||||
match expecting_status:
|
||||
case 'healthy':
|
||||
await wait_for_health_status(context, context.base_url, 200, 'ok')
|
||||
|
||||
case 'ready' | 'idle':
|
||||
await wait_for_health_status(context, context.base_url, 200, 'ok',
|
||||
params={'fail_on_no_slot': 0, 'include_slots': 0},
|
||||
slots_idle=context.n_slots,
|
||||
slots_processing=0,
|
||||
expected_slots=[{'id': slot_id, 'state': 0}
|
||||
for slot_id in range(context.n_slots)])
|
||||
case 'busy':
|
||||
await wait_for_health_status(context, context.base_url, 503,
|
||||
'no slot available',
|
||||
params={'fail_on_no_slot': 0, 'include_slots': 0},
|
||||
slots_idle=0,
|
||||
slots_processing=context.n_slots,
|
||||
expected_slots=[{'id': slot_id, 'state': 1}
|
||||
for slot_id in range(context.n_slots)])
|
||||
case _:
|
||||
assert False, "unknown status"
|
||||
|
||||
|
||||
@step(u'all slots are {expected_slot_status_string}')
|
||||
@async_run_until_complete
|
||||
async def step_all_slots_status(context, expected_slot_status_string):
|
||||
match expected_slot_status_string:
|
||||
case 'idle':
|
||||
expected_slot_status = 0
|
||||
case 'busy':
|
||||
expected_slot_status = 1
|
||||
case _:
|
||||
assert False, "unknown status"
|
||||
|
||||
expected_slots = [{'id': slot_id, 'state': expected_slot_status}
|
||||
for slot_id in range(context.n_slots)]
|
||||
await request_slots_status(context, expected_slots)
|
||||
|
||||
|
||||
@step(u'a completion request with {api_error} api error')
|
||||
@async_run_until_complete
|
||||
async def step_request_completion(context, api_error):
|
||||
expect_api_error = api_error == 'raised'
|
||||
completion = await request_completion(context.prompts.pop(),
|
||||
context.base_url,
|
||||
debug=context.debug,
|
||||
n_predict=context.n_predict,
|
||||
server_seed=context.server_seed,
|
||||
expect_api_error=expect_api_error,
|
||||
user_api_key=context.user_api_key)
|
||||
context.tasks_result.append(completion)
|
||||
if context.debug:
|
||||
print(f"Completion response: {completion}")
|
||||
if expect_api_error:
|
||||
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||
|
||||
|
||||
@step(u'{predicted_n} tokens are predicted matching {re_content}')
|
||||
def step_n_tokens_predicted_with_content(context, predicted_n, re_content):
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n), re_content)
|
||||
|
||||
|
||||
@step(u'{predicted_n} tokens are predicted')
|
||||
def step_n_tokens_predicted(context, predicted_n):
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), int(predicted_n))
|
||||
|
||||
|
||||
@step(u'a user prompt {user_prompt}')
|
||||
def step_user_prompt(context, user_prompt):
|
||||
context.prompts.append(user_prompt)
|
||||
|
||||
|
||||
@step(u'a system prompt {system_prompt}')
|
||||
def step_system_prompt(context, system_prompt):
|
||||
context.system_prompt = system_prompt
|
||||
|
||||
|
||||
@step(u'a model {model}')
|
||||
def step_model(context, model):
|
||||
context.model = model
|
||||
|
||||
|
||||
@step(u'{max_tokens} max tokens to predict')
|
||||
def step_max_tokens(context, max_tokens):
|
||||
context.n_predict = int(max_tokens)
|
||||
|
||||
|
||||
@step(u'streaming is {enable_streaming}')
|
||||
def step_streaming(context, enable_streaming):
|
||||
context.enable_streaming = enable_streaming == 'enabled'
|
||||
|
||||
|
||||
@step(u'a user api key {user_api_key}')
|
||||
def step_user_api_key(context, user_api_key):
|
||||
context.user_api_key = user_api_key
|
||||
|
||||
|
||||
@step(u'no user api key')
|
||||
def step_no_user_api_key(context):
|
||||
context.user_api_key = None
|
||||
|
||||
|
||||
@step(u'a user api key ')
|
||||
def step_no_user_api_key_space(context):
|
||||
context.user_api_key = None
|
||||
|
||||
|
||||
@step(u'a server api key {server_api_key}')
|
||||
def step_server_api_key(context, server_api_key):
|
||||
context.server_api_key = server_api_key
|
||||
|
||||
|
||||
@step(u'an OAI compatible chat completions request with {api_error} api error')
|
||||
@async_run_until_complete
|
||||
async def step_oai_chat_completions(context, api_error):
|
||||
if context.debug:
|
||||
print(f"Submitting OAI compatible completions request...")
|
||||
expect_api_error = api_error == 'raised'
|
||||
completion = await oai_chat_completions(context.prompts.pop(),
|
||||
context.system_prompt,
|
||||
context.base_url,
|
||||
False,
|
||||
model=context.model if hasattr(context, 'model') else None,
|
||||
|
||||
n_predict=context.n_predict
|
||||
if hasattr(context, 'n_predict') else None,
|
||||
|
||||
enable_streaming=context.enable_streaming
|
||||
if hasattr(context, 'enable_streaming') else None,
|
||||
|
||||
server_seed=context.server_seed
|
||||
if hasattr(context, 'server_seed') else None,
|
||||
|
||||
user_api_key=context.user_api_key
|
||||
if hasattr(context, 'user_api_key') else None,
|
||||
|
||||
expect_api_error=expect_api_error)
|
||||
context.tasks_result.append(completion)
|
||||
if context.debug:
|
||||
print(f"Completion response: {completion}")
|
||||
if expect_api_error:
|
||||
assert completion == 401, f"completion must be an 401 status code: {completion}"
|
||||
|
||||
if context.debug:
|
||||
print(f"Completion response: {completion}")
|
||||
|
||||
|
||||
@step(u'a prompt')
|
||||
def step_a_prompt(context):
|
||||
context.prompts.append(context.text)
|
||||
|
||||
|
||||
@step(u'a prompt {prompt}')
|
||||
def step_a_prompt_prompt(context, prompt):
|
||||
context.prompts.append(prompt)
|
||||
|
||||
|
||||
@step(u'concurrent completion requests')
|
||||
@async_run_until_complete()
|
||||
async def step_concurrent_completion_requests(context):
|
||||
await concurrent_completion_requests(context,
|
||||
request_completion,
|
||||
# prompt is inserted automatically
|
||||
context.base_url,
|
||||
debug=context.debug,
|
||||
n_predict=context.n_predict if hasattr(context, 'n_predict') else None,
|
||||
server_seed=context.server_seed if hasattr(context, 'server_seed') else None,
|
||||
user_api_key=context.user_api_key if hasattr(context,
|
||||
'user_api_key') else None)
|
||||
|
||||
|
||||
@step(u'concurrent OAI completions requests')
|
||||
@async_run_until_complete
|
||||
async def step_oai_chat_completions(context):
|
||||
await concurrent_completion_requests(context, oai_chat_completions,
|
||||
# user_prompt is inserted automatically
|
||||
context.system_prompt,
|
||||
context.base_url,
|
||||
True, # async_client
|
||||
model=context.model
|
||||
if hasattr(context, 'model') else None,
|
||||
n_predict=context.n_predict
|
||||
if hasattr(context, 'n_predict') else None,
|
||||
enable_streaming=context.enable_streaming
|
||||
if hasattr(context, 'enable_streaming') else None,
|
||||
server_seed=context.server_seed
|
||||
if hasattr(context, 'server_seed') else None,
|
||||
user_api_key=context.user_api_key
|
||||
if hasattr(context, 'user_api_key') else None)
|
||||
|
||||
|
||||
@step(u'all prompts are predicted')
|
||||
@async_run_until_complete
|
||||
async def step_all_prompts_are_predicted(context):
|
||||
await all_prompts_are_predicted(context)
|
||||
|
||||
|
||||
@step(u'all prompts are predicted with {n_predict} tokens')
|
||||
@async_run_until_complete
|
||||
async def step_all_prompts_are_predicted_with_n_tokens(context, n_predict):
|
||||
expected_predicted_n = int(n_predict)
|
||||
await all_prompts_are_predicted(context, expected_predicted_n)
|
||||
|
||||
|
||||
async def all_prompts_are_predicted(context, expected_predicted_n=None):
|
||||
n_completions = await gather_tasks_results(context)
|
||||
assert n_completions > 0
|
||||
for i in range(n_completions):
|
||||
assert_n_tokens_predicted(context.tasks_result.pop(), expected_predicted_n=expected_predicted_n)
|
||||
assert len(context.concurrent_tasks) == 0, f"{len(context.concurrent_tasks)} pending requests"
|
||||
|
||||
|
||||
@step(u'embeddings are computed for')
|
||||
@async_run_until_complete
|
||||
async def step_compute_embedding(context):
|
||||
content = context.text
|
||||
base_url = context.base_url
|
||||
context.embeddings = await request_embedding(content, base_url)
|
||||
|
||||
|
||||
@step(u'embeddings are generated')
|
||||
def step_assert_embeddings(context):
|
||||
assert_embeddings(context.embeddings)
|
||||
|
||||
|
||||
@step(u'an OAI compatible embeddings computation request for')
|
||||
def step_oai_compute_embedding(context):
|
||||
openai.api_key = 'nope' # openai client always expects an api_keu
|
||||
if context.user_api_key is not None:
|
||||
openai.api_key = context.user_api_key
|
||||
openai.api_base = f'{context.base_url}/v1'
|
||||
embeddings = openai.Embedding.create(
|
||||
model=context.model,
|
||||
input=context.text,
|
||||
)
|
||||
context.embeddings = embeddings
|
||||
|
||||
|
||||
@step(u'concurrent embedding requests')
|
||||
@async_run_until_complete()
|
||||
async def step_concurrent_embedding_requests(context):
|
||||
await concurrent_completion_requests(context,
|
||||
request_embedding,
|
||||
# prompt is inserted automatically
|
||||
context.base_url)
|
||||
|
||||
|
||||
@step(u'all embeddings are generated')
|
||||
@async_run_until_complete()
|
||||
async def all_embeddings_are_generated(context):
|
||||
n_embedding_requests = await gather_tasks_results(context)
|
||||
assert n_embedding_requests > 0
|
||||
for i in range(n_embedding_requests):
|
||||
assert_embeddings(context.tasks_result.pop())
|
||||
|
||||
|
||||
@step(u'tokenizing')
|
||||
@async_run_until_complete
|
||||
async def step_tokenize(context):
|
||||
context.tokenized_text = context.text
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{context.base_url}/tokenize',
|
||||
json={
|
||||
"content": context.tokenized_text,
|
||||
}) as response:
|
||||
assert response.status == 200
|
||||
tokenize_json = await response.json()
|
||||
context.tokens = tokenize_json['tokens']
|
||||
|
||||
|
||||
@step(u'tokens can be detokenize')
|
||||
@async_run_until_complete
|
||||
async def step_detokenize(context):
|
||||
assert len(context.tokens) > 0
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{context.base_url}/detokenize',
|
||||
json={
|
||||
"tokens": context.tokens,
|
||||
}) as response:
|
||||
assert response.status == 200
|
||||
detokenize_json = await response.json()
|
||||
# SPM tokenizer adds a whitespace prefix: https://github.com/google/sentencepiece/issues/15
|
||||
assert context.tokenized_text == detokenize_json['content'].strip()
|
||||
|
||||
|
||||
@step(u'an OPTIONS request is sent from {origin}')
|
||||
@async_run_until_complete
|
||||
async def step_options_request(context, origin):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.options(f'{context.base_url}/v1/chat/completions',
|
||||
headers={"Origin": origin}) as response:
|
||||
assert response.status == 200
|
||||
context.options_response = response
|
||||
|
||||
|
||||
@step(u'CORS header {cors_header} is set to {cors_header_value}')
|
||||
def step_check_options_header_value(context, cors_header, cors_header_value):
|
||||
assert context.options_response.headers[cors_header] == cors_header_value
|
||||
|
||||
|
||||
async def concurrent_completion_requests(context, f_completion, *args, **kwargs):
|
||||
n_prompts = len(context.prompts)
|
||||
if context.debug:
|
||||
print(f"starting {n_prompts} concurrent completion requests...")
|
||||
assert n_prompts > 0
|
||||
for prompt_no in range(n_prompts):
|
||||
shifted_args = [context.prompts.pop(), *args]
|
||||
context.concurrent_tasks.append(asyncio.create_task(f_completion(*shifted_args, **kwargs)))
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
async def request_completion(prompt,
|
||||
base_url,
|
||||
debug=False,
|
||||
n_predict=None,
|
||||
server_seed=None,
|
||||
expect_api_error=None,
|
||||
user_api_key=None):
|
||||
if debug:
|
||||
print(f"Sending completion request: {prompt}")
|
||||
origin = "my.super.domain"
|
||||
headers = {
|
||||
'Origin': origin
|
||||
}
|
||||
if user_api_key is not None:
|
||||
if debug:
|
||||
print(f"Set user_api_key: {user_api_key}")
|
||||
headers['Authorization'] = f'Bearer {user_api_key}'
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{base_url}/completion',
|
||||
json={
|
||||
"prompt": prompt,
|
||||
"n_predict": int(n_predict) if n_predict is not None else -1,
|
||||
"seed": server_seed if server_seed is not None else 42
|
||||
},
|
||||
headers=headers) as response:
|
||||
if expect_api_error is None or not expect_api_error:
|
||||
assert response.status == 200
|
||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||
return await response.json()
|
||||
else:
|
||||
return response.status
|
||||
|
||||
|
||||
async def oai_chat_completions(user_prompt,
|
||||
system_prompt,
|
||||
base_url,
|
||||
async_client,
|
||||
debug=False,
|
||||
model=None,
|
||||
n_predict=None,
|
||||
enable_streaming=None,
|
||||
server_seed=None,
|
||||
user_api_key=None,
|
||||
expect_api_error=None):
|
||||
if debug:
|
||||
print(f"Sending OAI Chat completions request: {user_prompt}")
|
||||
# openai client always expects an api key
|
||||
user_api_key = user_api_key if user_api_key is not None else 'nope'
|
||||
seed = server_seed if server_seed is not None else 42
|
||||
enable_streaming = enable_streaming if enable_streaming is not None else False
|
||||
payload = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": system_prompt,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": user_prompt,
|
||||
}
|
||||
],
|
||||
"model": model,
|
||||
"max_tokens": n_predict,
|
||||
"stream": enable_streaming,
|
||||
"seed": seed
|
||||
}
|
||||
completion_response = {
|
||||
'content': '',
|
||||
'timings': {
|
||||
'predicted_n': 0
|
||||
}
|
||||
}
|
||||
if async_client:
|
||||
origin = 'llama.cpp'
|
||||
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{base_url}/v1/chat/completions',
|
||||
json=payload,
|
||||
headers=headers) as response:
|
||||
if enable_streaming:
|
||||
assert response.status == 200
|
||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||
assert response.headers['Content-Type'] == "text/event-stream"
|
||||
event_received = True
|
||||
while event_received:
|
||||
event_received = False
|
||||
async for line_in_bytes in response.content:
|
||||
line = line_in_bytes.decode('utf8')
|
||||
line = line.rstrip('\n').rstrip('\r')
|
||||
if line == '':
|
||||
continue
|
||||
event_data = line.split(': ', 1)
|
||||
assert event_data[0] == 'data', f'Bad event code received: ```{event_data}```'
|
||||
chunk_raw = event_data[1]
|
||||
|
||||
chunk = json.loads(chunk_raw)
|
||||
assert len(chunk['choices']) == 1, f"no choices provided, line ```{line}```"
|
||||
delta = chunk['choices'][0]['delta']
|
||||
if 'content' in delta:
|
||||
completion_response['content'] += delta['content']
|
||||
completion_response['timings']['predicted_n'] += 1
|
||||
else:
|
||||
if expect_api_error is None or not expect_api_error:
|
||||
assert response.status == 200
|
||||
assert response.headers['Access-Control-Allow-Origin'] == origin
|
||||
assert response.headers['Content-Type'] == "application/json; charset=utf-8"
|
||||
chat_completion_raw = await response.json()
|
||||
completion_response = {
|
||||
'content': chat_completion_raw['choices'][0]['message'],
|
||||
'timings': {
|
||||
'predicted_n': chat_completion_raw['usage']['completion_tokens']
|
||||
}
|
||||
}
|
||||
else:
|
||||
return response.status
|
||||
else:
|
||||
try:
|
||||
openai.api_key = user_api_key
|
||||
openai.api_base = f'{base_url}/v1/chat'
|
||||
chat_completion = openai.Completion.create(
|
||||
messages=payload['messages'],
|
||||
model=model,
|
||||
max_tokens=n_predict,
|
||||
stream=enable_streaming,
|
||||
seed=seed
|
||||
)
|
||||
except openai.error.APIError as e:
|
||||
if expect_api_error is not None and expect_api_error:
|
||||
return 401
|
||||
else:
|
||||
assert False, f'error raised: {e}'
|
||||
|
||||
if enable_streaming:
|
||||
for chunk in chat_completion:
|
||||
assert len(chunk.choices) == 1
|
||||
delta = chunk.choices[0].delta
|
||||
if 'content' in delta:
|
||||
completion_response['content'] += delta['content']
|
||||
completion_response['timings']['predicted_n'] += 1
|
||||
else:
|
||||
assert len(chat_completion.choices) == 1
|
||||
completion_response = {
|
||||
'content': chat_completion.choices[0].message.content,
|
||||
'timings': {
|
||||
'predicted_n': chat_completion.usage.completion_tokens
|
||||
}
|
||||
}
|
||||
if debug:
|
||||
print("OAI response formatted to llama.cpp:", completion_response)
|
||||
return completion_response
|
||||
|
||||
|
||||
async def request_embedding(content, base_url):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f'{base_url}/embedding',
|
||||
json={
|
||||
"content": content,
|
||||
}) as response:
|
||||
assert response.status == 200
|
||||
response_json = await response.json()
|
||||
return response_json['embedding']
|
||||
|
||||
|
||||
def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re_content=None):
|
||||
content = completion_response['content']
|
||||
n_predicted = completion_response['timings']['predicted_n']
|
||||
assert len(content) > 0, "no token predicted"
|
||||
if expected_predicted_n is not None:
|
||||
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
|
||||
f' {n_predicted} <> {expected_predicted_n}')
|
||||
if re_content is not None:
|
||||
re_content = '^.*' + re_content.replace('<or>', '|') + '.*$'
|
||||
assert re.match(re_content, content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL), (
|
||||
f'invalid tokens predicted:'
|
||||
f' ```\n{content}\n``` do not match /{re_content}/')
|
||||
|
||||
|
||||
async def gather_tasks_results(context):
|
||||
n_tasks = len(context.concurrent_tasks)
|
||||
if context.debug:
|
||||
print(f"Waiting for all {n_tasks} tasks results...")
|
||||
for task_no in range(n_tasks):
|
||||
context.tasks_result.append(await context.concurrent_tasks.pop())
|
||||
n_completions = len(context.tasks_result)
|
||||
return n_completions
|
||||
|
||||
|
||||
async def wait_for_health_status(context,
|
||||
base_url,
|
||||
expected_http_status_code,
|
||||
expected_health_status,
|
||||
params=None,
|
||||
slots_idle=None,
|
||||
slots_processing=None,
|
||||
expected_slots=None):
|
||||
if context.debug:
|
||||
print(f"Starting checking for health for expected_health_status={expected_health_status}")
|
||||
timeout = 3 # seconds
|
||||
interval = 0.5
|
||||
counter = 0
|
||||
async with aiohttp.ClientSession() as session:
|
||||
while True:
|
||||
async with await session.get(f'{base_url}/health', params=params) as health_response:
|
||||
status_code = health_response.status
|
||||
health = await health_response.json()
|
||||
if context.debug:
|
||||
print(f"HEALTH - response for expected health status='{expected_health_status}' on "
|
||||
f"'{base_url}/health'?{params} is {health}")
|
||||
if (status_code == expected_http_status_code
|
||||
and health['status'] == expected_health_status
|
||||
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
||||
and (slots_processing is None or health['slots_processing'] == slots_processing)):
|
||||
if expected_slots is not None:
|
||||
assert_slots_status(health['slots'], expected_slots)
|
||||
return
|
||||
if (status_code == expected_http_status_code
|
||||
and health['status'] == expected_health_status
|
||||
and (slots_idle is None or health['slots_idle'] == slots_idle)
|
||||
and (slots_processing is None or health['slots_processing'] == slots_processing)):
|
||||
if expected_slots is not None:
|
||||
assert_slots_status(health['slots'], expected_slots)
|
||||
return
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
counter += interval
|
||||
if counter >= timeout:
|
||||
# Sometimes health requests are triggered after completions are predicted
|
||||
if expected_http_status_code == 503:
|
||||
if len(context.tasks_result) == 0:
|
||||
print("\x1b[5;37;43mWARNING: forcing concurrent tasks,"
|
||||
" busy health check missed, probably too fast inference\x1b[0m")
|
||||
n_completions = await gather_tasks_results(context)
|
||||
if n_completions > 0:
|
||||
return
|
||||
|
||||
assert False, 'timeout exceeded'
|
||||
|
||||
|
||||
def assert_embeddings(embeddings):
|
||||
assert len(embeddings) > 0
|
||||
embeddings_computed = False
|
||||
for emb in embeddings:
|
||||
if emb != 0:
|
||||
embeddings_computed = True
|
||||
assert embeddings_computed, f"Embeddings: {embeddings}"
|
||||
|
||||
|
||||
async def request_slots_status(context, expected_slots):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with await session.get(f'{context.base_url}/slots') as slots_response:
|
||||
assert slots_response.status == 200
|
||||
slots = await slots_response.json()
|
||||
assert_slots_status(slots, expected_slots)
|
||||
|
||||
|
||||
def assert_slots_status(slots, expected_slots):
|
||||
assert len(slots) == len(expected_slots)
|
||||
for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)):
|
||||
for key in expected:
|
||||
assert expected[key] == slot[key], (f"invalid slot {slot_id}"
|
||||
f" expected[{key}] != slot[{key}]"
|
||||
f" = {expected[key]} != {slot[key]}")
|
||||
|
||||
|
||||
def start_server_background(context):
|
||||
context.server_path = '../../../build/bin/server'
|
||||
if 'LLAMA_SERVER_BIN_PATH' in os.environ:
|
||||
context.server_path = os.environ['LLAMA_SERVER_BIN_PATH']
|
||||
server_args = [
|
||||
'--host', context.server_fqdn,
|
||||
'--port', context.server_port,
|
||||
'--model', context.model_file
|
||||
]
|
||||
if context.server_continuous_batching:
|
||||
server_args.append('--cont-batching')
|
||||
if context.server_embeddings:
|
||||
server_args.append('--embedding')
|
||||
if context.model_alias is not None:
|
||||
server_args.extend(['--alias', context.model_alias])
|
||||
if context.n_ctx is not None:
|
||||
server_args.extend(['--ctx-size', context.n_ctx])
|
||||
if context.n_slots is not None:
|
||||
server_args.extend(['--parallel', context.n_slots])
|
||||
if context.n_server_predict is not None:
|
||||
server_args.extend(['--n-predict', context.n_server_predict])
|
||||
if context.server_api_key is not None:
|
||||
server_args.extend(['--api-key', context.server_api_key])
|
||||
if context.debug:
|
||||
server_args.append('--verbose')
|
||||
print(f"starting server with: {context.server_path}", *server_args)
|
||||
context.server_process = subprocess.Popen(
|
||||
[str(arg) for arg in [context.server_path, *server_args]],
|
||||
close_fds=True)
|
||||
print(f"server pid={context.server_process.pid}")
|
21
examples/server/tests/features/wrong_usages.feature
Normal file
21
examples/server/tests/features/wrong_usages.feature
Normal file
@ -0,0 +1,21 @@
|
||||
# run with ./test.sh --tags wrong_usage
|
||||
@wrong_usage
|
||||
Feature: Wrong usage of llama.cpp server
|
||||
|
||||
#3969 The user must always set --n-predict option
|
||||
# to cap the number of tokens any completion request can generate
|
||||
# or pass n_predict/max_tokens in the request.
|
||||
Scenario: Infinite loop
|
||||
Given a server listening on localhost:8080
|
||||
And a model file stories260K.gguf
|
||||
# Uncomment below to fix the issue
|
||||
#And 64 server max tokens to predict
|
||||
Then the server is starting
|
||||
Given a prompt:
|
||||
"""
|
||||
Go to: infinite loop
|
||||
"""
|
||||
# Uncomment below to fix the issue
|
||||
#And 128 max tokens to predict
|
||||
Given concurrent completion requests
|
||||
Then all prompts are predicted
|
3
examples/server/tests/requirements.txt
Normal file
3
examples/server/tests/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
aiohttp~=3.9.3
|
||||
behave~=1.2.6
|
||||
openai~=0.25.0
|
12
examples/server/tests/tests.sh
Executable file
12
examples/server/tests/tests.sh
Executable file
@ -0,0 +1,12 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
if [ $# -lt 1 ]
|
||||
then
|
||||
# Start @llama.cpp scenario
|
||||
behave --summary --stop --no-capture --exclude 'issues|wrong_usages' --tags llama.cpp
|
||||
else
|
||||
behave "$@"
|
||||
fi
|
||||
|
Loading…
Reference in New Issue
Block a user