mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
server : add "/chat/completions" alias for "/v1/...` (#5722)
* Add "/chat/completions" as alias for "/v1/chat/completions" * merge to upstream master * minor : fix trailing whitespace --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
7c4263d426
commit
efc72253f7
@ -3211,87 +3211,88 @@ int main(int argc, char **argv)
|
|||||||
res.set_content(models.dump(), "application/json; charset=utf-8");
|
res.set_content(models.dump(), "application/json; charset=utf-8");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
const auto chat_completions = [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
||||||
|
{
|
||||||
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
if (!validate_api_key(req, res)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
|
||||||
|
|
||||||
// TODO: add mount point without "/v1" prefix -- how?
|
const int task_id = llama.queue_tasks.get_new_id();
|
||||||
svr.Post("/v1/chat/completions", [&llama, &validate_api_key, &sparams](const httplib::Request &req, httplib::Response &res)
|
llama.queue_results.add_waiting_task_id(task_id);
|
||||||
{
|
llama.request_completion(task_id, data, false, false, -1);
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
||||||
if (!validate_api_key(req, res)) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
json data = oaicompat_completion_params_parse(llama.model, json::parse(req.body), sparams.chat_template);
|
|
||||||
|
|
||||||
const int task_id = llama.queue_tasks.get_new_id();
|
if (!json_value(data, "stream", false)) {
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
std::string completion_text;
|
||||||
llama.request_completion(task_id, data, false, false, -1);
|
task_result result = llama.queue_results.recv(task_id);
|
||||||
|
|
||||||
if (!json_value(data, "stream", false)) {
|
if (!result.error && result.stop) {
|
||||||
std::string completion_text;
|
json oaicompat_result = format_final_response_oaicompat(data, result);
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
|
||||||
|
|
||||||
if (!result.error && result.stop) {
|
res.set_content(oaicompat_result.dump(-1, ' ', false,
|
||||||
json oaicompat_result = format_final_response_oaicompat(data, result);
|
json::error_handler_t::replace),
|
||||||
|
"application/json; charset=utf-8");
|
||||||
|
} else {
|
||||||
|
res.status = 500;
|
||||||
|
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
|
||||||
|
}
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
} else {
|
||||||
|
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
|
||||||
|
while (true) {
|
||||||
|
task_result llama_result = llama.queue_results.recv(task_id);
|
||||||
|
if (!llama_result.error) {
|
||||||
|
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
|
||||||
|
|
||||||
res.set_content(oaicompat_result.dump(-1, ' ', false,
|
for (auto it = result_array.begin(); it != result_array.end(); ++it)
|
||||||
json::error_handler_t::replace),
|
{
|
||||||
"application/json; charset=utf-8");
|
if (!it->empty()) {
|
||||||
} else {
|
|
||||||
res.status = 500;
|
|
||||||
res.set_content(result.result_json["content"], "text/plain; charset=utf-8");
|
|
||||||
}
|
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
|
||||||
} else {
|
|
||||||
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
|
|
||||||
while (true) {
|
|
||||||
task_result llama_result = llama.queue_results.recv(task_id);
|
|
||||||
if (!llama_result.error) {
|
|
||||||
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
|
|
||||||
|
|
||||||
for (auto it = result_array.begin(); it != result_array.end(); ++it)
|
|
||||||
{
|
|
||||||
if (!it->empty()) {
|
|
||||||
const std::string str =
|
|
||||||
"data: " +
|
|
||||||
it->dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
||||||
"\n\n";
|
|
||||||
LOG_VERBOSE("data stream", {{"to_send", str}});
|
|
||||||
if (!sink.write(str.c_str(), str.size())) {
|
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (llama_result.stop) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
const std::string str =
|
const std::string str =
|
||||||
"error: " +
|
"data: " +
|
||||||
llama_result.result_json.dump(-1, ' ', false,
|
it->dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||||
json::error_handler_t::replace) +
|
|
||||||
"\n\n";
|
"\n\n";
|
||||||
LOG_VERBOSE("data stream", {{"to_send", str}});
|
LOG_VERBOSE("data stream", {{"to_send", str}});
|
||||||
if (!sink.write(str.c_str(), str.size())) {
|
if (!sink.write(str.c_str(), str.size())) {
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sink.done();
|
if (llama_result.stop) {
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
break;
|
||||||
return true;
|
}
|
||||||
};
|
} else {
|
||||||
|
const std::string str =
|
||||||
auto on_complete = [task_id, &llama](bool) {
|
"error: " +
|
||||||
// cancel request
|
llama_result.result_json.dump(-1, ' ', false,
|
||||||
llama.request_cancel(task_id);
|
json::error_handler_t::replace) +
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
"\n\n";
|
||||||
};
|
LOG_VERBOSE("data stream", {{"to_send", str}});
|
||||||
|
if (!sink.write(str.c_str(), str.size())) {
|
||||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
});
|
sink.done();
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto on_complete = [task_id, &llama](bool) {
|
||||||
|
// cancel request
|
||||||
|
llama.request_cancel(task_id);
|
||||||
|
llama.queue_results.remove_waiting_task_id(task_id);
|
||||||
|
};
|
||||||
|
|
||||||
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
svr.Post("/chat/completions", chat_completions);
|
||||||
|
svr.Post("/v1/chat/completions", chat_completions);
|
||||||
|
|
||||||
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/infill", [&llama, &validate_api_key](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
|
@ -54,6 +54,28 @@ Feature: Parallel
|
|||||||
| disabled | 128 |
|
| disabled | 128 |
|
||||||
| enabled | 64 |
|
| enabled | 64 |
|
||||||
|
|
||||||
|
Scenario Outline: Multi users OAI completions compatibility no v1
|
||||||
|
Given a system prompt You are a writer.
|
||||||
|
And a model tinyllama-2
|
||||||
|
Given a prompt:
|
||||||
|
"""
|
||||||
|
Write a very long book.
|
||||||
|
"""
|
||||||
|
And a prompt:
|
||||||
|
"""
|
||||||
|
Write another a poem.
|
||||||
|
"""
|
||||||
|
And <n_predict> max tokens to predict
|
||||||
|
And streaming is <streaming>
|
||||||
|
Given concurrent OAI completions requests no v1
|
||||||
|
Then the server is busy
|
||||||
|
Then the server is idle
|
||||||
|
Then all prompts are predicted with <n_predict> tokens
|
||||||
|
Examples:
|
||||||
|
| streaming | n_predict |
|
||||||
|
| disabled | 128 |
|
||||||
|
| enabled | 64 |
|
||||||
|
|
||||||
Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969
|
Scenario: Multi users with total number of tokens to predict exceeds the KV Cache size #3969
|
||||||
Given a prompt:
|
Given a prompt:
|
||||||
"""
|
"""
|
||||||
|
@ -231,6 +231,7 @@ async def step_oai_chat_completions(context, api_error):
|
|||||||
completion = await oai_chat_completions(context.prompts.pop(),
|
completion = await oai_chat_completions(context.prompts.pop(),
|
||||||
context.system_prompt,
|
context.system_prompt,
|
||||||
context.base_url,
|
context.base_url,
|
||||||
|
'/v1/chat',
|
||||||
False,
|
False,
|
||||||
model=context.model if hasattr(context, 'model') else None,
|
model=context.model if hasattr(context, 'model') else None,
|
||||||
|
|
||||||
@ -288,6 +289,28 @@ async def step_oai_chat_completions(context):
|
|||||||
# user_prompt is inserted automatically
|
# user_prompt is inserted automatically
|
||||||
context.system_prompt,
|
context.system_prompt,
|
||||||
context.base_url,
|
context.base_url,
|
||||||
|
'/v1/chat/completions',
|
||||||
|
True, # async_client
|
||||||
|
model=context.model
|
||||||
|
if hasattr(context, 'model') else None,
|
||||||
|
n_predict=context.n_predict
|
||||||
|
if hasattr(context, 'n_predict') else None,
|
||||||
|
enable_streaming=context.enable_streaming
|
||||||
|
if hasattr(context, 'enable_streaming') else None,
|
||||||
|
server_seed=context.server_seed
|
||||||
|
if hasattr(context, 'server_seed') else None,
|
||||||
|
user_api_key=context.user_api_key
|
||||||
|
if hasattr(context, 'user_api_key') else None)
|
||||||
|
|
||||||
|
|
||||||
|
@step(u'concurrent OAI completions requests no v1')
|
||||||
|
@async_run_until_complete
|
||||||
|
async def step_oai_chat_completions(context):
|
||||||
|
await concurrent_requests(context, oai_chat_completions,
|
||||||
|
# user_prompt is inserted automatically
|
||||||
|
context.system_prompt,
|
||||||
|
context.base_url,
|
||||||
|
'/chat/completions',
|
||||||
True, # async_client
|
True, # async_client
|
||||||
model=context.model
|
model=context.model
|
||||||
if hasattr(context, 'model') else None,
|
if hasattr(context, 'model') else None,
|
||||||
@ -497,6 +520,7 @@ async def request_completion(prompt,
|
|||||||
async def oai_chat_completions(user_prompt,
|
async def oai_chat_completions(user_prompt,
|
||||||
system_prompt,
|
system_prompt,
|
||||||
base_url,
|
base_url,
|
||||||
|
base_path,
|
||||||
async_client,
|
async_client,
|
||||||
debug=False,
|
debug=False,
|
||||||
model=None,
|
model=None,
|
||||||
@ -537,7 +561,7 @@ async def oai_chat_completions(user_prompt,
|
|||||||
origin = 'llama.cpp'
|
origin = 'llama.cpp'
|
||||||
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin}
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(f'{base_url}/v1/chat/completions',
|
async with session.post(f'{base_url}{base_path}',
|
||||||
json=payload,
|
json=payload,
|
||||||
headers=headers) as response:
|
headers=headers) as response:
|
||||||
if enable_streaming:
|
if enable_streaming:
|
||||||
@ -579,7 +603,7 @@ async def oai_chat_completions(user_prompt,
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
openai.api_key = user_api_key
|
openai.api_key = user_api_key
|
||||||
openai.api_base = f'{base_url}/v1/chat'
|
openai.api_base = f'{base_url}{base_path}'
|
||||||
chat_completion = openai.Completion.create(
|
chat_completion = openai.Completion.create(
|
||||||
messages=payload['messages'],
|
messages=payload['messages'],
|
||||||
model=model,
|
model=model,
|
||||||
|
Loading…
Reference in New Issue
Block a user