diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b63a6f243..3172d96dd 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2763,6 +2763,7 @@ int main(int argc, char ** argv) { res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "POST"); res.set_header("Access-Control-Allow-Headers", "*"); + return res.set_content("", "application/json; charset=utf-8"); }); svr->set_logger(log_server_request); @@ -3371,44 +3372,37 @@ int main(int argc, char ** argv) { const json body = json::parse(req.body); bool is_openai = false; - // an input prompt can string or a list of tokens (integer) - std::vector prompts; + // an input prompt can be a string or a list of tokens (integer) + json prompt; if (body.count("input") != 0) { is_openai = true; - if (body["input"].is_array()) { - // support multiple prompts - for (const json & elem : body["input"]) { - prompts.push_back(elem); - } - } else { - // single input prompt - prompts.push_back(body["input"]); - } + prompt = body["input"]; } else if (body.count("content") != 0) { - // only support single prompt here - std::string content = body["content"]; - prompts.push_back(content); + // with "content", we only support single prompt + prompt = std::vector{body["content"]}; } else { res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); return; } - // process all prompts - json responses = json::array(); - for (auto & prompt : prompts) { - // TODO @ngxson : maybe support multitask for this endpoint? - // create and queue the task + // create and queue the task + json responses; + { const int id_task = ctx_server.queue_tasks.get_new_id(); - ctx_server.queue_results.add_waiting_task_id(id_task); - ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true); + ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true); // get the result server_task_result result = ctx_server.queue_results.recv(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task); if (!result.error) { - // append to the responses - responses.push_back(result.data); + if (result.data.count("results")) { + // result for multi-task + responses = result.data["results"]; + } else { + // result for single task + responses = std::vector{result.data}; + } } else { // error received, ignore everything else res_error(res, result.data); @@ -3417,24 +3411,19 @@ int main(int argc, char ** argv) { } // write JSON response - json root; - if (is_openai) { - json res_oai = json::array(); - int i = 0; - for (auto & elem : responses) { - res_oai.push_back(json{ - {"embedding", json_value(elem, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }); - } - root = format_embeddings_response_oaicompat(body, res_oai); - } else { - root = responses[0]; - } + json root = is_openai + ? format_embeddings_response_oaicompat(body, responses) + : responses[0]; return res.set_content(root.dump(), "application/json; charset=utf-8"); }; + auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { + return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { + res.set_content(reinterpret_cast(content), len, mime_type); + return false; + }; + }; + // // Router // @@ -3446,17 +3435,6 @@ int main(int argc, char ** argv) { } // using embedded static files - auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) { - return [content, len, mime_type](const httplib::Request &, httplib::Response & res) { - res.set_content(reinterpret_cast(content), len, mime_type); - return false; - }; - }; - - svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) { - // TODO @ngxson : I have no idea what it is... maybe this is redundant? - return res.set_content("", "application/json; charset=utf-8"); - }); svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 48aeef4eb..2ddb2cd21 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -529,6 +529,16 @@ static std::vector format_partial_response_oaicompat(json result, const st } static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) { + json data = json::array(); + int i = 0; + for (auto & elem : embeddings) { + data.push_back(json{ + {"embedding", json_value(elem, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }); + } + json res = json { {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, {"object", "list"}, @@ -536,7 +546,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso {"prompt_tokens", 0}, {"total_tokens", 0} }}, - {"data", embeddings} + {"data", data} }; return res;