From 4a5b5870f191aa8fd938046e0fba3de8dc3c2279 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 17 Oct 2024 22:53:22 +0300 Subject: [PATCH] llama : handle temp <= 0.0 in the temp_ext sampler too ggml-ci --- src/llama-sampling.cpp | 60 ++++++++++++++++++++--------------------- tests/test-sampling.cpp | 14 ++++++++++ 2 files changed, 44 insertions(+), 30 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index fb6668fac..3b2dcfbfc 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -63,6 +63,33 @@ static void llama_log_softmax(float * array, size_t size) { } */ +static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { + if (temp <= 0.0f) { + // find the token with the highest logit and set the rest to -inf + llama_token max_id = cur_p->data[0].id; + float max_logit = cur_p->data[0].logit; + + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > max_logit) { + max_id = cur_p->data[i].id; + max_logit = cur_p->data[i].logit; + } + } + + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].id != max_id) { + cur_p->data[i].logit = -INFINITY; + } + } + + return; + } + + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= temp; + } +} + static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { GGML_ASSERT(cur_p->size > 0); @@ -916,30 +943,7 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl* static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_temp *) smpl->ctx; - if (ctx->temp <= 0.0f) { - // find the token with the highest logit and set the rest to -inf - llama_token max_id = cur_p->data[0].id; - float max_logit = cur_p->data[0].logit; - - for (size_t i = 1; i < cur_p->size; ++i) { - if (cur_p->data[i].logit > max_logit) { - max_id = cur_p->data[i].id; - max_logit = cur_p->data[i].logit; - } - } - - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].id != max_id) { - cur_p->data[i].logit = -INFINITY; - } - } - - return; - } - - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= ctx->temp; - } + llama_sampler_temp_impl(cur_p, ctx->temp); } static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) { @@ -1024,9 +1028,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke #endif // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= dyn_temp; - } + llama_sampler_temp_impl(cur_p, dyn_temp); // Re-compute softmax probabilities after scaling logits with dynamic temperature const double max_l_double = cur_p->data[0].logit; @@ -1050,9 +1052,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke } #endif } else { - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= ctx->temp; - } + llama_sampler_temp_impl(cur_p, ctx->temp); } } diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index e9dada795..05600e6f5 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -70,6 +70,17 @@ static void test_temp(const std::vector & probs, const std::vector tester.check(); } +static void test_temp_ext(const std::vector & probs, const std::vector & probs_expected, float temp, float delta, float exponent) { + sampler_tester tester(probs, probs_expected); + + DUMP(&tester.cur_p); + tester.apply(llama_sampler_init_temp_ext(temp, delta, exponent)); + tester.apply(llama_sampler_init_dist (0)); + DUMP(&tester.cur_p); + + tester.check(); +} + static void test_top_k(const std::vector & probs, const std::vector & probs_expected, int k) { sampler_tester tester(probs, probs_expected); @@ -277,6 +288,9 @@ int main(void) { test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f); test_temp({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f); + test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1.0f, 0.0f, 1.0f); + test_temp_ext({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f, 0.0f, 0.0f, 0.0f}, 0.0f, 0.0f, 1.0f); + test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {1.0f}, 1); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.44444f, 0.33333f, 0.22222f}, 3); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 4);