server : refactor middleware and /health endpoint (#9056)

* server : refactor middleware and /health endpoint

* move "fail_on_no_slot" to /slots

* Update examples/server/server.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* fix server tests

* fix CI

* update server docs

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Xuan Son Nguyen 2024-08-16 17:19:05 +02:00 committed by GitHub
parent d565bb2fd5
commit 8b3befc0e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 177 additions and 217 deletions

View File

@ -368,15 +368,16 @@ node index.js
## API Endpoints ## API Endpoints
### GET `/health`: Returns the current state of the server ### GET `/health`: Returns heath check result
- 503 -> `{"status": "loading model"}` if the model is still being loaded. **Response format**
- 500 -> `{"status": "error"}` if the model failed to load.
- 200 -> `{"status": "ok", "slots_idle": 1, "slots_processing": 2 }` if the model is successfully loaded and the server is ready for further requests mentioned below.
- 200 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if no slots are currently available.
- 503 -> `{"status": "no slot available", "slots_idle": 0, "slots_processing": 32}` if the query parameter `fail_on_no_slot` is provided and no slots are currently available.
If the query parameter `include_slots` is passed, `slots` field will contain internal slots data except if `--slots-endpoint-disable` is set. - HTTP status code 503
- Body: `{"error": {"code": 503, "message": "Loading model", "type": "unavailable_error"}}`
- Explanation: the model is still being loaded.
- HTTP status code 200
- Body: `{"status": "ok" }`
- Explanation: the model is successfully loaded and the server is ready.
### POST `/completion`: Given a `prompt`, it returns the predicted completion. ### POST `/completion`: Given a `prompt`, it returns the predicted completion.
@ -639,10 +640,16 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
}' }'
``` ```
### GET `/slots`: Returns the current slots processing state. Can be disabled with `--slots-endpoint-disable`. ### GET `/slots`: Returns the current slots processing state
This endpoint can be disabled with `--no-slots`
If query param `?fail_on_no_slot=1` is set, this endpoint will respond with status code 503 if there is no available slots.
**Response format** **Response format**
Example:
```json ```json
[ [
{ {
@ -702,7 +709,13 @@ Given a ChatML-formatted json description in `messages`, it returns the predicte
] ]
``` ```
### GET `/metrics`: Prometheus compatible metrics exporter endpoint if `--metrics` is enabled: Possible values for `slot[i].state` are:
- `0`: SLOT_STATE_IDLE
- `1`: SLOT_STATE_PROCESSING
### GET `/metrics`: Prometheus compatible metrics exporter
This endpoint is only accessible if `--metrics` is set.
Available metrics: Available metrics:
- `llamacpp:prompt_tokens_total`: Number of prompt tokens processed. - `llamacpp:prompt_tokens_total`: Number of prompt tokens processed.
@ -767,6 +780,10 @@ Available metrics:
### GET `/lora-adapters`: Get list of all LoRA adapters ### GET `/lora-adapters`: Get list of all LoRA adapters
This endpoint returns the loaded LoRA adapters. You can add adapters using `--lora` when starting the server, for example: `--lora my_adapter_1.gguf --lora my_adapter_2.gguf ...`
By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`
If an adapter is disabled, the scale will be set to 0. If an adapter is disabled, the scale will be set to 0.
**Response format** **Response format**

View File

@ -15,6 +15,8 @@
// Change JSON_ASSERT from assert() to GGML_ASSERT: // Change JSON_ASSERT from assert() to GGML_ASSERT:
#define JSON_ASSERT GGML_ASSERT #define JSON_ASSERT GGML_ASSERT
#include "json.hpp" #include "json.hpp"
// mime type for sending response
#define MIMETYPE_JSON "application/json; charset=utf-8"
// auto generated files (update with ./deps.sh) // auto generated files (update with ./deps.sh)
#include "colorthemes.css.hpp" #include "colorthemes.css.hpp"
@ -67,7 +69,6 @@ enum slot_command {
enum server_state { enum server_state {
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
SERVER_STATE_READY, // Server is ready and model is loaded SERVER_STATE_READY, // Server is ready and model is loaded
SERVER_STATE_ERROR // An error occurred, load_model failed
}; };
enum server_task_type { enum server_task_type {
@ -695,6 +696,7 @@ struct server_context {
add_bos_token = llama_add_bos_token(model); add_bos_token = llama_add_bos_token(model);
has_eos_token = !llama_add_eos_token(model); has_eos_token = !llama_add_eos_token(model);
return true; return true;
} }
@ -2555,19 +2557,19 @@ int main(int argc, char ** argv) {
svr->set_default_headers({{"Server", "llama.cpp"}}); svr->set_default_headers({{"Server", "llama.cpp"}});
// CORS preflight // CORS preflight
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) { svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); // Access-Control-Allow-Origin is already set by middleware
res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Credentials", "true");
res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Methods", "POST");
res.set_header("Access-Control-Allow-Headers", "*"); res.set_header("Access-Control-Allow-Headers", "*");
return res.set_content("", "application/json; charset=utf-8"); return res.set_content("", "text/html"); // blank response, no data
}); });
svr->set_logger(log_server_request); svr->set_logger(log_server_request);
auto res_error = [](httplib::Response & res, json error_data) { auto res_error = [](httplib::Response & res, json error_data) {
json final_response {{"error", error_data}}; json final_response {{"error", error_data}};
res.set_content(final_response.dump(), "application/json; charset=utf-8"); res.set_content(final_response.dump(), MIMETYPE_JSON);
res.status = json_value(error_data, "code", 500); res.status = json_value(error_data, "code", 500);
}; };
@ -2597,11 +2599,6 @@ int main(int argc, char ** argv) {
svr->set_read_timeout (params.timeout_read); svr->set_read_timeout (params.timeout_read);
svr->set_write_timeout(params.timeout_write); svr->set_write_timeout(params.timeout_write);
if (!svr->bind_to_port(params.hostname, params.port)) {
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port);
return 1;
}
std::unordered_map<std::string, std::string> log_data; std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = params.hostname; log_data["hostname"] = params.hostname;
@ -2617,35 +2614,6 @@ int main(int argc, char ** argv) {
// Necessary similarity of prompt for slot selection // Necessary similarity of prompt for slot selection
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
// load the model
if (!ctx_server.load_model(params)) {
state.store(SERVER_STATE_ERROR);
return 1;
} else {
ctx_server.init();
state.store(SERVER_STATE_READY);
}
LOG_INFO("model loaded", {});
const auto model_meta = ctx_server.model_meta();
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (params.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
params.chat_template = "chatml";
}
}
// print sample chat example to make it clear which template is used
{
LOG_INFO("chat template", {
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
{"built_in", params.chat_template.empty()},
});
}
// //
// Middlewares // Middlewares
// //
@ -2689,8 +2657,6 @@ int main(int argc, char ** argv) {
} }
// API key is invalid or not provided // API key is invalid or not provided
// TODO: make another middleware for CORS related logic
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
LOG_WARNING("Unauthorized: Invalid API Key", {}); LOG_WARNING("Unauthorized: Invalid API Key", {});
@ -2698,8 +2664,21 @@ int main(int argc, char ** argv) {
return false; return false;
}; };
auto middleware_server_state = [&res_error, &state](const httplib::Request &, httplib::Response & res) {
server_state current_state = state.load();
if (current_state == SERVER_STATE_LOADING_MODEL) {
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
return false;
}
return true;
};
// register server middlewares // register server middlewares
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) { svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
if (!middleware_server_state(req, res)) {
return httplib::Server::HandlerResponse::Handled;
}
if (!middleware_validate_api_key(req, res)) { if (!middleware_validate_api_key(req, res)) {
return httplib::Server::HandlerResponse::Handled; return httplib::Server::HandlerResponse::Handled;
} }
@ -2710,62 +2689,15 @@ int main(int argc, char ** argv) {
// Route handlers (or controllers) // Route handlers (or controllers)
// //
const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) { const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
server_state current_state = state.load(); // error and loading states are handled by middleware
switch (current_state) { json health = {{"status", "ok"}};
case SERVER_STATE_READY:
{
// request slots data using task queue
server_task task;
task.id = ctx_server.queue_tasks.get_new_id();
task.type = SERVER_TASK_TYPE_METRICS;
task.id_target = -1;
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
// get the result
server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
const int n_idle_slots = result.data.at("idle");
const int n_processing_slots = result.data.at("processing");
json health = {
{"status", "ok"},
{"slots_idle", n_idle_slots},
{"slots_processing", n_processing_slots}
};
res.status = 200; // HTTP OK
if (params.endpoint_slots && req.has_param("include_slots")) {
health["slots"] = result.data.at("slots");
}
if (n_idle_slots == 0) {
health["status"] = "no slot available";
if (req.has_param("fail_on_no_slot")) {
res.status = 503; // HTTP Service Unavailable
}
}
res.set_content(health.dump(), "application/json"); res.set_content(health.dump(), "application/json");
break;
}
case SERVER_STATE_LOADING_MODEL:
{
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
} break;
case SERVER_STATE_ERROR:
{
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
} break;
}
}; };
const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) { const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
if (!params.endpoint_slots) { if (!params.endpoint_slots) {
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2783,13 +2715,22 @@ int main(int argc, char ** argv) {
server_task_result result = ctx_server.queue_results.recv(task.id); server_task_result result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id); ctx_server.queue_results.remove_waiting_task_id(task.id);
res.set_content(result.data.at("slots").dump(), "application/json"); // optionally return "fail_on_no_slot" error
const int n_idle_slots = result.data.at("idle");
if (req.has_param("fail_on_no_slot")) {
if (n_idle_slots == 0) {
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
return;
}
}
res.set_content(result.data.at("slots").dump(), MIMETYPE_JSON);
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
if (!params.endpoint_metrics) { if (!params.endpoint_metrics) {
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
@ -2914,7 +2855,7 @@ int main(int argc, char ** argv) {
if (result.error) { if (result.error) {
res_error(res, result.data); res_error(res, result.data);
} else { } else {
res.set_content(result.data.dump(), "application/json"); res.set_content(result.data.dump(), MIMETYPE_JSON);
} }
}; };
@ -2944,7 +2885,7 @@ int main(int argc, char ** argv) {
if (result.error) { if (result.error) {
res_error(res, result.data); res_error(res, result.data);
} else { } else {
res.set_content(result.data.dump(), "application/json"); res.set_content(result.data.dump(), MIMETYPE_JSON);
} }
}; };
@ -2964,13 +2905,11 @@ int main(int argc, char ** argv) {
if (result.error) { if (result.error) {
res_error(res, result.data); res_error(res, result.data);
} else { } else {
res.set_content(result.data.dump(), "application/json"); res.set_content(result.data.dump(), MIMETYPE_JSON);
} }
}; };
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
std::string id_slot_str = req.path_params.at("id_slot"); std::string id_slot_str = req.path_params.at("id_slot");
int id_slot; int id_slot;
@ -2994,7 +2933,7 @@ int main(int argc, char ** argv) {
} }
}; };
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_props = [&ctx_server](const httplib::Request &, httplib::Response & res) {
std::string template_key = "tokenizer.chat_template", curr_tmpl; std::string template_key = "tokenizer.chat_template", curr_tmpl;
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
if (tlen > 0) { if (tlen > 0) {
@ -3003,7 +2942,6 @@ int main(int argc, char ** argv) {
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
} }
} }
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = { json data = {
{ "system_prompt", ctx_server.system_prompt.c_str() }, { "system_prompt", ctx_server.system_prompt.c_str() },
{ "default_generation_settings", ctx_server.default_generation_settings_for_props }, { "default_generation_settings", ctx_server.default_generation_settings_for_props },
@ -3011,7 +2949,7 @@ int main(int argc, char ** argv) {
{ "chat_template", curr_tmpl.c_str() } { "chat_template", curr_tmpl.c_str() }
}; };
res.set_content(data.dump(), "application/json; charset=utf-8"); res.set_content(data.dump(), MIMETYPE_JSON);
}; };
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
@ -3020,8 +2958,6 @@ int main(int argc, char ** argv) {
return; return;
} }
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body); json data = json::parse(req.body);
const int id_task = ctx_server.queue_tasks.get_new_id(); const int id_task = ctx_server.queue_tasks.get_new_id();
@ -3032,7 +2968,7 @@ int main(int argc, char ** argv) {
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
if (!result.error && result.stop) { if (!result.error && result.stop) {
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else { } else {
res_error(res, result.data); res_error(res, result.data);
} }
@ -3095,9 +3031,7 @@ int main(int argc, char ** argv) {
} }
}; };
const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) { const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json models = { json models = {
{"object", "list"}, {"object", "list"},
{"data", { {"data", {
@ -3106,12 +3040,12 @@ int main(int argc, char ** argv) {
{"object", "model"}, {"object", "model"},
{"created", std::time(0)}, {"created", std::time(0)},
{"owned_by", "llamacpp"}, {"owned_by", "llamacpp"},
{"meta", model_meta} {"meta", ctx_server.model_meta()}
}, },
}} }}
}; };
res.set_content(models.dump(), "application/json; charset=utf-8"); res.set_content(models.dump(), MIMETYPE_JSON);
}; };
const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
@ -3119,8 +3053,6 @@ int main(int argc, char ** argv) {
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return; return;
} }
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
const int id_task = ctx_server.queue_tasks.get_new_id(); const int id_task = ctx_server.queue_tasks.get_new_id();
@ -3135,7 +3067,7 @@ int main(int argc, char ** argv) {
if (!result.error && result.stop) { if (!result.error && result.stop) {
json result_oai = format_final_response_oaicompat(data, result.data, completion_id); json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else { } else {
res_error(res, result.data); res_error(res, result.data);
} }
@ -3197,8 +3129,6 @@ int main(int argc, char ** argv) {
return; return;
} }
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = json::parse(req.body); json data = json::parse(req.body);
const int id_task = ctx_server.queue_tasks.get_new_id(); const int id_task = ctx_server.queue_tasks.get_new_id();
@ -3209,7 +3139,7 @@ int main(int argc, char ** argv) {
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
if (!result.error && result.stop) { if (!result.error && result.stop) {
res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
} else { } else {
res_error(res, result.data); res_error(res, result.data);
} }
@ -3257,7 +3187,6 @@ int main(int argc, char ** argv) {
}; };
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
@ -3266,11 +3195,10 @@ int main(int argc, char ** argv) {
tokens = ctx_server.tokenize(body.at("content"), add_special); tokens = ctx_server.tokenize(body.at("content"), add_special);
} }
const json data = format_tokenizer_response(tokens); const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json; charset=utf-8"); return res.set_content(data.dump(), MIMETYPE_JSON);
}; };
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
std::string content; std::string content;
@ -3280,12 +3208,10 @@ int main(int argc, char ** argv) {
} }
const json data = format_detokenized_response(content); const json data = format_detokenized_response(content);
return res.set_content(data.dump(), "application/json; charset=utf-8"); return res.set_content(data.dump(), MIMETYPE_JSON);
}; };
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) { const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body); const json body = json::parse(req.body);
bool is_openai = false; bool is_openai = false;
@ -3331,11 +3257,10 @@ int main(int argc, char ** argv) {
json root = is_openai json root = is_openai
? format_embeddings_response_oaicompat(body, responses) ? format_embeddings_response_oaicompat(body, responses)
: responses[0]; : responses[0];
return res.set_content(root.dump(), "application/json; charset=utf-8"); return res.set_content(root.dump(), MIMETYPE_JSON);
}; };
const auto handle_lora_adapters_list = [&](const httplib::Request & req, httplib::Response & res) { const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json result = json::array(); json result = json::array();
for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) { for (size_t i = 0; i < ctx_server.lora_adapters.size(); ++i) {
auto & la = ctx_server.lora_adapters[i]; auto & la = ctx_server.lora_adapters[i];
@ -3345,13 +3270,11 @@ int main(int argc, char ** argv) {
{"scale", la.scale}, {"scale", la.scale},
}); });
} }
res.set_content(result.dump(), "application/json"); res.set_content(result.dump(), MIMETYPE_JSON);
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const std::vector<json> body = json::parse(req.body); const std::vector<json> body = json::parse(req.body);
int max_idx = ctx_server.lora_adapters.size(); int max_idx = ctx_server.lora_adapters.size();
@ -3379,7 +3302,7 @@ int main(int argc, char ** argv) {
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
res.set_content(result.data.dump(), "application/json"); res.set_content(result.data.dump(), MIMETYPE_JSON);
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
}; };
@ -3455,17 +3378,55 @@ int main(int argc, char ** argv) {
log_data["n_threads_http"] = std::to_string(params.n_threads_http); log_data["n_threads_http"] = std::to_string(params.n_threads_http);
svr->new_task_queue = [&params] { return new httplib::ThreadPool(params.n_threads_http); }; svr->new_task_queue = [&params] { return new httplib::ThreadPool(params.n_threads_http); };
LOG_INFO("HTTP server listening", log_data); // clean up function, to be called before exit
auto clean_up = [&svr]() {
svr->stop();
llama_backend_free();
};
// run the HTTP server in a thread - see comment below // bind HTTP listen port, run the HTTP server in a thread
std::thread t([&]() { if (!svr->bind_to_port(params.hostname, params.port)) {
if (!svr->listen_after_bind()) { LOG_ERROR("couldn't bind HTTP server socket", {
state.store(SERVER_STATE_ERROR); {"hostname", params.hostname},
{"port", params.port},
});
clean_up();
LOG_ERROR("exiting due to HTTP server error", {});
return 1; return 1;
} }
std::thread t([&]() { svr->listen_after_bind(); });
svr->wait_until_ready();
return 0; LOG_INFO("HTTP server is listening", log_data);
// load the model
LOG_INFO("loading model", log_data);
if (!ctx_server.load_model(params)) {
clean_up();
t.join();
LOG_ERROR("exiting due to model loading error", {});
return 1;
} else {
ctx_server.init();
state.store(SERVER_STATE_READY);
LOG_INFO("model loaded", {});
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
if (params.chat_template.empty()) {
if (!ctx_server.validate_model_chat_template()) {
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
params.chat_template = "chatml";
}
}
// print sample chat example to make it clear which template is used
{
LOG_INFO("chat template", {
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
{"built_in", params.chat_template.empty()},
}); });
}
ctx_server.queue_tasks.on_new_task(std::bind( ctx_server.queue_tasks.on_new_task(std::bind(
&server_context::process_single_task, &ctx_server, std::placeholders::_1)); &server_context::process_single_task, &ctx_server, std::placeholders::_1));
@ -3484,6 +3445,8 @@ int main(int argc, char ** argv) {
shutdown_handler = [&](int) { shutdown_handler = [&](int) {
ctx_server.queue_tasks.terminate(); ctx_server.queue_tasks.terminate();
}; };
ctx_server.queue_tasks.start_loop();
}
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
struct sigaction sigint_action; struct sigaction sigint_action;
@ -3499,12 +3462,8 @@ int main(int argc, char ** argv) {
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif #endif
ctx_server.queue_tasks.start_loop(); clean_up();
svr->stop();
t.join(); t.join();
llama_backend_free();
return 0; return 0;
} }

View File

@ -205,27 +205,20 @@ def step_start_server(context):
async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str): async def step_wait_for_the_server_to_be_started(context, expecting_status: Literal['healthy', 'ready', 'idle', 'busy'] | str):
match expecting_status: match expecting_status:
case 'healthy': case 'healthy':
await wait_for_health_status(context, context.base_url, 200, 'ok', await wait_for_slots_status(context, context.base_url, 200,
timeout=30) timeout=30)
case 'ready' | 'idle': case 'ready' | 'idle':
await wait_for_health_status(context, context.base_url, 200, 'ok', await wait_for_slots_status(context, context.base_url, 200,
timeout=30, timeout=30,
params={'fail_on_no_slot': 0, 'include_slots': 0}, params={'fail_on_no_slot': 1},
slots_idle=context.n_slots, slots_idle=context.n_slots,
slots_processing=0, slots_processing=0)
expected_slots=[{'id': slot_id, 'state': 0}
for slot_id in
range(context.n_slots if context.n_slots else 1)])
case 'busy': case 'busy':
await wait_for_health_status(context, context.base_url, 503, await wait_for_slots_status(context, context.base_url, 503,
'no slot available', params={'fail_on_no_slot': 1},
params={'fail_on_no_slot': 0, 'include_slots': 0},
slots_idle=0, slots_idle=0,
slots_processing=context.n_slots, slots_processing=context.n_slots)
expected_slots=[{'id': slot_id, 'state': 1}
for slot_id in
range(context.n_slots if context.n_slots else 1)])
case _: case _:
assert False, "unknown status" assert False, "unknown status"
@ -1187,17 +1180,15 @@ async def gather_tasks_results(context):
return n_completions return n_completions
async def wait_for_health_status(context, async def wait_for_slots_status(context,
base_url, base_url,
expected_http_status_code, expected_http_status_code,
expected_health_status,
timeout=3, timeout=3,
params=None, params=None,
slots_idle=None, slots_idle=None,
slots_processing=None, slots_processing=None):
expected_slots=None):
if context.debug: if context.debug:
print(f"Starting checking for health for expected_health_status={expected_health_status}") print(f"Starting checking for health for expected_http_status_code={expected_http_status_code}")
interval = 0.5 interval = 0.5
counter = 0 counter = 0
if 'GITHUB_ACTIONS' in os.environ: if 'GITHUB_ACTIONS' in os.environ:
@ -1205,25 +1196,18 @@ async def wait_for_health_status(context,
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
while True: while True:
async with await session.get(f'{base_url}/health', params=params) as health_response: async with await session.get(f'{base_url}/slots', params=params) as slots_response:
status_code = health_response.status status_code = slots_response.status
health = await health_response.json() slots = await slots_response.json()
if context.debug: if context.debug:
print(f"HEALTH - response for expected health status='{expected_health_status}' on " print(f"slots responses {slots}\n")
f"'{base_url}/health'?{params} is {health}\n") if status_code == 503 and status_code == expected_http_status_code:
if (status_code == expected_http_status_code
and health['status'] == expected_health_status
and (slots_idle is None or health['slots_idle'] == slots_idle)
and (slots_processing is None or health['slots_processing'] == slots_processing)):
if expected_slots is not None:
assert_slots_status(health['slots'], expected_slots)
return return
if (status_code == expected_http_status_code if status_code == 200 and status_code == expected_http_status_code:
and health['status'] == expected_health_status n_slots_idle = sum(1 if slot["state"] == 0 else 0 for slot in slots)
and (slots_idle is None or health['slots_idle'] == slots_idle) n_slots_processing = sum(1 if slot["state"] != 0 else 0 for slot in slots)
and (slots_processing is None or health['slots_processing'] == slots_processing)): if ((slots_idle is None or slots_idle == n_slots_idle)
if expected_slots is not None: and (slots_processing is None or slots_processing == n_slots_processing)):
assert_slots_status(health['slots'], expected_slots)
return return
await asyncio.sleep(interval) await asyncio.sleep(interval)
@ -1238,7 +1222,7 @@ async def wait_for_health_status(context,
if n_completions > 0: if n_completions > 0:
return return
assert False, f'{expected_health_status} timeout exceeded {counter}s>={timeout}' assert False, f'slots check timeout exceeded {counter}s>={timeout}'
def assert_embeddings(embeddings): def assert_embeddings(embeddings):