server : add lora hotswap endpoint (WIP) (#8857)

* server : add lora hotswap endpoint

* handle lora_no_apply

* fix build

* updae docs

* clean up struct def

* fix build

* add LoRA test

* fix style
This commit is contained in:
Xuan Son Nguyen 2024-08-06 17:33:39 +02:00 committed by GitHub
parent 641f5dd2a6
commit 1e6f6554aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 251 additions and 92 deletions

View File

@ -684,14 +684,24 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
} }
if (arg == "--lora") { if (arg == "--lora") {
CHECK_ARG CHECK_ARG
params.lora_adapter.emplace_back(argv[i], 1.0f); params.lora_adapters.push_back({
std::string(argv[i]),
1.0,
});
return true; return true;
} }
if (arg == "--lora-scaled") { if (arg == "--lora-scaled") {
CHECK_ARG CHECK_ARG
const char* lora_adapter = argv[i]; std::string lora_adapter = argv[i];
CHECK_ARG CHECK_ARG
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); params.lora_adapters.push_back({
lora_adapter,
std::stof(argv[i]),
});
return true;
}
if (arg == "--lora-init-without-apply") {
params.lora_init_without_apply = true;
return true; return true;
} }
if (arg == "--control-vector") { if (arg == "--control-vector") {
@ -1654,6 +1664,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
"https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" }); "https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template" });
options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY", options.push_back({ "server", "-sps, --slot-prompt-similarity SIMILARITY",
"how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity }); "how much the prompt of a request must match the prompt of a slot in order to use that slot (default: %.2f, 0.0 = disabled)\n", params.slot_prompt_similarity });
options.push_back({ "server", " --lora-init-without-apply", "load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"});
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
options.push_back({ "logging" }); options.push_back({ "logging" });
@ -2091,17 +2102,22 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
} }
} }
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) { // load and optionally apply lora adapters
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]); for (auto & la : params.lora_adapters) {
float lora_scale = std::get<1>(params.lora_adapter[i]); llama_lora_adapter_container loaded_la;
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str()); loaded_la.path = la.path;
if (adapter == nullptr) { loaded_la.scale = la.scale;
fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str());
if (loaded_la.adapter == nullptr) {
fprintf(stderr, "%s: error: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx); llama_free(lctx);
llama_free_model(model); llama_free_model(model);
return iparams; return iparams;
} }
llama_lora_adapter_set(lctx, adapter, lora_scale); iparams.lora_adapters.push_back(loaded_la); // copy to list of loaded adapters
}
if (!params.lora_init_without_apply) {
llama_lora_adapters_apply(lctx, iparams.lora_adapters);
} }
if (params.ignore_eos) { if (params.ignore_eos) {
@ -2140,6 +2156,15 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) {
return iparams; return iparams;
} }
void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters) {
llama_lora_adapter_clear(ctx);
for (auto & la : lora_adapters) {
if (la.scale != 0.0f) {
llama_lora_adapter_set(ctx, la.adapter, la.scale);
}
}
}
struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) { struct llama_model_params llama_model_params_from_gpt_params(const gpt_params & params) {
auto mparams = llama_model_default_params(); auto mparams = llama_model_default_params();
@ -3162,19 +3187,18 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l
} }
fprintf(stream, "lora:\n"); fprintf(stream, "lora:\n");
for (std::tuple<std::string, float> la : params.lora_adapter) { for (auto & la : params.lora_adapters) {
if (std::get<1>(la) != 1.0f) { if (la.scale == 1.0f) {
continue; fprintf(stream, " - %s\n", la.path.c_str());
} }
fprintf(stream, " - %s\n", std::get<0>(la).c_str());
} }
fprintf(stream, "lora_scaled:\n"); fprintf(stream, "lora_scaled:\n");
for (std::tuple<std::string, float> la : params.lora_adapter) { for (auto & la : params.lora_adapters) {
if (std::get<1>(la) == 1.0f) { if (la.scale != 1.0f) {
continue; fprintf(stream, " - %s: %f\n", la.path.c_str(), la.scale);
} }
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
} }
fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false");
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep);
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);

View File

