server: add option to output probabilities for completion (#1962)

* server: add option to output probabilities for completion
* server: fix issue when handling probability output for incomplete tokens for multibyte character generation
* server: fix llama_sample_top_k order
* examples/common.h: put all bool variables in gpt_params together
This commit is contained in:
WangHaoranRobin 2023-07-03 05:38:44 +08:00 committed by GitHub
parent 46088f7231
commit d7d2e6a0f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 122 additions and 31 deletions

View File

@ -31,7 +31,7 @@ struct gpt_params {
int32_t n_gpu_layers = 0; // number of layers to store in VRAM int32_t n_gpu_layers = 0; // number of layers to store in VRAM
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
bool low_vram = 0; // if true, reduce VRAM usage at the cost of performance int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
// sampling parameters // sampling parameters
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
@ -59,6 +59,7 @@ struct gpt_params {
std::string lora_adapter = ""; // lora adapter path std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter std::string lora_base = ""; // base model path for the lora adapter
bool low_vram = false; // if true, reduce VRAM usage at the cost of performance
bool memory_f16 = true; // use f16 instead of f32 for memory kv bool memory_f16 = true; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs bool use_color = false; // use color to distinguish generations and inputs

View File

@ -26,6 +26,17 @@ struct server_params {
int32_t write_timeout = 600; int32_t write_timeout = 600;
}; };
// completion token output with probabilities
struct completion_token_output {
struct token_prob {
llama_token tok;
float prob;
};
std::vector<token_prob> probs;
llama_token tok;
};
static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) { static size_t common_part(const std::vector<llama_token> & a, const std::vector<llama_token> & b) {
size_t i; size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
@ -86,6 +97,40 @@ static void server_log(const char * level, const char * function, int line,
fflush(stdout); fflush(stdout);
} }
// format incomplete utf-8 multibyte character for output
static std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) {
std::string out = token == -1 ? "" : llama_token_to_str(ctx, token);
// if first bit is 1, meaning it's a partial character
if (out.size() > 0 && (out[0] & 0x80) == 0x80) {
std::stringstream ss;
ss<< std::hex << (out[0] & 0xff);
std::string res ( ss.str() );
out = "byte: \\x" + res;
}
return out;
}
// convert a vector of completion_token_output to json
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> probs) {
json out = json::array();
for (const auto & prob : probs) {
json probs_for_token = json::array();
for (const auto & p : prob.probs) {
std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
probs_for_token.push_back(json {
{ "tok_str", tok_str },
{ "prob", p.prob },
});
}
std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
out.push_back(json {
{"content", tok_str},
{"probs", probs_for_token},
});
}
return out;
}
static bool server_verbose = false; static bool server_verbose = false;
#if SERVER_VERBOSE != 1 #if SERVER_VERBOSE != 1
@ -107,6 +152,7 @@ struct llama_server_context {
bool stream = false; bool stream = false;
bool has_next_token = false; bool has_next_token = false;
std::string generated_text; std::string generated_text;
std::vector<completion_token_output> generated_token_probs;
size_t num_tokens_predicted = 0; size_t num_tokens_predicted = 0;
size_t n_past = 0; size_t n_past = 0;
@ -142,6 +188,7 @@ struct llama_server_context {
num_tokens_predicted = 0; num_tokens_predicted = 0;
generated_text = ""; generated_text = "";
generated_text.reserve(params.n_ctx); generated_text.reserve(params.n_ctx);
generated_token_probs.clear();
truncated = false; truncated = false;
stopped_eos = false; stopped_eos = false;
stopped_word = false; stopped_word = false;
@ -221,8 +268,9 @@ struct llama_server_context {
llama_set_rng_seed(ctx, params.seed); llama_set_rng_seed(ctx, params.seed);
} }
llama_token nextToken() { completion_token_output nextToken() {
llama_token result = -1; completion_token_output result;
result.tok = -1;
if (embd.size() >= (size_t)params.n_ctx) { if (embd.size() >= (size_t)params.n_ctx) {
// Reset context // Reset context
@ -261,7 +309,8 @@ struct llama_server_context {
if (params.n_predict == 0) { if (params.n_predict == 0) {
has_next_token = false; has_next_token = false;
return llama_token_eos(); result.tok = llama_token_eos();
return result;
} }
// out of user input, sample next token // out of user input, sample next token
@ -278,7 +327,7 @@ struct llama_server_context {
const float mirostat_tau = params.mirostat_tau; const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl; const bool penalize_nl = params.penalize_nl;
llama_token id = 0; const int32_t n_probs = params.n_probs;
{ {
auto * logits = llama_get_logits(ctx); auto * logits = llama_get_logits(ctx);
@ -312,35 +361,42 @@ struct llama_server_context {
if (temp <= 0) { if (temp <= 0) {
// Greedy sampling // Greedy sampling
id = llama_sample_token_greedy(ctx, &candidates_p); result.tok = llama_sample_token_greedy(ctx, &candidates_p);
if (n_probs > 0) {
llama_sample_softmax(ctx, &candidates_p);
}
} else { } else {
if (mirostat == 1) { if (mirostat == 1) {
static float mirostat_mu = 2.0f * mirostat_tau; static float mirostat_mu = 2.0f * mirostat_tau;
const int mirostat_m = 100; const int mirostat_m = 100;
llama_sample_temperature(ctx, &candidates_p, temp); llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu); result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
} else if (mirostat == 2) { } else if (mirostat == 2) {
static float mirostat_mu = 2.0f * mirostat_tau; static float mirostat_mu = 2.0f * mirostat_tau;
llama_sample_temperature(ctx, &candidates_p, temp); llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu); result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
} else { } else {
// Temperature sampling // Temperature sampling
llama_sample_top_k(ctx, &candidates_p, top_k, 1); size_t min_keep = std::max(1, n_probs);
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1); llama_sample_top_k(ctx, &candidates_p, top_k, min_keep);
llama_sample_typical(ctx, &candidates_p, typical_p, 1); llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
llama_sample_top_p(ctx, &candidates_p, top_p, 1); llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
llama_sample_temperature(ctx, &candidates_p, temp); llama_sample_temperature(ctx, &candidates_p, temp);
id = llama_sample_token(ctx, &candidates_p); result.tok = llama_sample_token(ctx, &candidates_p);
} }
} }
for (size_t i = 0; i < std::min(candidates_p.size, (size_t) n_probs); ++i) {
result.probs.push_back({candidates_p.data[i].id, candidates_p.data[i].p});
}
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id); last_n_tokens.push_back(result.tok);
num_tokens_predicted++; num_tokens_predicted++;
} }
// add it to the context // add it to the context
embd.push_back(id); embd.push_back(result.tok);
result = id;
// decrement remaining sampling budget // decrement remaining sampling budget
--n_remain; --n_remain;
@ -382,12 +438,16 @@ struct llama_server_context {
return stop_pos; return stop_pos;
} }
std::string doCompletion() { completion_token_output doCompletion() {
const llama_token token = nextToken(); const completion_token_output token_with_probs = nextToken();
const std::string token_text = token == -1 ? "" : llama_token_to_str(ctx, token); const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(ctx, token_with_probs.tok);
generated_text += token_text; generated_text += token_text;
if (params.n_probs > 0) {
generated_token_probs.push_back(token_with_probs);
}
if (multibyte_pending > 0) { if (multibyte_pending > 0) {
multibyte_pending -= token_text.size(); multibyte_pending -= token_text.size();
} else if (token_text.size() == 1) { } else if (token_text.size() == 1) {
@ -416,8 +476,8 @@ struct llama_server_context {
} }
LOG_VERBOSE("next token", { LOG_VERBOSE("next token", {
{ "token", token }, { "token", token_with_probs.tok },
{ "token_text", llama_token_to_str(ctx, token) }, { "token_text", tokens_to_output_formatted_string(ctx, token_with_probs.tok) },
{ "has_next_token", has_next_token }, { "has_next_token", has_next_token },
{ "n_remain", n_remain }, { "n_remain", n_remain },
{ "num_tokens_predicted", num_tokens_predicted }, { "num_tokens_predicted", num_tokens_predicted },
@ -427,7 +487,7 @@ struct llama_server_context {
{ "stopping_word", stopping_word }, { "stopping_word", stopping_word },
}); });
return token_text; return token_with_probs;
} }
std::vector<float> getEmbedding() { std::vector<float> getEmbedding() {
@ -669,6 +729,7 @@ static json format_generation_settings(llama_server_context & llama) {
{ "ignore_eos", ignore_eos }, { "ignore_eos", ignore_eos },
{ "stream", llama.stream }, { "stream", llama.stream },
{ "logit_bias", llama.params.logit_bias }, { "logit_bias", llama.params.logit_bias },
{ "n_probs", llama.params.n_probs },
}; };
} }
@ -678,8 +739,9 @@ static json format_embedding_response(llama_server_context & llama) {
}; };
} }
static json format_final_response(llama_server_context & llama, const std::string & content) { static json format_final_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
return json {
json res = json {
{ "content", content }, { "content", content },
{ "stop", true }, { "stop", true },
{ "model", llama.params.model_alias }, { "model", llama.params.model_alias },
@ -692,13 +754,25 @@ static json format_final_response(llama_server_context & llama, const std::strin
{ "stopped_limit", llama.stopped_limit }, { "stopped_limit", llama.stopped_limit },
{ "stopping_word", llama.stopping_word }, { "stopping_word", llama.stopping_word },
}; };
if (llama.params.n_probs > 0) {
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
} }
static json format_partial_response(const std::string & content) { return res;
return json { }
static json format_partial_response(llama_server_context & llama, const std::string & content, const std::vector<completion_token_output> & probs) {
json res = json {
{ "content", content }, { "content", content },
{ "stop", false }, { "stop", false },
}; };
if (llama.params.n_probs > 0) {
res["completion_probabilities"] = probs_vector_to_json(llama.ctx, probs);
}
return res;
} }
static json format_tokenizer_response(const std::vector<llama_token> & tokens) { static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
@ -728,6 +802,7 @@ static void parse_options_completion(const json & body, llama_server_context & l
llama.params.n_keep = body.value("n_keep", default_params.n_keep); llama.params.n_keep = body.value("n_keep", default_params.n_keep);
llama.params.seed = body.value("seed", default_params.seed); llama.params.seed = body.value("seed", default_params.seed);
llama.params.prompt = body.value("prompt", default_params.prompt); llama.params.prompt = body.value("prompt", default_params.prompt);
llama.params.n_probs = body.value("n_probs", default_params.n_probs);
llama.params.logit_bias.clear(); llama.params.logit_bias.clear();
if (body.value("ignore_eos", false)) { if (body.value("ignore_eos", false)) {
@ -830,7 +905,8 @@ int main(int argc, char ** argv) {
size_t stop_pos = std::string::npos; size_t stop_pos = std::string::npos;
while (llama.has_next_token) { while (llama.has_next_token) {
const std::string token_text = llama.doCompletion(); const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
stop_pos = llama.findStoppingStrings(llama.generated_text, stop_pos = llama.findStoppingStrings(llama.generated_text,
token_text.size(), STOP_FULL); token_text.size(), STOP_FULL);
@ -844,7 +920,7 @@ int main(int argc, char ** argv) {
llama.generated_text.end()); llama.generated_text.end());
} }
const json data = format_final_response(llama, llama.generated_text); const json data = format_final_response(llama, llama.generated_text, llama.generated_token_probs);
llama_print_timings(llama.ctx); llama_print_timings(llama.ctx);
@ -853,9 +929,11 @@ int main(int argc, char ** argv) {
} else { } else {
const auto chunked_content_provider = [&](size_t, DataSink & sink) { const auto chunked_content_provider = [&](size_t, DataSink & sink) {
size_t sent_count = 0; size_t sent_count = 0;
size_t sent_token_probs_index = 0;
while (llama.has_next_token) { while (llama.has_next_token) {
const std::string token_text = llama.doCompletion(); const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
if (llama.multibyte_pending > 0) { if (llama.multibyte_pending > 0) {
continue; continue;
} }
@ -878,10 +956,22 @@ int main(int argc, char ** argv) {
const std::string to_send = llama.generated_text.substr(pos, stop_pos); const std::string to_send = llama.generated_text.substr(pos, stop_pos);
sent_count += to_send.size(); sent_count += to_send.size();
std::vector<completion_token_output> probs_output = {};
if (llama.params.n_probs > 0) {
const std::vector<llama_token> to_send_toks = llama_tokenize(llama.ctx, to_send, false);
size_t probs_pos = std::min(sent_token_probs_index, llama.generated_token_probs.size());
size_t probs_stop_pos = std::min(sent_token_probs_index + to_send_toks.size(), llama.generated_token_probs.size());
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(llama.generated_token_probs.begin() + probs_pos, llama.generated_token_probs.begin() + probs_stop_pos);
}
sent_token_probs_index = probs_stop_pos;
}
const json data = llama.has_next_token const json data = llama.has_next_token
? format_partial_response(to_send) ? format_partial_response(llama, to_send, probs_output)
// Generation is done, send extra information. // Generation is done, send extra information.
: format_final_response(llama, to_send); : format_final_response(llama, to_send, llama.generated_token_probs);
const std::string str = const std::string str =
"data: " + "data: " +