From af0a5b616359809ce886ea433acedebb39b12969 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 7 May 2024 23:07:58 +0200 Subject: [PATCH] server: fix incorrectly reported token probabilities (#7125) * server: normalize token probabilities * fix temperature == 0.0f --- common/sampling.cpp | 5 +++++ common/sampling.h | 1 + examples/server/README.md | 2 +- examples/server/server.cpp | 34 ++++++++++++++++++++++++---------- 4 files changed, 31 insertions(+), 11 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index cc83600d9..3715a7985 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -35,6 +35,8 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); + result->n_considered = 0; + llama_sampling_set_rng_seed(result, params.seed); return result; @@ -64,6 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); + ctx->n_considered = 0; } void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { @@ -253,6 +256,8 @@ static llama_token llama_sampling_sample_impl( } } + ctx_sampling->n_considered = cur_p.size; + return id; } diff --git a/common/sampling.h b/common/sampling.h index cf7081e36..5b73ecdcd 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -81,6 +81,7 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector prev; std::vector cur; + size_t n_considered; std::mt19937 rng; }; diff --git a/examples/server/README.md b/examples/server/README.md index bf3713640..a7c3f0b5f 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -272,7 +272,7 @@ node index.js `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced. The tokens can also be represented as strings, e.g. `[["Hello, World!",-0.5]]` will reduce the likelihood of all the individual tokens that represent the string `Hello, World!`, just like the `presence_penalty` does. Default: `[]` - `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token. Default: `0` + `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token given the sampling settings. Note that for temperature < 0 the tokens are sampled greedily but token probabilities are still being calculated via a simple softmax of the logits without considering any other sampler settings. Default: `0` `min_keep`: If greater than 0, force samplers to return N possible tokens at minimum. Default: `0` diff --git a/examples/server/server.cpp b/examples/server/server.cpp index ff0814b2f..85ae1ad96 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2266,17 +2266,31 @@ struct server_context { llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false }; result.tok = id; - const int32_t n_probs = slot.sparams.n_probs; - if (slot.sparams.temp <= 0 && n_probs > 0) { - // for llama_sample_token_greedy we need to sort candidates - llama_sample_softmax(ctx, &cur_p); - } + const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); + if (n_probs > 0) { + const size_t n_considered = slot.ctx_sampling->n_considered; - for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) { - result.probs.push_back({ - cur_p.data[i].id, - cur_p.data[i].p - }); + // Make sure at least n_probs top tokens are at the front of the vector: + if (slot.sparams.temp == 0.0f && n_probs > n_considered) { + llama_sample_top_k(ctx, &cur_p, n_probs, 0); + } + + if (slot.sparams.temp == 0.0f) { + // With greedy sampling the probabilities have possibly not been calculated. + for (size_t i = 0; i < n_probs; ++i) { + result.probs.push_back({ + cur_p.data[i].id, + i == 0 ? 1.0f : 0.0f + }); + } + } else { + for (size_t i = 0; i < n_probs; ++i) { + result.probs.push_back({ + cur_p.data[i].id, + i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. + }); + } + } } if (!process_token(result, slot)) {