From 4330bd83feb39683de4bd7a34cfcf672ff8ac3e4 Mon Sep 17 00:00:00 2001 From: Laura Date: Thu, 11 Jan 2024 19:02:48 +0100 Subject: [PATCH] server : implement credentialed CORS (#4514) * Implement credentialed CORS according to MDN * Fix syntax error * Move validate_api_key up so it is defined before its first usage --- examples/server/server.cpp | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 345004fa1..031824e14 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2822,9 +2822,15 @@ int main(int argc, char **argv) std::atomic state{SERVER_STATE_LOADING_MODEL}; - svr.set_default_headers({{"Server", "llama.cpp"}, - {"Access-Control-Allow-Origin", "*"}, - {"Access-Control-Allow-Headers", "content-type"}}); + svr.set_default_headers({{"Server", "llama.cpp"}}); + + // CORS preflight + svr.Options(R"(.*)", [](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + }); svr.Get("/health", [&](const httplib::Request&, httplib::Response& res) { server_state current_state = state.load(); @@ -2987,9 +2993,9 @@ int main(int argc, char **argv) return false; }); - svr.Get("/props", [&llama](const httplib::Request & /*req*/, httplib::Response &res) + svr.Get("/props", [&llama](const httplib::Request & req, httplib::Response &res) { - res.set_header("Access-Control-Allow-Origin", "*"); + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = { { "user_name", llama.name_user.c_str() }, { "assistant_name", llama.name_assistant.c_str() } @@ -2999,6 +3005,7 @@ int main(int argc, char **argv) svr.Post("/completion", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -3066,8 +3073,9 @@ int main(int argc, char **argv) } }); - svr.Get("/v1/models", [¶ms](const httplib::Request&, httplib::Response& res) + svr.Get("/v1/models", [¶ms](const httplib::Request& req, httplib::Response& res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); std::time_t t = std::time(0); json models = { @@ -3085,9 +3093,11 @@ int main(int argc, char **argv) res.set_content(models.dump(), "application/json; charset=utf-8"); }); + // TODO: add mount point without "/v1" prefix -- how? svr.Post("/v1/chat/completions", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -3161,6 +3171,7 @@ int main(int argc, char **argv) svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); if (!validate_api_key(req, res)) { return; } @@ -3233,6 +3244,7 @@ int main(int argc, char **argv) svr.Post("/tokenize", [&llama](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::vector tokens; if (body.count("content") != 0) @@ -3245,6 +3257,7 @@ int main(int argc, char **argv) svr.Post("/detokenize", [&llama](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); std::string content; if (body.count("tokens") != 0) @@ -3259,6 +3272,7 @@ int main(int argc, char **argv) svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); const json body = json::parse(req.body); json prompt; if (body.count("content") != 0)