server : maintain chat completion id for streaming responses (#5988)

* server: maintain chat completion id for streaming responses

* Update examples/server/utils.hpp

* Update examples/server/utils.hpp

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Minsoo Cheong 2024-03-11 17:09:32 +09:00 committed by GitHub
parent ecab1c75de
commit 332bdfd798
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 9 deletions

View File

@ -3195,11 +3195,12 @@ int main(int argc, char ** argv) {
ctx_server.queue_results.add_waiting_task_id(id_task); ctx_server.queue_results.add_waiting_task_id(id_task);
ctx_server.request_completion(id_task, -1, data, false, false); ctx_server.request_completion(id_task, -1, data, false, false);
const auto completion_id = gen_chatcmplid();
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
if (!result.error && result.stop) { if (!result.error && result.stop) {
json result_oai = format_final_response_oaicompat(data, result.data); json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8"); res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
} else { } else {
@ -3208,11 +3209,11 @@ int main(int argc, char ** argv) {
} }
ctx_server.queue_results.remove_waiting_task_id(id_task); ctx_server.queue_results.remove_waiting_task_id(id_task);
} else { } else {
const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) { const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
while (true) { while (true) {
server_task_result result = ctx_server.queue_results.recv(id_task); server_task_result result = ctx_server.queue_results.recv(id_task);
if (!result.error) { if (!result.error) {
std::vector<json> result_array = format_partial_response_oaicompat(result.data); std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
for (auto it = result_array.begin(); it != result_array.end(); ++it) { for (auto it = result_array.begin(); it != result_array.end(); ++it) {
if (!it->empty()) { if (!it->empty()) {

View File

@ -378,7 +378,7 @@ static json oaicompat_completion_params_parse(
return llama_params; return llama_params;
} }
static json format_final_response_oaicompat(const json & request, json result, bool streaming = false) { static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) {
bool stopped_word = result.count("stopped_word") != 0; bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false); bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0); int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@ -412,7 +412,7 @@ static json format_final_response_oaicompat(const json & request, json result, b
{"prompt_tokens", num_prompt_tokens}, {"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens} {"total_tokens", num_tokens_predicted + num_prompt_tokens}
}}, }},
{"id", gen_chatcmplid()} {"id", completion_id}
}; };
if (server_verbose) { if (server_verbose) {
@ -427,7 +427,7 @@ static json format_final_response_oaicompat(const json & request, json result, b
} }
// return value is vector as there is one case where we might need to generate two responses // return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(json result) { static std::vector<json> format_partial_response_oaicompat(json result, const std::string & completion_id) {
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) { if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({result}); return std::vector<json>({result});
} }
@ -471,7 +471,7 @@ static std::vector<json> format_partial_response_oaicompat(json result) {
{"role", "assistant"} {"role", "assistant"}
}}}})}, }}}})},
{"created", t}, {"created", t},
{"id", gen_chatcmplid()}, {"id", completion_id},
{"model", modelname}, {"model", modelname},
{"object", "chat.completion.chunk"}}; {"object", "chat.completion.chunk"}};
@ -482,7 +482,7 @@ static std::vector<json> format_partial_response_oaicompat(json result) {
{"content", content}}} {"content", content}}}
}})}, }})},
{"created", t}, {"created", t},
{"id", gen_chatcmplid()}, {"id", completion_id},
{"model", modelname}, {"model", modelname},
{"object", "chat.completion.chunk"}}; {"object", "chat.completion.chunk"}};
@ -509,7 +509,7 @@ static std::vector<json> format_partial_response_oaicompat(json result) {
json ret = json { json ret = json {
{"choices", choices}, {"choices", choices},
{"created", t}, {"created", t},
{"id", gen_chatcmplid()}, {"id", completion_id},
{"model", modelname}, {"model", modelname},
{"object", "chat.completion.chunk"} {"object", "chat.completion.chunk"}
}; };