diff --git a/examples/server/oai.hpp b/examples/server/oai.hpp index bc5db6eef..43410f803 100644 --- a/examples/server/oai.hpp +++ b/examples/server/oai.hpp @@ -206,3 +206,18 @@ inline static std::vector format_partial_response_oaicompat(const task_res return std::vector({ret}); } + +inline static json format_embeddings_response_oaicompat(const json &request, const json &embeddings) +{ + json res = + json{ + {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))}, + {"object", "list"}, + {"usage", + json{{"prompt_tokens", 0}, + {"total_tokens", 0}}}, + {"data", embeddings} + }; + return res; +} + diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a48582ad9..11dd82c33 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2929,6 +2929,66 @@ int main(int argc, char **argv) return res.set_content(result.result_json.dump(), "application/json; charset=utf-8"); }); + svr.Post("/v1/embeddings", [&llama](const httplib::Request &req, httplib::Response &res) + { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + const json body = json::parse(req.body); + + json prompt; + if (body.count("input") != 0) + { + prompt = body["input"]; + // batch + if(prompt.is_array()) { + json data = json::array(); + int i = 0; + for (const json &elem : prompt) { + const int task_id = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(task_id); + llama.request_completion(task_id, { {"prompt", elem}, { "n_predict", 0} }, false, true, -1); + + // get the result + task_result result = llama.queue_results.recv(task_id); + llama.queue_results.remove_waiting_task_id(task_id); + + json embedding = json{ + {"embedding", json_value(result.result_json, "embedding", json::array())}, + {"index", i++}, + {"object", "embedding"} + }; + data.push_back(embedding); + } + json result = format_embeddings_response_oaicompat(body, data); + return res.set_content(result.dump(), "application/json; charset=utf-8"); + } + } + else + { + prompt = ""; + } + + // create and queue the task + const int task_id = llama.queue_tasks.get_new_id(); + llama.queue_results.add_waiting_task_id(task_id); + llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}}, false, true, -1); + + // get the result + task_result result = llama.queue_results.recv(task_id); + llama.queue_results.remove_waiting_task_id(task_id); + + json data = json::array({json{ + {"embedding", json_value(result.result_json, "embedding", json::array())}, + {"index", 0}, + {"object", "embedding"} + }} + ); + + json root = format_embeddings_response_oaicompat(body, data); + + // send the result + return res.set_content(root.dump(), "application/json; charset=utf-8"); + }); + // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // "Bus error: 10" - this is on macOS, it does not crash on Linux //std::thread t2([&]()