@ -33,6 +33,15 @@
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
struct llama_lora_adapter_info {
std::string path;
float scale;
};
struct llama_lora_adapter_container : llama_lora_adapter_info {
struct llama_lora_adapter * adapter;
};
// build info // build info
extern int LLAMA_BUILD_NUMBER; extern int LLAMA_BUILD_NUMBER;
extern char const * LLAMA_COMMIT; extern char const * LLAMA_COMMIT;
@ -126,8 +135,8 @@ struct gpt_params {
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides; std::vector<llama_model_kv_override> kv_overrides;
// TODO: avoid tuple, use struct bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply)
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale std::vector<llama_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale
std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale std::vector<llama_control_vector_load_info> control_vectors; // control vector with user defined scale
@ -311,6 +320,7 @@ std::string fs_get_cache_file(const std::string & filename);
struct llama_init_result { struct llama_init_result {
struct llama_model * model = nullptr; struct llama_model * model = nullptr;
struct llama_context * context = nullptr; struct llama_context * context = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;
}; };
struct llama_init_result llama_init_from_gpt_params(gpt_params & params); struct llama_init_result llama_init_from_gpt_params(gpt_params & params);
@ -321,6 +331,9 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params); struct llama_model * llama_load_model_from_url(const char * model_url, const char * path_model, const char * hf_token, const struct llama_model_params & params);
struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params); struct llama_model * llama_load_model_from_hf(const char * repo, const char * file, const char * path_model, const char * hf_token, const struct llama_model_params & params);
// clear LoRA adapters from context, then apply new list of adapters
void llama_lora_adapters_apply(struct llama_context * ctx, std::vector<llama_lora_adapter_container> & lora_adapters);
// Batch utils // Batch utils
void llama_batch_clear(struct llama_batch & batch); void llama_batch_clear(struct llama_batch & batch);

View File

