server : fix build + rename enums (#4870)

This commit is contained in:
Georgi Gerganov 2024-01-11 09:10:34 +02:00 committed by GitHub
parent cd108e641d
commit 5c1980d8d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -147,15 +147,15 @@ static std::vector<uint8_t> base64_decode(const std::string & encoded_string)
// parallel // parallel
// //
enum ServerState { enum server_state {
LOADING_MODEL, // Server is starting up, model not fully loaded yet SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
READY, // Server is ready and model is loaded SERVER_STATE_READY, // Server is ready and model is loaded
ERROR // An error occurred, load_model failed SERVER_STATE_ERROR // An error occurred, load_model failed
}; };
enum task_type { enum task_type {
COMPLETION_TASK, TASK_TYPE_COMPLETION,
CANCEL_TASK TASK_TYPE_CANCEL,
}; };
struct task_server { struct task_server {
@ -1402,7 +1402,7 @@ struct llama_server_context
task.data = std::move(data); task.data = std::move(data);
task.infill_mode = infill; task.infill_mode = infill;
task.embedding_mode = embedding; task.embedding_mode = embedding;
task.type = COMPLETION_TASK; task.type = TASK_TYPE_COMPLETION;
task.multitask_id = multitask_id; task.multitask_id = multitask_id;
// when a completion task's prompt array is not a singleton, we split it into multiple requests // when a completion task's prompt array is not a singleton, we split it into multiple requests
@ -1524,7 +1524,7 @@ struct llama_server_context
std::unique_lock<std::mutex> lock(mutex_tasks); std::unique_lock<std::mutex> lock(mutex_tasks);
task_server task; task_server task;
task.id = id_gen++; task.id = id_gen++;
task.type = CANCEL_TASK; task.type = TASK_TYPE_CANCEL;
task.target_id = task_id; task.target_id = task_id;
queue_tasks.push_back(task); queue_tasks.push_back(task);
condition_tasks.notify_one(); condition_tasks.notify_one();
@ -1560,7 +1560,7 @@ struct llama_server_context
queue_tasks.erase(queue_tasks.begin()); queue_tasks.erase(queue_tasks.begin());
switch (task.type) switch (task.type)
{ {
case COMPLETION_TASK: { case TASK_TYPE_COMPLETION: {
llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1)); llama_client_slot *slot = get_slot(json_value(task.data, "slot_id", -1));
if (slot == nullptr) if (slot == nullptr)
{ {
@ -1589,7 +1589,7 @@ struct llama_server_context
break; break;
} }
} break; } break;
case CANCEL_TASK: { // release slot linked with the task id case TASK_TYPE_CANCEL: { // release slot linked with the task id
for (auto & slot : slots) for (auto & slot : slots)
{ {
if (slot.task_id == task.target_id) if (slot.task_id == task.target_id)
@ -2798,24 +2798,24 @@ int main(int argc, char **argv)
httplib::Server svr; httplib::Server svr;
std::atomic<ServerState> server_state{LOADING_MODEL}; std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
svr.set_default_headers({{"Server", "llama.cpp"}, svr.set_default_headers({{"Server", "llama.cpp"},
{"Access-Control-Allow-Origin", "*"}, {"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}}); {"Access-Control-Allow-Headers", "content-type"}});
svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) { svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) {
ServerState current_state = server_state.load(); server_state current_state = state.load();
switch(current_state) { switch(current_state) {
case READY: case SERVER_STATE_READY:
res.set_content(R"({"status": "ok"})", "application/json"); res.set_content(R"({"status": "ok"})", "application/json");
res.status = 200; // HTTP OK res.status = 200; // HTTP OK
break; break;
case LOADING_MODEL: case SERVER_STATE_LOADING_MODEL:
res.set_content(R"({"status": "loading model"})", "application/json"); res.set_content(R"({"status": "loading model"})", "application/json");
res.status = 503; // HTTP Service Unavailable res.status = 503; // HTTP Service Unavailable
break; break;
case ERROR: case SERVER_STATE_ERROR:
res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json"); res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
res.status = 500; // HTTP Internal Server Error res.status = 500; // HTTP Internal Server Error
break; break;
@ -2891,7 +2891,7 @@ int main(int argc, char **argv)
{ {
if (!svr.listen_after_bind()) if (!svr.listen_after_bind())
{ {
server_state.store(ERROR); state.store(SERVER_STATE_ERROR);
return 1; return 1;
} }
@ -2901,11 +2901,11 @@ int main(int argc, char **argv)
// load the model // load the model
if (!llama.load_model(params)) if (!llama.load_model(params))
{ {
server_state.store(ERROR); state.store(SERVER_STATE_ERROR);
return 1; return 1;
} else { } else {
llama.initialize(); llama.initialize();
server_state.store(READY); state.store(SERVER_STATE_READY);
} }
// Middleware for API key validation // Middleware for API key validation