server : bring back info of final chunk in stream mode (#10722)

* server : bring back into to final chunk in stream mode

* clarify a bit

* traling space
This commit is contained in:
Xuan Son Nguyen 2024-12-08 20:38:51 +01:00 committed by GitHub
parent 06d70147e6
commit e52522b869
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 94 additions and 86 deletions

View File

@ -392,7 +392,7 @@ struct server_task_result {
return false;
}
virtual bool is_stop() {
// only used by server_task_result_cmpl_partial
// only used by server_task_result_cmpl_*
return false;
}
virtual int get_index() {
@ -478,14 +478,20 @@ struct server_task_result_cmpl_final : server_task_result {
return index;
}
virtual bool is_stop() override {
return true; // in stream mode, final responses are considered stop
}
virtual json to_json() override {
return oaicompat ? to_json_oaicompat_chat() : to_json_non_oaicompat();
return oaicompat
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
: to_json_non_oaicompat();
}
json to_json_non_oaicompat() {
json res = json {
{"index", index},
{"content", content},
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"id_slot", id_slot},
{"stop", true},
{"model", oaicompat_model},
@ -546,18 +552,46 @@ struct server_task_result_cmpl_final : server_task_result {
return res;
}
json to_json_oaicompat_chat_stream() {
std::time_t t = std::time(0);
std::string finish_reason = "length";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
}
json choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
json ret = json {
{"choices", choices},
{"created", t},
{"id", oaicompat_cmpl_id},
{"model", oaicompat_model},
{"object", "chat.completion.chunk"},
{"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}},
};
if (timings.prompt_n >= 0) {
ret.push_back({"timings", timings.to_json()});
}
return ret;
}
};
struct server_task_result_cmpl_partial : server_task_result {
int index = 0;
std::string content;
bool truncated;
int32_t n_decoded;
int32_t n_prompt_tokens;
stop_type stop = STOP_TYPE_NONE;
std::vector<completion_token_output> probs_output;
result_timings timings;
@ -573,20 +607,19 @@ struct server_task_result_cmpl_partial : server_task_result {
}
virtual bool is_stop() override {
return stop != STOP_TYPE_NONE;
return false; // in stream mode, partial responses are not considered stop
}
virtual json to_json() override {
if (oaicompat) {
return to_json_oaicompat();
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
}
bool is_stop = stop != STOP_TYPE_NONE;
json to_json_non_oaicompat() {
// non-OAI-compat JSON
json res = json {
{"index", index},
{"content", content},
{"stop_type", stop_type_to_str(stop)},
{"stop", is_stop},
{"stop", false},
{"id_slot", id_slot},
{"tokens_predicted", n_decoded},
{"tokens_evaluated", n_prompt_tokens},
@ -598,31 +631,14 @@ struct server_task_result_cmpl_partial : server_task_result {
if (!probs_output.empty()) {
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
}
if (is_stop) {
res.push_back({"truncated", truncated});
}
return res;
}
json to_json_oaicompat() {
bool first = n_decoded == 0;
std::string finish_reason;
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
finish_reason = "stop";
} else if (stop == STOP_TYPE_LIMIT) {
finish_reason = "length";
}
std::time_t t = std::time(0);
json choices;
if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
@ -664,7 +680,6 @@ struct server_task_result_cmpl_partial : server_task_result {
}},
}});
}
}
json ret = json {
{"choices", choices},
@ -678,14 +693,6 @@ struct server_task_result_cmpl_partial : server_task_result {
ret.push_back({"timings", timings.to_json()});
}
if (!finish_reason.empty()) {
ret.push_back({"usage", json {
{"completion_tokens", n_decoded},
{"prompt_tokens", n_prompt_tokens},
{"total_tokens", n_decoded + n_prompt_tokens},
}});
}
return std::vector<json>({ret});
}
};
@ -1888,12 +1895,9 @@ struct server_context {
res->index = slot.index;
res->content = tkn.text_to_send;
res->truncated = slot.truncated;
res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens;
res->stop = slot.stop;
res->verbose = slot.params.verbose;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
@ -1924,12 +1928,6 @@ struct server_context {
}
void send_final_response(server_slot & slot) {
if (slot.params.stream) {
// if in stream mode, send the last partial response
send_partial_response(slot, {0, "", {}});
return;
}
auto res = std::make_unique<server_task_result_cmpl_final>();
res->id = slot.id_task;
res->id_slot = slot.id;
@ -1948,6 +1946,7 @@ struct server_context {
res->stop = slot.stop;
res->verbose = slot.params.verbose;
res->stream = slot.params.stream;
res->oaicompat = slot.params.oaicompat;
res->oaicompat_chat = slot.params.oaicompat_chat;
res->oaicompat_model = slot.params.oaicompat_model;
@ -2100,7 +2099,10 @@ struct server_context {
return;
}
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr);
GGML_ASSERT(
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
if (!result_handler(result)) {
cancel_tasks(id_tasks);
break;

View File

@ -42,10 +42,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
})
content = ""
for data in res:
assert "stop" in data and type(data["stop"]) == bool
if data["stop"]:
assert data["timings"]["prompt_n"] == n_prompt
assert data["timings"]["predicted_n"] == n_predicted
assert data["truncated"] == truncated
assert data["stop_type"] == "limit"
assert "generation_settings" in data
assert server.n_predict is not None
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
assert data["generation_settings"]["seed"] == server.seed
assert match_regex(re_content, content)
else:
content += data["content"]