From cb75bebcad8b4f06bf4a03c23a5d9ad1d625ae7d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 17 Oct 2024 17:19:23 +0300 Subject: [PATCH] sampling : change temperature sampler logic For t <= 0.0f, keep the max logit intact and set the rest to -inf --- common/sampling.cpp | 3 ++- include/llama.h | 6 ++++-- src/llama-sampling.cpp | 23 +++++++++++++++++++++++ tests/test-sampling.cpp | 3 +++ 4 files changed, 32 insertions(+), 3 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index f536c1e0a..8d9a39ef0 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -171,7 +171,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co params.penalize_nl, params.ignore_eos)); - if (params.temp > 0.0f) { + if (params.temp >= 0.0f) { if (params.mirostat == 0) { for (const auto & cnstr : params.samplers) { switch (cnstr) { @@ -214,6 +214,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co GGML_ASSERT(false && "unknown mirostat version"); } } else { + // negative temperatures will trigger "greedy" sampling: simply take the most likely token each time if (params.n_probs > 0) { // some use cases require to sample greedily, but still obtain the probabilities of the top tokens // ref: https://github.com/ggerganov/llama.cpp/pull/9605 diff --git a/include/llama.h b/include/llama.h index 2206ef27d..581469034 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1082,8 +1082,8 @@ extern "C" { // available samplers: - LLAMA_API struct llama_sampler * llama_sampler_init_greedy (void); - LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); + LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void); + LLAMA_API struct llama_sampler * llama_sampler_init_dist (uint32_t seed); /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first. @@ -1104,6 +1104,8 @@ extern "C" { /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. LLAMA_API struct llama_sampler * llama_sampler_init_typical (float p, size_t min_keep); + + /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf LLAMA_API struct llama_sampler * llama_sampler_init_temp (float t); /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772. diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index af5117e88..fb6668fac 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -915,6 +915,28 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl* static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_temp *) smpl->ctx; + + if (ctx->temp <= 0.0f) { + // find the token with the highest logit and set the rest to -inf + llama_token max_id = cur_p->data[0].id; + float max_logit = cur_p->data[0].logit; + + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > max_logit) { + max_id = cur_p->data[i].id; + max_logit = cur_p->data[i].logit; + } + } + + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].id != max_id) { + cur_p->data[i].logit = -INFINITY; + } + } + + return; + } + for (size_t i = 0; i < cur_p->size; ++i) { cur_p->data[i].logit /= ctx->temp; } @@ -964,6 +986,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke if (ctx->delta > 0) { const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); const float max_temp = ctx->temp + ctx->delta; + float exponent_val = ctx->exponent; // no need to do anything if there is only one (or zero) candidates diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index df62c8bec..8960ced8f 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -274,6 +274,9 @@ static void test_perf() { int main(void) { ggml_time_init(); + test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f); + test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);