mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
server : bring back info of final chunk in stream mode (#10722)
* server : bring back into to final chunk in stream mode * clarify a bit * traling space
This commit is contained in:
parent
06d70147e6
commit
e52522b869
@ -392,7 +392,7 @@ struct server_task_result {
|
||||
return false;
|
||||
}
|
||||
virtual bool is_stop() {
|
||||
// only used by server_task_result_cmpl_partial
|
||||
// only used by server_task_result_cmpl_*
|
||||
return false;
|
||||
}
|
||||
virtual int get_index() {
|
||||
@ -478,14 +478,20 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
return index;
|
||||
}
|
||||
|
||||
virtual bool is_stop() override {
|
||||
return true; // in stream mode, final responses are considered stop
|
||||
}
|
||||
|
||||
virtual json to_json() override {
|
||||
return oaicompat ? to_json_oaicompat_chat() : to_json_non_oaicompat();
|
||||
return oaicompat
|
||||
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
|
||||
: to_json_non_oaicompat();
|
||||
}
|
||||
|
||||
json to_json_non_oaicompat() {
|
||||
json res = json {
|
||||
{"index", index},
|
||||
{"content", content},
|
||||
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||
{"id_slot", id_slot},
|
||||
{"stop", true},
|
||||
{"model", oaicompat_model},
|
||||
@ -546,18 +552,46 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
json to_json_oaicompat_chat_stream() {
|
||||
std::time_t t = std::time(0);
|
||||
std::string finish_reason = "length";
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
}
|
||||
|
||||
json choices = json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}}});
|
||||
|
||||
json ret = json {
|
||||
{"choices", choices},
|
||||
{"created", t},
|
||||
{"id", oaicompat_cmpl_id},
|
||||
{"model", oaicompat_model},
|
||||
{"object", "chat.completion.chunk"},
|
||||
{"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}},
|
||||
};
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
ret.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct server_task_result_cmpl_partial : server_task_result {
|
||||
int index = 0;
|
||||
std::string content;
|
||||
|
||||
bool truncated;
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
|
||||
stop_type stop = STOP_TYPE_NONE;
|
||||
|
||||
std::vector<completion_token_output> probs_output;
|
||||
result_timings timings;
|
||||
|
||||
@ -573,20 +607,19 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
}
|
||||
|
||||
virtual bool is_stop() override {
|
||||
return stop != STOP_TYPE_NONE;
|
||||
return false; // in stream mode, partial responses are not considered stop
|
||||
}
|
||||
|
||||
virtual json to_json() override {
|
||||
if (oaicompat) {
|
||||
return to_json_oaicompat();
|
||||
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
||||
}
|
||||
bool is_stop = stop != STOP_TYPE_NONE;
|
||||
|
||||
json to_json_non_oaicompat() {
|
||||
// non-OAI-compat JSON
|
||||
json res = json {
|
||||
{"index", index},
|
||||
{"content", content},
|
||||
{"stop_type", stop_type_to_str(stop)},
|
||||
{"stop", is_stop},
|
||||
{"stop", false},
|
||||
{"id_slot", id_slot},
|
||||
{"tokens_predicted", n_decoded},
|
||||
{"tokens_evaluated", n_prompt_tokens},
|
||||
@ -598,31 +631,14 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
if (!probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
||||
}
|
||||
if (is_stop) {
|
||||
res.push_back({"truncated", truncated});
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
json to_json_oaicompat() {
|
||||
bool first = n_decoded == 0;
|
||||
|
||||
std::string finish_reason;
|
||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||
finish_reason = "stop";
|
||||
} else if (stop == STOP_TYPE_LIMIT) {
|
||||
finish_reason = "length";
|
||||
}
|
||||
|
||||
std::time_t t = std::time(0);
|
||||
|
||||
json choices;
|
||||
|
||||
if (!finish_reason.empty()) {
|
||||
choices = json::array({json{{"finish_reason", finish_reason},
|
||||
{"index", 0},
|
||||
{"delta", json::object()}}});
|
||||
} else {
|
||||
if (first) {
|
||||
if (content.empty()) {
|
||||
choices = json::array({json{{"finish_reason", nullptr},
|
||||
@ -664,7 +680,6 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
}},
|
||||
}});
|
||||
}
|
||||
}
|
||||
|
||||
json ret = json {
|
||||
{"choices", choices},
|
||||
@ -678,14 +693,6 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||
ret.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
||||
if (!finish_reason.empty()) {
|
||||
ret.push_back({"usage", json {
|
||||
{"completion_tokens", n_decoded},
|
||||
{"prompt_tokens", n_prompt_tokens},
|
||||
{"total_tokens", n_decoded + n_prompt_tokens},
|
||||
}});
|
||||
}
|
||||
|
||||
return std::vector<json>({ret});
|
||||
}
|
||||
};
|
||||
@ -1888,12 +1895,9 @@ struct server_context {
|
||||
res->index = slot.index;
|
||||
res->content = tkn.text_to_send;
|
||||
|
||||
res->truncated = slot.truncated;
|
||||
res->n_decoded = slot.n_decoded;
|
||||
res->n_prompt_tokens = slot.n_prompt_tokens;
|
||||
|
||||
res->stop = slot.stop;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_chat = slot.params.oaicompat_chat;
|
||||
@ -1924,12 +1928,6 @@ struct server_context {
|
||||
}
|
||||
|
||||
void send_final_response(server_slot & slot) {
|
||||
if (slot.params.stream) {
|
||||
// if in stream mode, send the last partial response
|
||||
send_partial_response(slot, {0, "", {}});
|
||||
return;
|
||||
}
|
||||
|
||||
auto res = std::make_unique<server_task_result_cmpl_final>();
|
||||
res->id = slot.id_task;
|
||||
res->id_slot = slot.id;
|
||||
@ -1948,6 +1946,7 @@ struct server_context {
|
||||
res->stop = slot.stop;
|
||||
|
||||
res->verbose = slot.params.verbose;
|
||||
res->stream = slot.params.stream;
|
||||
res->oaicompat = slot.params.oaicompat;
|
||||
res->oaicompat_chat = slot.params.oaicompat_chat;
|
||||
res->oaicompat_model = slot.params.oaicompat_model;
|
||||
@ -2100,7 +2099,10 @@ struct server_context {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr);
|
||||
GGML_ASSERT(
|
||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||
);
|
||||
if (!result_handler(result)) {
|
||||
cancel_tasks(id_tasks);
|
||||
break;
|
||||
|
@ -42,10 +42,16 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
|
||||
})
|
||||
content = ""
|
||||
for data in res:
|
||||
assert "stop" in data and type(data["stop"]) == bool
|
||||
if data["stop"]:
|
||||
assert data["timings"]["prompt_n"] == n_prompt
|
||||
assert data["timings"]["predicted_n"] == n_predicted
|
||||
assert data["truncated"] == truncated
|
||||
assert data["stop_type"] == "limit"
|
||||
assert "generation_settings" in data
|
||||
assert server.n_predict is not None
|
||||
assert data["generation_settings"]["n_predict"] == min(n_predict, server.n_predict)
|
||||
assert data["generation_settings"]["seed"] == server.seed
|
||||
assert match_regex(re_content, content)
|
||||
else:
|
||||
content += data["content"]
|
||||
|
Loading…
Reference in New Issue
Block a user