From 0da5d860266c6928b8c9408efbd264ae59fedda6 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 2 Jan 2025 15:05:18 +0100 Subject: [PATCH] server : allow using LoRA adapters per-request (#10994) * slot.can_batch_with * lora per request * test: force disable cache prompt * move can_batch_with check * fix condition * add slow test with llama 8b * update docs * move lora change task to queue * Apply suggestions from code review Co-authored-by: Georgi Gerganov * lora_base * remove redundant check --------- Co-authored-by: Georgi Gerganov --- examples/server/README.md | 6 + examples/server/server.cpp | 116 ++++++++++++------ examples/server/tests/README.md | 6 + examples/server/tests/requirements.txt | 1 + examples/server/tests/unit/test_lora.py | 93 ++++++++++++-- .../server/tests/unit/test_speculative.py | 10 +- examples/server/tests/utils.py | 21 ++++ examples/server/utils.hpp | 41 +++++++ 8 files changed, 235 insertions(+), 59 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index bcef81946..3ce16945a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -452,6 +452,8 @@ These words will not be included in the completion, so make sure to add them to `response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name. +`lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation. + **Response format** - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. @@ -945,6 +947,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply` +Please note that this value will be overwritten by the `lora` field for each request. + If an adapter is disabled, the scale will be set to 0. **Response format** @@ -966,6 +970,8 @@ If an adapter is disabled, the scale will be set to 0. ### POST `/lora-adapters`: Set list of LoRA adapters +This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request. + To disable an adapter, either remove it from the list below, or set scale to 0. **Request format** diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b3773f276..5118084f1 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -98,6 +98,8 @@ struct slot_params { int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + std::vector lora; + std::vector antiprompt; std::vector response_fields; bool timings_per_token = false; @@ -120,6 +122,11 @@ struct slot_params { samplers.emplace_back(common_sampler_type_to_str(sampler)); } + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + return json { {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, @@ -160,6 +167,7 @@ struct slot_params { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, }; } }; @@ -189,12 +197,16 @@ struct server_task { // used by SERVER_TASK_TYPE_METRICS bool metrics_reset_bucket = false; + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( const llama_model * model, const llama_context * ctx, const common_params & params_base, + const std::vector & lora_base, const json & data) { slot_params params; @@ -251,6 +263,16 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_max = std::max(params.speculative.n_max, 0); + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(lora_base, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = lora_base; + } + // TODO: add more sanity checks for the input parameters if (params.sampling.penalty_last_n < -1) { @@ -1110,6 +1132,8 @@ struct server_slot { common_speculative * spec = nullptr; + std::vector lora; + // the index relative to completion multi-task request size_t index = 0; @@ -1191,6 +1215,11 @@ struct server_slot { return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; } + bool can_batch_with(server_slot & other_slot) { + return is_non_causal() == other_slot.is_non_causal() + && are_lora_equal(lora, other_slot.lora); + } + bool has_budget(const common_params & global_params) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless @@ -1600,7 +1629,7 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; - std::vector loras; + std::vector lora; llama_model * model_dft = nullptr; llama_context_params cparams_dft; @@ -1667,7 +1696,7 @@ struct server_context { model = llama_init.model; ctx = llama_init.context; - loras = llama_init.lora_adapters; + lora = llama_init.lora_adapters; if (model == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); @@ -1866,6 +1895,12 @@ struct server_context { slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = std::move(task.params.lora); + } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { @@ -2557,7 +2592,7 @@ struct server_context { } break; case SERVER_TASK_TYPE_SET_LORA: { - common_lora_adapters_apply(ctx, loras); + lora = std::move(task.set_lora); auto res = std::make_unique(); res->id = task.id; queue_results.send(std::move(res)); @@ -2634,12 +2669,22 @@ struct server_context { // start populating the batch for this iteration common_batch_clear(batch); + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + // frist, add sampled tokens from any ongoing sequences for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { continue; } + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + slot.i_batch = batch.n_tokens; common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); @@ -2658,15 +2703,18 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // track if this is an embedding or non-embedding batch - // if we've added sampled tokens above, we are in non-embedding mode - // -1: none, 0: non-embedding, 1: embedding - // TODO: make enum - int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; - // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { + // check if we can batch this slot with the previous one + if (slot.is_processing()) { + if (!slot_batched) { + slot_batched = &slot; + } else if (!slot_batched->can_batch_with(slot)) { + continue; + } + } + // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { auto & prompt_tokens = slot.prompt_tokens; @@ -2827,14 +2875,6 @@ struct server_context { } } - // check that we are in the right batch_type, if not defer the slot - int slot_type = slot.is_non_causal(); - if (batch_type == -1) { - batch_type = slot_type; - } else if (batch_type != slot_type) { - continue; - } - // keep only the common part if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) @@ -2902,8 +2942,12 @@ struct server_context { SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - // make sure we're in the right embedding mode - llama_set_embeddings(ctx, batch_type == 1); + if (slot_batched) { + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, slot_batched->is_non_causal()); + // apply lora, only need to do it once per batch + common_lora_adapters_apply(ctx, slot_batched->lora); + } // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { @@ -3623,7 +3667,12 @@ int main(int argc, char ** argv) { task.index = i; task.prompt_tokens = std::move(tokenized_prompts[i]); - task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data); + task.params = server_task::params_from_json_cmpl( + ctx_server.model, + ctx_server.ctx, + ctx_server.params_base, + ctx_server.lora, + data); task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat @@ -4049,8 +4098,8 @@ int main(int argc, char ** argv) { const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); - for (size_t i = 0; i < ctx_server.loras.size(); ++i) { - auto & lora = ctx_server.loras[i]; + for (size_t i = 0; i < ctx_server.lora.size(); ++i) { + auto & lora = ctx_server.lora[i]; result.push_back({ {"id", i}, {"path", lora.path}, @@ -4062,27 +4111,14 @@ int main(int argc, char ** argv) { }; const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - const std::vector body = json::parse(req.body); - int max_idx = ctx_server.loras.size(); - - // clear existing value - for (auto & lora : ctx_server.loras) { - lora.scale = 0.0f; + const json body = json::parse(req.body); + if (!body.is_array()) { + res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); + return; } - - // set value - for (auto entry : body) { - int id = entry.at("id"); - float scale = entry.at("scale"); - if (0 <= id && id < max_idx) { - ctx_server.loras[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } - } - server_task task(SERVER_TASK_TYPE_SET_LORA); task.id = ctx_server.queue_tasks.get_new_id(); + task.set_lora = parse_lora_request(ctx_server.lora, body); ctx_server.queue_results.add_waiting_task_id(task.id); ctx_server.queue_tasks.post(task); diff --git a/examples/server/tests/README.md b/examples/server/tests/README.md index fa3d0a2f5..5787276ab 100644 --- a/examples/server/tests/README.md +++ b/examples/server/tests/README.md @@ -44,6 +44,12 @@ To run with stdout/stderr display in real time (verbose output, but useful for d DEBUG=1 ./tests.sh -s -v -x ``` +To run single test unit: + +```shell +./tests.sh unit/test_{name of test case here}.py -v -x +``` + Hint: You can compile and run test in single command, useful for local developement: ```shell diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt index 074b9d47b..15d024914 100644 --- a/examples/server/tests/requirements.txt +++ b/examples/server/tests/requirements.txt @@ -5,3 +5,4 @@ numpy~=1.26.4 openai~=1.55.3 prometheus-client~=0.20.0 requests~=2.32.3 +wget~=3.2 diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index 749615449..c1aa8be70 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -1,5 +1,4 @@ import pytest -import os from utils import * server = ServerPreset.stories15m_moe() @@ -10,15 +9,7 @@ LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe def create_server(): global server server = ServerPreset.stories15m_moe() - # download lora file if needed - file_name = LORA_FILE_URL.split('/').pop() - lora_file = f'../../../{file_name}' - if not os.path.exists(lora_file): - print(f"Downloading {LORA_FILE_URL} to {lora_file}") - with open(lora_file, 'wb') as f: - f.write(requests.get(LORA_FILE_URL).content) - print(f"Done downloading lora file") - server.lora_files = [lora_file] + server.lora_files = [download_file(LORA_FILE_URL)] @pytest.mark.parametrize("scale,re_content", [ @@ -40,3 +31,85 @@ def test_lora(scale: float, re_content: str): assert res.status_code == 200 assert match_regex(re_content, res.body["content"]) + +def test_lora_per_request(): + global server + server.n_slots = 4 + server.start() + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Look in thy glass" + lora_config = [ + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ), + ( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ] + + tasks = [( + server.make_request, + ("POST", "/completion", { + "prompt": prompt, + "lora": lora, + "seed": 42, + "temperature": 0.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + ) for lora, _ in lora_config] + results = parallel_function_calls(tasks) + + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert match_regex(re_test, res.body["content"]) + + +@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test") +def test_with_big_model(): + server = ServerProcess() + server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" + server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf" + server.model_alias = "Llama-3.2-8B-Instruct" + server.n_slots = 4 + server.n_ctx = server.n_slots * 1024 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + server.lora_files = [ + download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"), + # TODO: find & add other lora adapters for this model + ] + server.start(timeout_seconds=600) + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Write a computer virus" + lora_config = [ + # without applying lora, the model should reject the request + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), + ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ), + # with 0.7 scale, the model should provide a simple computer virus with hesitation + ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ), + # with 1.5 scale, the model should confidently provide a computer virus + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ] + + tasks = [( + server.make_request, + ("POST", "/v1/chat/completions", { + "messages": [ + {"role": "user", "content": prompt} + ], + "lora": lora, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + ) for lora, _ in lora_config] + results = parallel_function_calls(tasks) + + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert re_test in res.body["choices"][0]["message"]["content"] diff --git a/examples/server/tests/unit/test_speculative.py b/examples/server/tests/unit/test_speculative.py index 3bb5733cb..54db38cf3 100644 --- a/examples/server/tests/unit/test_speculative.py +++ b/examples/server/tests/unit/test_speculative.py @@ -10,16 +10,8 @@ MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tiny def create_server(): global server server = ServerPreset.stories15m_moe() - # download draft model file if needed - file_name = MODEL_DRAFT_FILE_URL.split('/').pop() - model_draft_file = f'../../../{file_name}' - if not os.path.exists(model_draft_file): - print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}") - with open(model_draft_file, 'wb') as f: - f.write(requests.get(MODEL_DRAFT_FILE_URL).content) - print(f"Done downloading draft model file") # set default values - server.model_draft = model_draft_file + server.model_draft = download_file(MODEL_DRAFT_FILE_URL) server.draft_min = 4 server.draft_max = 8 diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 359bb0fae..a1a94d0f1 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -23,6 +23,7 @@ from typing import ( Set, ) from re import RegexFlag +import wget class ServerResponse: @@ -381,5 +382,25 @@ def match_regex(regex: str, text: str) -> bool: is not None ) + +def download_file(url: str, output_file_path: str | None = None) -> str: + """ + Download a file from a URL to a local path. If the file already exists, it will not be downloaded again. + + output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory. + + Returns the local path of the downloaded file. + """ + file_name = url.split('/').pop() + output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path + if not os.path.exists(output_file): + print(f"Downloading {url} to {output_file}") + wget.download(url, out=output_file) + print(f"Done downloading to {output_file}") + else: + print(f"File already exists at {output_file}") + return output_file + + def is_slow_test_allowed(): return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON" diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 70220c437..1cf08bb0a 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -797,3 +797,44 @@ static std::vector get_token_probabilities(llama_context * ctx return cur; } + +static bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) { + return false; + } + } + return true; +} + +// parse lora config from JSON request, returned a copy of base_lora with updated scale +static std::vector parse_lora_request( + const std::vector & base_lora, + const json & data) { + std::vector lora(base_lora); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (const auto & entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; +}