server : add "tokens" output

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-12-16 21:03:24 +02:00
parent 08ea539df2
commit 79a8176883
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -468,7 +468,10 @@ struct completion_token_output {
struct server_task_result_cmpl_final : server_task_result { struct server_task_result_cmpl_final : server_task_result {
int index = 0; int index = 0;
std::string content;
std::string content;
llama_tokens tokens;
bool stream; bool stream;
result_timings timings; result_timings timings;
std::string prompt; std::string prompt;
@ -510,6 +513,7 @@ struct server_task_result_cmpl_final : server_task_result {
json res = json { json res = json {
{"index", index}, {"index", index},
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
{"tokens", stream ? llama_tokens {} : tokens},
{"id_slot", id_slot}, {"id_slot", id_slot},
{"stop", true}, {"stop", true},
{"model", oaicompat_model}, {"model", oaicompat_model},
@ -541,7 +545,8 @@ struct server_task_result_cmpl_final : server_task_result {
{"index", 0}, {"index", 0},
{"message", json{ {"message", json{
{"content", content}, {"content", content},
{"role", "assistant"} {"tokens", tokens},
{"role", "assistant"}
} }
}}}); }}});
@ -605,7 +610,9 @@ struct server_task_result_cmpl_final : server_task_result {
struct server_task_result_cmpl_partial : server_task_result { struct server_task_result_cmpl_partial : server_task_result {
int index = 0; int index = 0;
std::string content;
std::string content;
llama_tokens tokens;
int32_t n_decoded; int32_t n_decoded;
int32_t n_prompt_tokens; int32_t n_prompt_tokens;
@ -637,6 +644,7 @@ struct server_task_result_cmpl_partial : server_task_result {
json res = json { json res = json {
{"index", index}, {"index", index},
{"content", content}, {"content", content},
{"tokens", tokens},
{"stop", false}, {"stop", false},
{"id_slot", id_slot}, {"id_slot", id_slot},
{"tokens_predicted", n_decoded}, {"tokens_predicted", n_decoded},
@ -679,7 +687,8 @@ struct server_task_result_cmpl_partial : server_task_result {
{"choices", json::array({json{{"finish_reason", nullptr}, {"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0}, {"index", 0},
{"delta", json{ {"delta", json{
{"content", content}}} {"content", content},
{"tokens", tokens}}}
}})}, }})},
{"created", t}, {"created", t},
{"id", oaicompat_cmpl_id}, {"id", oaicompat_cmpl_id},
@ -695,6 +704,7 @@ struct server_task_result_cmpl_partial : server_task_result {
{"delta", {"delta",
json{ json{
{"content", content}, {"content", content},
{"tokens", tokens}
}}, }},
}}); }});
} }
@ -949,8 +959,11 @@ struct server_slot {
size_t last_nl_pos = 0; size_t last_nl_pos = 0;
std::string generated_text; std::string generated_text;
llama_tokens generated_tokens;
llama_tokens cache_tokens; llama_tokens cache_tokens;
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
bool has_next_token = true; bool has_next_token = true;
@ -985,6 +998,7 @@ struct server_slot {
n_prompt_tokens = 0; n_prompt_tokens = 0;
last_nl_pos = 0; last_nl_pos = 0;
generated_text = ""; generated_text = "";
generated_tokens = {};
has_new_line = false; has_new_line = false;
truncated = false; truncated = false;
stop = STOP_TYPE_NONE; stop = STOP_TYPE_NONE;
@ -1736,6 +1750,7 @@ struct server_context {
// search stop word and delete it // search stop word and delete it
slot.generated_text += token_str; slot.generated_text += token_str;
slot.generated_tokens.push_back(result.tok);
slot.has_next_token = true; slot.has_next_token = true;
// check if there is incomplete UTF-8 character at the end // check if there is incomplete UTF-8 character at the end
@ -1912,6 +1927,7 @@ struct server_context {
res->id = slot.id_task; res->id = slot.id_task;
res->index = slot.index; res->index = slot.index;
res->content = tkn.text_to_send; res->content = tkn.text_to_send;
res->tokens = { tkn.tok };
res->n_decoded = slot.n_decoded; res->n_decoded = slot.n_decoded;
res->n_prompt_tokens = slot.n_prompt_tokens; res->n_prompt_tokens = slot.n_prompt_tokens;
@ -1952,6 +1968,7 @@ struct server_context {
res->index = slot.index; res->index = slot.index;
res->content = slot.generated_text; res->content = slot.generated_text;
res->tokens = slot.generated_tokens;
res->timings = slot.get_timings(); res->timings = slot.get_timings();
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true); res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);