server : add a /health endpoint (#4860)

* added /health endpoint to the server

* added comments on the additional /health endpoint

* Better handling of server state

When the model is being loaded, the server state is `LOADING_MODEL`. If model-loading fails, the server state becomes `ERROR`, otherwise it becomes `READY`. The `/health` endpoint provides more granular messages now according to the server_state value.

* initialized server_state

* fixed a typo

* starting http server before initializing the model

* Update server.cpp

* Update server.cpp

* fixes

* fixes

* fixes

* made ServerState atomic and turned two-line spaces into one-line
This commit is contained in:
Behnam M 2024-01-10 14:56:05 -05:00 committed by GitHub
parent 57d016ba2d
commit cd108e641d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -26,6 +26,7 @@
#include <mutex> #include <mutex>
#include <chrono> #include <chrono>
#include <condition_variable> #include <condition_variable>
#include <atomic>
#ifndef SERVER_VERBOSE #ifndef SERVER_VERBOSE
#define SERVER_VERBOSE 1 #define SERVER_VERBOSE 1
@ -146,6 +147,12 @@ static std::vector<uint8_t> base64_decode(const std::string & encoded_string)
// parallel // parallel
// //
enum ServerState {
LOADING_MODEL, // Server is starting up, model not fully loaded yet
READY, // Server is ready and model is loaded
ERROR // An error occurred, load_model failed
};
enum task_type { enum task_type {
COMPLETION_TASK, COMPLETION_TASK,
CANCEL_TASK CANCEL_TASK
@ -2453,7 +2460,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
} }
static std::string random_string() static std::string random_string()
{ {
static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
@ -2790,15 +2796,117 @@ int main(int argc, char **argv)
{"system_info", llama_print_system_info()}, {"system_info", llama_print_system_info()},
}); });
// load the model httplib::Server svr;
if (!llama.load_model(params))
std::atomic<ServerState> server_state{LOADING_MODEL};
svr.set_default_headers({{"Server", "llama.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}});
svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) {
ServerState current_state = server_state.load();
switch(current_state) {
case READY:
res.set_content(R"({"status": "ok"})", "application/json");
res.status = 200; // HTTP OK
break;
case LOADING_MODEL:
res.set_content(R"({"status": "loading model"})", "application/json");
res.status = 503; // HTTP Service Unavailable
break;
case ERROR:
res.set_content(R"({"status": "error", "error": "Model failed to load"})", "application/json");
res.status = 500; // HTTP Internal Server Error
break;
}
});
svr.set_logger(log_server_request);
svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep)
{ {
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
try
{
std::rethrow_exception(std::move(ep));
}
catch (std::exception &e)
{
snprintf(buf, sizeof(buf), fmt, e.what());
}
catch (...)
{
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
}
res.set_content(buf, "text/plain; charset=utf-8");
res.status = 500;
});
svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
{
if (res.status == 401)
{
res.set_content("Unauthorized", "text/plain; charset=utf-8");
}
if (res.status == 400)
{
res.set_content("Invalid request", "text/plain; charset=utf-8");
}
else if (res.status == 404)
{
res.set_content("File Not Found", "text/plain; charset=utf-8");
res.status = 404;
}
});
// set timeouts and change hostname and port
svr.set_read_timeout (sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);
if (!svr.bind_to_port(sparams.hostname, sparams.port))
{
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
return 1; return 1;
} }
llama.initialize(); // Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);
httplib::Server svr; // to make it ctrl+clickable:
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = sparams.hostname;
log_data["port"] = std::to_string(sparams.port);
if (!sparams.api_key.empty()) {
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
}
LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]()
{
if (!svr.listen_after_bind())
{
server_state.store(ERROR);
return 1;
}
return 0;
});
// load the model
if (!llama.load_model(params))
{
server_state.store(ERROR);
return 1;
} else {
llama.initialize();
server_state.store(READY);
}
// Middleware for API key validation // Middleware for API key validation
auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool { auto validate_api_key = [&sparams](const httplib::Request &req, httplib::Response &res) -> bool {
@ -2826,10 +2934,6 @@ int main(int argc, char **argv)
return false; return false;
}; };
svr.set_default_headers({{"Server", "llama.cpp"},
{"Access-Control-Allow-Origin", "*"},
{"Access-Control-Allow-Headers", "content-type"}});
// this is only called if no index.html is found in the public --path // this is only called if no index.html is found in the public --path
svr.Get("/", [](const httplib::Request &, httplib::Response &res) svr.Get("/", [](const httplib::Request &, httplib::Response &res)
{ {
@ -2937,8 +3041,6 @@ int main(int argc, char **argv)
} }
}); });
svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res) svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res)
{ {
std::time_t t = std::time(0); std::time_t t = std::time(0);
@ -3157,81 +3259,6 @@ int main(int argc, char **argv)
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}); });
svr.set_logger(log_server_request);
svr.set_exception_handler([](const httplib::Request &, httplib::Response &res, std::exception_ptr ep)
{
const char fmt[] = "500 Internal Server Error\n%s";
char buf[BUFSIZ];
try
{
std::rethrow_exception(std::move(ep));
}
catch (std::exception &e)
{
snprintf(buf, sizeof(buf), fmt, e.what());
}
catch (...)
{
snprintf(buf, sizeof(buf), fmt, "Unknown Exception");
}
res.set_content(buf, "text/plain; charset=utf-8");
res.status = 500;
});
svr.set_error_handler([](const httplib::Request &, httplib::Response &res)
{
if (res.status == 401)
{
res.set_content("Unauthorized", "text/plain; charset=utf-8");
}
if (res.status == 400)
{
res.set_content("Invalid request", "text/plain; charset=utf-8");
}
else if (res.status == 404)
{
res.set_content("File Not Found", "text/plain; charset=utf-8");
res.status = 404;
}
});
// set timeouts and change hostname and port
svr.set_read_timeout (sparams.read_timeout);
svr.set_write_timeout(sparams.write_timeout);
if (!svr.bind_to_port(sparams.hostname, sparams.port))
{
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
return 1;
}
// Set the base directory for serving static files
svr.set_base_dir(sparams.public_path);
// to make it ctrl+clickable:
LOG_TEE("\nllama server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port);
std::unordered_map<std::string, std::string> log_data;
log_data["hostname"] = sparams.hostname;
log_data["port"] = std::to_string(sparams.port);
if (!sparams.api_key.empty()) {
log_data["api_key"] = "api_key: ****" + sparams.api_key.substr(sparams.api_key.length() - 4);
}
LOG_INFO("HTTP server listening", log_data);
// run the HTTP server in a thread - see comment below
std::thread t([&]()
{
if (!svr.listen_after_bind())
{
return 1;
}
return 0;
});
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
// "Bus error: 10" - this is on macOS, it does not crash on Linux // "Bus error: 10" - this is on macOS, it does not crash on Linux
//std::thread t2([&]() //std::thread t2([&]()