mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
Server: Use multi-task for embeddings endpoint (#6001)
* use multitask for embd endpoint * specify types * remove redundant {"n_predict", 0}
This commit is contained in:
parent
306d34be7a
commit
99b71c068f
@ -2763,6 +2763,7 @@ int main(int argc, char ** argv) {
|
|||||||
res.set_header("Access-Control-Allow-Credentials", "true");
|
res.set_header("Access-Control-Allow-Credentials", "true");
|
||||||
res.set_header("Access-Control-Allow-Methods", "POST");
|
res.set_header("Access-Control-Allow-Methods", "POST");
|
||||||
res.set_header("Access-Control-Allow-Headers", "*");
|
res.set_header("Access-Control-Allow-Headers", "*");
|
||||||
|
return res.set_content("", "application/json; charset=utf-8");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr->set_logger(log_server_request);
|
svr->set_logger(log_server_request);
|
||||||
@ -3371,44 +3372,37 @@ int main(int argc, char ** argv) {
|
|||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
bool is_openai = false;
|
bool is_openai = false;
|
||||||
|
|
||||||
// an input prompt can string or a list of tokens (integer)
|
// an input prompt can be a string or a list of tokens (integer)
|
||||||
std::vector<json> prompts;
|
json prompt;
|
||||||
if (body.count("input") != 0) {
|
if (body.count("input") != 0) {
|
||||||
is_openai = true;
|
is_openai = true;
|
||||||
if (body["input"].is_array()) {
|
prompt = body["input"];
|
||||||
// support multiple prompts
|
|
||||||
for (const json & elem : body["input"]) {
|
|
||||||
prompts.push_back(elem);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// single input prompt
|
|
||||||
prompts.push_back(body["input"]);
|
|
||||||
}
|
|
||||||
} else if (body.count("content") != 0) {
|
} else if (body.count("content") != 0) {
|
||||||
// only support single prompt here
|
// with "content", we only support single prompt
|
||||||
std::string content = body["content"];
|
prompt = std::vector<std::string>{body["content"]};
|
||||||
prompts.push_back(content);
|
|
||||||
} else {
|
} else {
|
||||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// process all prompts
|
|
||||||
json responses = json::array();
|
|
||||||
for (auto & prompt : prompts) {
|
|
||||||
// TODO @ngxson : maybe support multitask for this endpoint?
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
|
json responses;
|
||||||
|
{
|
||||||
const int id_task = ctx_server.queue_tasks.get_new_id();
|
const int id_task = ctx_server.queue_tasks.get_new_id();
|
||||||
|
|
||||||
ctx_server.queue_results.add_waiting_task_id(id_task);
|
ctx_server.queue_results.add_waiting_task_id(id_task);
|
||||||
ctx_server.request_completion(id_task, -1, { {"prompt", prompt}, { "n_predict", 0}}, false, true);
|
ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
|
||||||
|
|
||||||
// get the result
|
// get the result
|
||||||
server_task_result result = ctx_server.queue_results.recv(id_task);
|
server_task_result result = ctx_server.queue_results.recv(id_task);
|
||||||
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
if (!result.error) {
|
if (!result.error) {
|
||||||
// append to the responses
|
if (result.data.count("results")) {
|
||||||
responses.push_back(result.data);
|
// result for multi-task
|
||||||
|
responses = result.data["results"];
|
||||||
|
} else {
|
||||||
|
// result for single task
|
||||||
|
responses = std::vector<json>{result.data};
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// error received, ignore everything else
|
// error received, ignore everything else
|
||||||
res_error(res, result.data);
|
res_error(res, result.data);
|
||||||
@ -3417,24 +3411,19 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// write JSON response
|
// write JSON response
|
||||||
json root;
|
json root = is_openai
|
||||||
if (is_openai) {
|
? format_embeddings_response_oaicompat(body, responses)
|
||||||
json res_oai = json::array();
|
: responses[0];
|
||||||
int i = 0;
|
|
||||||
for (auto & elem : responses) {
|
|
||||||
res_oai.push_back(json{
|
|
||||||
{"embedding", json_value(elem, "embedding", json::array())},
|
|
||||||
{"index", i++},
|
|
||||||
{"object", "embedding"}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
root = format_embeddings_response_oaicompat(body, res_oai);
|
|
||||||
} else {
|
|
||||||
root = responses[0];
|
|
||||||
}
|
|
||||||
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
return res.set_content(root.dump(), "application/json; charset=utf-8");
|
||||||
};
|
};
|
||||||
|
|
||||||
|
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
||||||
|
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
|
||||||
|
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// Router
|
// Router
|
||||||
//
|
//
|
||||||
@ -3446,17 +3435,6 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// using embedded static files
|
// using embedded static files
|
||||||
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
|
||||||
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
|
|
||||||
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
|
|
||||||
return false;
|
|
||||||
};
|
|
||||||
};
|
|
||||||
|
|
||||||
svr->Options(R"(/.*)", [](const httplib::Request &, httplib::Response & res) {
|
|
||||||
// TODO @ngxson : I have no idea what it is... maybe this is redundant?
|
|
||||||
return res.set_content("", "application/json; charset=utf-8");
|
|
||||||
});
|
|
||||||
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
|
||||||
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
|
svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
|
||||||
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
|
||||||
|
@ -529,6 +529,16 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
|
|||||||
}
|
}
|
||||||
|
|
||||||
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
|
static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
|
||||||
|
json data = json::array();
|
||||||
|
int i = 0;
|
||||||
|
for (auto & elem : embeddings) {
|
||||||
|
data.push_back(json{
|
||||||
|
{"embedding", json_value(elem, "embedding", json::array())},
|
||||||
|
{"index", i++},
|
||||||
|
{"object", "embedding"}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
json res = json {
|
json res = json {
|
||||||
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
{"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
|
||||||
{"object", "list"},
|
{"object", "list"},
|
||||||
@ -536,7 +546,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
|
|||||||
{"prompt_tokens", 0},
|
{"prompt_tokens", 0},
|
||||||
{"total_tokens", 0}
|
{"total_tokens", 0}
|
||||||
}},
|
}},
|
||||||
{"data", embeddings}
|
{"data", data}
|
||||||
};
|
};
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
Loading…
Reference in New Issue
Block a user