@ -135,7 +135,7 @@ struct lora_merge_ctx {
lora_merge_ctx( lora_merge_ctx(
std::string & base_fname, std::string & base_fname,
std::vector<std::tuple<std::string, float>> & lora_files, std::vector<llama_lora_adapter_info> & lora_files,
std::string & outfile, std::string & outfile,
int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) { int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) {
fout.exceptions(std::ofstream::failbit); // fail fast on write errors fout.exceptions(std::ofstream::failbit); // fail fast on write errors
@ -144,9 +144,9 @@ struct lora_merge_ctx {
throw std::runtime_error("split model is not yet supported"); throw std::runtime_error("split model is not yet supported");
} }
for (auto lora_inp : lora_files) { for (auto & lora_inp : lora_files) {
auto fname = std::get<0>(lora_inp); auto fname = lora_inp.path;
auto scale = std::get<1>(lora_inp); auto scale = lora_inp.scale;
std::unique_ptr<file_input> adapter(new file_input(fname, scale)); std::unique_ptr<file_input> adapter(new file_input(fname, scale));
check_metadata_lora(adapter.get()); check_metadata_lora(adapter.get());
adapters.push_back(std::move(adapter)); adapters.push_back(std::move(adapter));
@ -407,7 +407,7 @@ int main(int argc, char ** argv) {
g_verbose = (params.verbosity == 1); g_verbose = (params.verbosity == 1);
try { try {
lora_merge_ctx ctx(params.model, params.lora_adapter, params.lora_outfile, params.n_threads); lora_merge_ctx ctx(params.model, params.lora_adapters, params.lora_outfile, params.n_threads);
ctx.run_merge(); ctx.run_merge();
} catch (const std::exception & err) { } catch (const std::exception & err) {
fprintf(stderr, "%s\n", err.what()); fprintf(stderr, "%s\n", err.what());

View File

@ -207,41 +207,6 @@ model:
-hff, --hf-file FILE Hugging Face model file (default: unused) -hff, --hf-file FILE Hugging Face model file (default: unused)
-hft, --hf-token TOKEN Hugging Face access token (default: value from HF_TOKEN environment variable) -hft, --hf-token TOKEN Hugging Face access token (default: value from HF_TOKEN environment variable)
retrieval:
--context-file FNAME file to load context from (repeat to specify multiple files)
--chunk-size N minimum length of embedded text chunks (default: 64)
--chunk-separator STRING
separator between chunks (default: '
')
passkey:
--junk N number of times to repeat the junk text (default: 250)
--pos N position of the passkey in the junk text (default: -1)
imatrix:
-o, --output FNAME output file (default: 'imatrix.dat')
--output-frequency N output the imatrix every N iterations (default: 10)
--save-frequency N save an imatrix copy every N iterations (default: 0)
--process-output collect data for the output tensor (default: false)
--no-ppl do not compute perplexity (default: true)
--chunk N start processing the input from chunk N (default: 0)
bench:
-pps is the prompt shared across parallel sequences (default: false)
-npp n0,n1,... number of prompt tokens
-ntg n0,n1,... number of text generation tokens
-npl n0,n1,... number of parallel prompts
embedding:
--embd-normalize normalisation for embendings (default: 2) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm)
--embd-output-format empty = default, "array" = [[],[]...], "json" = openai style, "json+" = same "json" + cosine similarity matrix
--embd-separator separator of embendings (default \n) for example "<#sep#>"
server: server:
--host HOST ip address to listen (default: 127.0.0.1) --host HOST ip address to listen (default: 127.0.0.1)
@ -267,7 +232,8 @@ server:
https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
-sps, --slot-prompt-similarity SIMILARITY -sps, --slot-prompt-similarity SIMILARITY
how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled) how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)
--lora-init-without-apply
load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled)
logging: logging:
@ -279,15 +245,6 @@ logging:
--log-file FNAME Specify a log filename (without extension) --log-file FNAME Specify a log filename (without extension)
--log-new Create a separate new log file on start. Each log file will have unique name: "<name>.<ID>.log" --log-new Create a separate new log file on start. Each log file will have unique name: "<name>.<ID>.log"
--log-append Don't truncate the old log file. --log-append Don't truncate the old log file.
cvector:
-o, --output FNAME output file (default: 'control_vector.gguf')
--positive-file FNAME positive prompts file, one prompt per line (default: 'examples/cvector-generator/positive.txt')
--negative-file FNAME negative prompts file, one prompt per line (default: 'examples/cvector-generator/negative.txt')
--pca-batch N batch size used for PCA. Larger batch runs faster, but uses more memory (default: 100)
--pca-iter N number of iterations used for PCA (default: 1000)
--method {pca,mean} dimensionality reduction method to be used (default: pca)
``` ```
@ -411,7 +368,8 @@ node index.js
## API Endpoints ## API Endpoints
- **GET** `/health`: Returns the current state of the server: ### GET `/health`: Returns the current state of the server
- 503 -> `{"status": "loading model"}` if the model is still being loaded. - 503 -> `{"status": "loading model"}` if the model is still being loaded.
- 500 -> `{"status": "error"}` if the model failed to load. - 500 -> `{"status": "error"}` if the model failed to load.
- 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below. - 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below.
@ -420,7 +378,7 @@ node index.js
If the query parameter `include_slots` is passed, `slots` field will contain internal slots data except if `--slots-endpoint-disable` is set. If the query parameter `include_slots` is passed, `slots` field will contain internal slots data except if `--slots-endpoint-disable` is set.
- **POST** `/completion`: Given a `prompt`, it returns the predicted completion. ### POST `/completion`: Given a `prompt`, it returns the predicted completion.
*Options:* *Options:*
@ -498,7 +456,7 @@ node index.js
`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"]` - these are all the available values. `samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["top_k", "tfs_z", "typical_p", "top_p", "min_p", "temperature"]` - these are all the available values.
### Result JSON **Response format**
- Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion. - Note: When using streaming mode (`stream`), only `content` and `stop` will be returned until end of completion.
@ -537,7 +495,7 @@ Notice that each `probs` is an array of length `n_probs`.
- `tokens_evaluated`: Number of tokens evaluated in total from the prompt - `tokens_evaluated`: Number of tokens evaluated in total from the prompt
- `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`)
- **POST** `/tokenize`: Tokenize a given text. ### POST `/tokenize`: Tokenize a given text
*Options:* *Options:*
@ -545,13 +503,15 @@ Notice that each `probs` is an array of length `n_probs`.
`add_special`: Boolean indicating if special tokens, i.e. `BOS`, should be inserted. Default: `false` `add_special`: Boolean indicating if special tokens, i.e. `BOS`, should be inserted. Default: `false`
- **POST** `/detokenize`: Convert tokens to text. ### POST `/detokenize`: Convert tokens to text
*Options:* *Options:*
`tokens`: Set the tokens to detokenize. `tokens`: Set the tokens to detokenize.
- **POST** `/embedding`: Generate embedding of a given text just as [the embedding example](../embedding) does. ### POST `/embedding`: Generate embedding of a given text
The same as [the embedding example](../embedding) does.
*Options:* *Options:*
@ -559,7 +519,9 @@ Notice that each `probs` is an array of length `n_probs`.
`image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA. `image_data`: An array of objects to hold base64-encoded image `data` and its `id`s to be reference in `content`. You can determine the place of the image in the content as in the following: `Image: [img-21].\nCaption: This is a picture of a house`. In this case, `[img-21]` will be replaced by the embeddings of the image with id `21` in the following `image_data` array: `{..., "image_data": [{"data": "<BASE64_STRING>", "id": 21}]}`. Use `image_data` only with multimodal models, e.g., LLaVA.
- **POST** `/infill`: For code infilling. Takes a prefix and a suffix and returns the predicted completion as stream. ### POST `/infill`: For code infilling.
Takes a prefix and a suffix and returns the predicted completion as stream.
*Options:* *Options:*
@ -571,7 +533,7 @@ Notice that each `probs` is an array of length `n_probs`.
- **GET** `/props`: Return current server settings. - **GET** `/props`: Return current server settings.
### Result JSON **Response format**
```json ```json
{ {
@ -589,7 +551,9 @@ Notice that each `probs` is an array of length `n_probs`.
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option) - `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
- `chat_template` - the model's original Jinja2 prompt template - `chat_template` - the model's original Jinja2 prompt template
- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. ### POST `/v1/chat/completions`: OpenAI-compatible Chat Completions API
Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used.
*Options:* *Options:*
@ -641,7 +605,7 @@ Notice that each `probs` is an array of length `n_probs`.
}' }'
``` ```
- **POST** `/v1/embeddings`: OpenAI-compatible embeddings API. ### POST `/v1/embeddings`: OpenAI-compatible embeddings API
*Options:* *Options:*
@ -675,9 +639,9 @@ Notice that each `probs` is an array of length `n_probs`.
}' }'
``` ```
- **GET** `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`. ### GET `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`.
### Result JSON **Response format**
```json ```json
[ [
@ -738,7 +702,7 @@ Notice that each `probs` is an array of length `n_probs`.
] ]
``` ```
- **GET** `/metrics`: [Prometheus](https://prometheus.io/) compatible metrics exporter endpoint if `--metrics` is enabled: ### GET `/metrics`: Prometheus compatible metrics exporter endpoint if `--metrics` is enabled:
Available metrics: Available metrics:
- `llamacpp:prompt_tokens_total`: Number of prompt tokens processed. - `llamacpp:prompt_tokens_total`: Number of prompt tokens processed.
@ -750,13 +714,13 @@ Available metrics:
- `llamacpp:requests_processing`: Number of requests processing. - `llamacpp:requests_processing`: Number of requests processing.
- `llamacpp:requests_deferred`: Number of requests deferred. - `llamacpp:requests_deferred`: Number of requests deferred.
- **POST** `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file. ### POST `/slots/{id_slot}?action=save`: Save the prompt cache of the specified slot to a file.
*Options:* *Options:*
`filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter. `filename`: Name of the file to save the slot's prompt cache. The file will be saved in the directory specified by the `--slot-save-path` server parameter.
### Result JSON **Response format**
```json ```json
{ {
@ -770,13 +734,13 @@ Available metrics:
} }
``` ```
- **POST** `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file. ### POST `/slots/{id_slot}?action=restore`: Restore the prompt cache of the specified slot from a file.
*Options:* *Options:*
`filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter. `filename`: Name of the file to restore the slot's prompt cache from. The file should be located in the directory specified by the `--slot-save-path` server parameter.
### Result JSON **Response format**
```json ```json
{ {
@ -790,9 +754,9 @@ Available metrics:
} }
``` ```
- **POST** `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot. ### POST `/slots/{id_slot}?action=erase`: Erase the prompt cache of the specified slot.
### Result JSON **Response format**
```json ```json
{ {
@ -801,6 +765,42 @@ Available metrics:
} }
``` ```
### GET `/lora-adapters`: Get list of all LoRA adapters
If an adapter is disabled, the scale will be set to 0.
**Response format**
```json
[
{
"id": 0,
"path": "my_adapter_1.gguf",
"scale": 0.0
},
{
"id": 1,
"path": "my_adapter_2.gguf",
"scale": 0.0
}
]
```
### POST `/lora-adapters`: Set list of LoRA adapters
To disable an adapter, either remove it from the list below, or set scale to 0.
**Request format**
To know the `id` of the adapter, use GET `/lora-adapters`
```json
[
{"id": 0, "scale": 0.2},
{"id": 1, "scale": 0.8}
]
```
## More examples ## More examples
### Change system prompt on runtime ### Change system prompt on runtime

View File

@ -78,6 +78,7 @@ enum server_task_type {
SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_SAVE,
SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE, SERVER_TASK_TYPE_SLOT_ERASE,
SERVER_TASK_TYPE_SET_LORA,
}; };
struct server_task { struct server_task {
@ -622,6 +623,7 @@ struct server_response {
struct server_context { struct server_context {
llama_model * model = nullptr; llama_model * model = nullptr;
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
std::vector<llama_lora_adapter_container> lora_adapters;
gpt_params params; gpt_params params;
@ -681,6 +683,7 @@ struct server_context {
model = llama_init.model; model = llama_init.model;
ctx = llama_init.context; ctx = llama_init.context;
lora_adapters = llama_init.lora_adapters;
params.n_parallel -= 1; // but be sneaky about it params.n_parallel -= 1; // but be sneaky about it
if (model == nullptr) { if (model == nullptr) {
LOG_ERROR("unable to load model", {{"model", params.model}}); LOG_ERROR("unable to load model", {{"model", params.model}});
@ -1850,6 +1853,14 @@ struct server_context {
}; };
queue_results.send(result); queue_results.send(result);
} break; } break;
case SERVER_TASK_TYPE_SET_LORA:
{
llama_lora_adapters_apply(ctx, lora_adapters);
server_task_result result;
result.id = task.id;
result.data = json{{ "success", true }};
queue_results.send(result);
} break;
} }
} }
@ -3328,6 +3339,55 @@ int main(int argc, char ** argv) {
return res.set_content(root.dump(), "application/json; charset=utf-8"); return res.set_content(root.dump(), "application/json; charset=utf-8");
}; };
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json result = json::array();
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
auto & la = ctx_server.lora_adapters[i];
result.push_back({
{"id", i},
{"path", la.path},
{"scale", la.scale},
});
}
res.set_content(result.dump(), "application/json");
res.status = 200; // HTTP OK
};
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const std::vector<json> body = json::parse(req.body);
int max_idx = ctx_server.lora_adapters.size();
// clear existing value
for (auto & la : ctx_server.lora_adapters) {
la.scale = 0.0f;
}
// set value
for (auto entry : body) {
int id = entry.at("id");
float scale = entry.at("scale");
if (0 <= id && id < max_idx) {
ctx_server.lora_adapters[id].scale = scale;
} else {
throw std::runtime_error("invalid adapter id");
}
}
server_task task;
task.type = SERVER_TASK_TYPE_SET_LORA;
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
res.set_content(result.data.dump(), "application/json");
res.status = 200; // HTTP OK
};
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
res.set_content(reinterpret_cast<const char*>(content), len, mime_type); res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
@ -3366,7 +3426,6 @@ int main(int argc, char ** argv) {
// register API routes // register API routes
svr->Get ("/health", handle_health); svr->Get ("/health", handle_health);
svr->Get ("/slots", handle_slots);
svr->Get ("/metrics", handle_metrics); svr->Get ("/metrics", handle_metrics);
svr->Get ("/props", handle_props); svr->Get ("/props", handle_props);
svr->Get ("/v1/models", handle_models); svr->Get ("/v1/models", handle_models);
@ -3381,6 +3440,11 @@ int main(int argc, char ** argv) {
svr->Post("/v1/embeddings", handle_embeddings); svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize); svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize); svr->Post("/detokenize", handle_detokenize);
// LoRA adapters hotswap
svr->Get ("/lora-adapters", handle_lora_adapters_list);
svr->Post("/lora-adapters", handle_lora_adapters_apply);
// Save & load slots
svr->Get ("/slots", handle_slots);
if (!params.slot_save_path.empty()) { if (!params.slot_save_path.empty()) {
// only enable slot endpoints if slot_save_path is set // only enable slot endpoints if slot_save_path is set
svr->Post("/slots/:id_slot", handle_slots_action); svr->Post("/slots/:id_slot", handle_slots_action);

View File

@ -0,0 +1,36 @@
@llama.cpp
@lora
Feature: llama.cpp server
Background: Server startup
Given a server listening on localhost:8080
And a model url https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/stories15M_MOE-F16.gguf
And a model file stories15M_MOE-F16.gguf
And a model alias stories15M_MOE
And a lora adapter file from https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf
And 42 as server seed
And 1024 as batch size
And 1024 as ubatch size
And 2048 KV cache size
And 64 max tokens to predict
And 0.0 temperature
Then the server is starting
Then the server is healthy
Scenario: Completion LoRA disabled
Given switch off lora adapter 0
Given a prompt:
"""
Look in thy glass
"""
And a completion request with no api error
Then 64 tokens are predicted matching little|girl|three|years|old
Scenario: Completion LoRA enabled
Given switch on lora adapter 0
Given a prompt:
"""
Look in thy glass
"""
And a completion request with no api error
Then 64 tokens are predicted matching eye|love|glass|sun

View File

@ -7,6 +7,7 @@ import subprocess
import sys import sys
import threading import threading
import time import time
import requests
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import closing from contextlib import closing
from re import RegexFlag from re import RegexFlag
@ -70,6 +71,7 @@ def step_server_config(context, server_fqdn: str, server_port: str):
context.user_api_key = None context.user_api_key = None
context.response_format = None context.response_format = None
context.temperature = None context.temperature = None
context.lora_file = None
context.tasks_result = [] context.tasks_result = []
context.concurrent_tasks = [] context.concurrent_tasks = []
@ -82,6 +84,12 @@ def step_download_hf_model(context, hf_file: str, hf_repo: str):
context.model_hf_file = hf_file context.model_hf_file = hf_file
context.model_file = os.path.basename(hf_file) context.model_file = os.path.basename(hf_file)
@step('a lora adapter file from {lora_file_url}')
def step_download_lora_file(context, lora_file_url: str):
file_name = lora_file_url.split('/').pop()
context.lora_file = f'../../../{file_name}'
with open(context.lora_file, 'wb') as f:
f.write(requests.get(lora_file_url).content)
@step('a model file {model_file}') @step('a model file {model_file}')
def step_model_file(context, model_file: str): def step_model_file(context, model_file: str):
@ -849,6 +857,17 @@ async def step_erase_slot(context, slot_id):
context.response = response context.response = response
@step('switch {on_or_off} lora adapter {lora_id:d}')
@async_run_until_complete
async def toggle_lora_adapter(context, on_or_off: str, lora_id: int):
async with aiohttp.ClientSession() as session:
async with session.post(f'{context.base_url}/lora-adapters',
json=[{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}],
headers={"Content-Type": "application/json"}) as response:
context.response = response
print([{'id': lora_id, 'scale': 1 if on_or_off == 'on' else 0}])
@step('the server responds with status code {status_code:d}') @step('the server responds with status code {status_code:d}')
def step_server_responds_with_status_code(context, status_code): def step_server_responds_with_status_code(context, status_code):
assert context.response.status == status_code assert context.response.status == status_code
@ -1326,6 +1345,8 @@ def start_server_background(context):
server_args.extend(['--grp-attn-w', context.n_ga_w]) server_args.extend(['--grp-attn-w', context.n_ga_w])
if context.debug: if context.debug:
server_args.append('--verbose') server_args.append('--verbose')
if context.lora_file:
server_args.extend(['--lora', context.lora_file])
if 'SERVER_LOG_FORMAT_JSON' not in os.environ: if 'SERVER_LOG_FORMAT_JSON' not in os.environ:
server_args.extend(['--log-format', "text"]) server_args.extend(['--log-format', "text"])

View File

@ -4,3 +4,4 @@ huggingface_hub~=0.20.3
numpy~=1.26.4 numpy~=1.26.4
openai~=1.30.3 openai~=1.30.3
prometheus-client~=0.20.0 prometheus-client~=0.20.0
requests~=2.32.3

View File

@ -174,7 +174,7 @@ class Metadata:
org_component, model_full_name_component = None, model_id org_component, model_full_name_component = None, model_id
# Check if we erroneously matched against './' or '../' etc... # Check if we erroneously matched against './' or '../' etc...
if org_component is not None and org_component[0] == '.': if org_component is not None and len(org_component) > 0 and org_component[0] == '.':
org_component = None org_component = None
name_parts: list[str] = model_full_name_component.split('-') name_parts: list[str] = model_full_name_component.split('-')