From 0b3bf966f47bf2ba88e5d4e3ed429602008c7e63 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 23 Sep 2024 22:23:54 +0200 Subject: [PATCH] server : add --no-context-shift option (#9607) * server : add --no-context-shift option * small fix * Update examples/server/tests/features/embeddings.feature Co-authored-by: Georgi Gerganov * tests : minor fix * revert usage of GGML_ASSERT * update server documentation --------- Co-authored-by: Georgi Gerganov --- common/arg.cpp | 2 +- examples/server/README.md | 20 +++--- examples/server/server.cpp | 27 +++++++- .../server/tests/features/ctx_shift.feature | 62 +++++++++++++++++++ .../server/tests/features/embeddings.feature | 22 +++++-- examples/server/tests/features/steps/steps.py | 28 ++++++--- 6 files changed, 139 insertions(+), 22 deletions(-) create mode 100644 examples/server/tests/features/ctx_shift.feature diff --git a/common/arg.cpp b/common/arg.cpp index 922391069..c1ec3c4f9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -691,7 +691,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, [](gpt_params & params) { params.ctx_shift = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER})); add_opt(llama_arg( {"--chunks"}, "N", format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), diff --git a/examples/server/README.md b/examples/server/README.md index 326e05e1e..741950c8a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -21,8 +21,6 @@ The project is under active development, and we are [looking for feedback and co | -------- | ----------- | | `-h, --help, --usage` | print usage and exit | | `--version` | show version and build info | -| `-v, --verbose` | print verbose information | -| `--verbosity N` | set specific verbosity level (default: 0) | | `-t, --threads N` | number of threads to use during generation (default: -1)
(env: LLAMA_ARG_THREADS) | | `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) | | `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") | @@ -40,15 +38,18 @@ The project is under active development, and we are [looking for feedback and co | `-b, --batch-size N` | logical maximum batch size (default: 2048)
(env: LLAMA_ARG_BATCH) | | `-ub, --ubatch-size N` | physical maximum batch size (default: 512)
(env: LLAMA_ARG_UBATCH) | | `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) | +| `--no-context-shift` | disables context shift on inifinite text generation (default: disabled) | | `-fa, --flash-attn` | enable Flash Attention (default: disabled)
(env: LLAMA_ARG_FLASH_ATTN) | | `-p, --prompt PROMPT` | prompt to start generation with | +| `--no-perf` | disable internal libllama performance timings (default: false)
(env: LLAMA_ARG_NO_PERF) | | `-f, --file FNAME` | a file containing the prompt (default: none) | | `-bf, --binary-file FNAME` | binary file containing the prompt (default: none) | | `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) | | `--no-escape` | do not process escape sequences | +| `-sp, --special` | special tokens output enabled (default: false) | | `--spm-infill` | use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: disabled) | | `--samplers SAMPLERS` | samplers that will be used for generation in the order, separated by ';'
(default: top_k;tfs_z;typ_p;top_p;min_p;temperature) | -| `-s, --seed SEED` | RNG seed (default: -1, use random seed for < 0) | +| `-s, --seed SEED` | RNG seed (default: 4294967295, use random seed for 4294967295) | | `--sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: kfypmt) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | | `--penalize-nl` | penalize newline tokens (default: false) | @@ -87,7 +88,7 @@ The project is under active development, and we are [looking for feedback and co | `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16) | | `-ctv, --cache-type-v TYPE` | KV cache data type for V (default: f16) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (default: -1.0, < 0 - disabled)
(env: LLAMA_ARG_DEFRAG_THOLD) | -| `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | +| `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | | `-cb, --cont-batching` | enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-nocb, --no-cont-batching` | disable continuous batching
(env: LLAMA_ARG_NO_CONT_BATCHING) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing | @@ -128,12 +129,13 @@ The project is under active development, and we are [looking for feedback and co | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
| | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | | `-ld, --logdir LOGDIR` | path under which to save YAML logs (no logging if unset) | -| `--log-test` | Log test | | `--log-disable` | Log disable | -| `--log-enable` | Log enable | -| `--log-new` | Log new | -| `--log-append` | Log append | -| `--log-file FNAME` | Log file | +| `--log-file FNAME` | Log to file | +| `--log-colors` | Enable colored logging
(env: LLAMA_LOG_COLORS) | +| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) | +| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.
(env: LLAMA_LOG_VERBOSITY) | +| `--log-prefix` | Enable prefx in log messages
(env: LLAMA_LOG_PREFIX) | +| `--log-timestamps` | Enable timestamps in log messages
(env: LLAMA_LOG_TIMESTAMPS) | Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0ca999994..8655c097a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1180,6 +1180,15 @@ struct server_context { SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); } + // if context shift is disabled, we stop when it reaches the context limit + if (slot.n_decoded >= slot.n_ctx) { + slot.truncated = true; + slot.stopped_limit = true; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx); + } + if (llama_token_is_eog(model, result.tok)) { slot.stopped_eos = true; slot.has_next_token = false; @@ -1480,7 +1489,7 @@ struct server_context { if (result.error) { error_handler(result.data); cancel_tasks(id_tasks); - break; + return; } size_t idx = result.data["index"]; @@ -1827,6 +1836,14 @@ struct server_context { for (server_slot & slot : slots) { if (slot.ga_n == 1) { if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) { + if (!params.ctx_shift) { + // this check is redundant (for good) + // we should never get here, because generation should already stopped in process_token() + slot.release(); + send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); + continue; + } + // Shift context const int n_keep = slot.params.n_keep + add_bos_token; const int n_left = (int) system_tokens.size() + slot.n_past - n_keep; @@ -1961,6 +1978,14 @@ struct server_context { continue; } } else { + if (!params.ctx_shift) { + // if context shift is disabled, we make sure prompt size is smaller than KV size + if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) { + slot.release(); + send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST); + continue; + } + } if (slot.params.n_keep < 0) { slot.params.n_keep = slot.n_prompt_tokens; } diff --git a/examples/server/tests/features/ctx_shift.feature b/examples/server/tests/features/ctx_shift.feature new file mode 100644 index 000000000..ba3afcf06 --- /dev/null +++ b/examples/server/tests/features/ctx_shift.feature @@ -0,0 +1,62 @@ +@llama.cpp +@ctx_shift +Feature: llama.cpp server + + Background: Server startup + Given a server listening on localhost:8080 + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models + And a model file test-model.gguf + And a model alias tinyllama-2 + And BOS token is 1 + And 42 as server seed + And 256 KV cache size + And 32 as batch size + And 2 slots + + Scenario: Inference with context shift + And 64 server max tokens to predict + Then the server is starting + Then the server is healthy + Given a prompt: + """ + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. + Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. + Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. + """ + And a completion request with no api error + Then 64 tokens are predicted matching fun|Annaks|popcorns|pictry|bowl + And the completion is truncated + And 109 prompt tokens are processed + + Scenario Outline: Inference without context shift + And server max tokens to predict + And disable context shifting + Then the server is starting + Then the server is healthy + Given a prompt: + """ + Hi how are you + """ + And a completion request with no api error + Then tokens are predicted matching twind|Anna + And the completion is truncated + And 8 prompt tokens are processed + Examples: + | n_predict | n_token_output | truncated | + | 64 | 64 | not | + | -1 | 120 | | + + Scenario: Inference without context shift (expected error: prompt too long) + And disable context shifting + Then the server is starting + Then the server is healthy + Given a prompt: + """ + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. + Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. + Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. + """ + And a completion request with 400 api error + diff --git a/examples/server/tests/features/embeddings.feature b/examples/server/tests/features/embeddings.feature index e1eade6cd..818ea3beb 100644 --- a/examples/server/tests/features/embeddings.feature +++ b/examples/server/tests/features/embeddings.feature @@ -10,11 +10,11 @@ Feature: llama.cpp server And 42 as server seed And 2 slots # the bert-bge-small model has context size of 512 - # since the generated prompts are as big as the batch size, we need to set the batch size to 512 + # since the generated prompts are as big as the batch size, we need to set the batch size to <= 512 # ref: https://huggingface.co/BAAI/bge-small-en-v1.5/blob/5c38ec7c405ec4b44b94cc5a9bb96e735b38267a/config.json#L20 - And 512 as batch size - And 512 as ubatch size - And 2048 KV cache size + And 128 as batch size + And 128 as ubatch size + And 512 KV cache size And embeddings extraction Then the server is starting Then the server is healthy @@ -26,6 +26,20 @@ Feature: llama.cpp server """ Then embeddings are generated + Scenario: Embedding (error: prompt too long) + When embeddings are computed for: + """ + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. + Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. + Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. + Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. + Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. + Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. + Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum. + """ + And embeddings request with 500 api error + Scenario: OAI Embeddings compatibility Given a model bert-bge-small When an OAI compatible embeddings computation request for: diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 062f084be..0fea0fe87 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -77,6 +77,7 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.response_format = None context.temperature = None context.lora_file = None + context.disable_ctx_shift = False context.tasks_result = [] context.concurrent_tasks = [] @@ -148,7 +149,7 @@ def step_n_slots(context, n_slots: int): @step('{n_predict:d} server max tokens to predict') def step_server_n_predict(context, n_predict: int): - context.n_server_predict = n_predict + context.n_server_predict = n_predict if n_predict > 0 else None @step('{slot_save_path} as slot save path') @@ -180,6 +181,9 @@ def step_server_embeddings(context): def step_server_metrics(context): context.server_metrics = True +@step('disable context shifting') +def step_server_disable_ctx_shift(context): + context.disable_ctx_shift = True @step("the server is starting") def step_start_server(context): @@ -257,7 +261,7 @@ async def step_all_slots_status(context, expected_slot_status_string: Literal['i @step('a completion request with {api_error} api error') @async_run_until_complete async def step_request_completion(context, api_error: Literal['raised'] | str): - expect_api_error = api_error == 'raised' + expect_api_error = api_error == 'raised' or api_error != 'no' seeds = await completions_seed(context, num_seeds=1) completion = await request_completion(context.prompts.pop(), seeds[0] if seeds is not None else seeds, @@ -272,8 +276,11 @@ async def step_request_completion(context, api_error: Literal['raised'] | str): context.tasks_result.append(completion) if context.debug: print(f"Completion response: {completion}") - if expect_api_error: + if api_error == 'raised': assert completion == 401, f"completion must be an 401 status code: {completion}" + elif api_error.isdigit(): + api_error_code = int(api_error) + assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}" @step('{predicted_n:d} tokens are predicted matching {re_content}') @@ -645,6 +652,9 @@ def step_assert_embeddings(context): for embedding in context.embeddings: assert_embeddings(embedding) +@step('embeddings request with {api_error_code:d} api error') +def step_assert_embeddings(context, api_error_code: int): + assert context.embeddings == api_error_code, f"embeddings request must return code {api_error_code}, but got {context.embeddings}" @step('an OAI compatible embeddings computation request for') @async_run_until_complete @@ -1089,15 +1099,17 @@ async def oai_chat_completions(user_prompt, return completion_response -async def request_embedding(content, seed, base_url=None) -> list[list[float]]: +async def request_embedding(content, seed, base_url=None) -> list[list[float]] | int: async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: async with session.post(f'{base_url}/embedding', json={ "content": content, }) as response: - assert response.status == 200 - response_json = await response.json() - return [response_json['embedding']] + if response.status == 200: + response_json = await response.json() + return [response_json['embedding']] + else: + return response.status async def request_oai_embeddings(input, seed, @@ -1372,6 +1384,8 @@ def start_server_background(context): server_args.append('--verbose') if context.lora_file: server_args.extend(['--lora', context.lora_file]) + if context.disable_ctx_shift: + server_args.extend(['--no-context-shift']) args = [str(arg) for arg in [context.server_path, *server_args]] print(f"bench: starting server with: {' '.join(args)}")