From 06159898e10c6281aeca45fb65398d94f532a887 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 18 Oct 2024 15:58:52 +0300 Subject: [PATCH] cont : avoid extra loop in temperature sampler for sub-zero temp ggml-ci --- src/llama-sampling.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index 3b2dcfbfc..29852ddf3 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -66,18 +66,15 @@ static void llama_log_softmax(float * array, size_t size) { static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { if (temp <= 0.0f) { // find the token with the highest logit and set the rest to -inf - llama_token max_id = cur_p->data[0].id; - float max_logit = cur_p->data[0].logit; + size_t max_i = 0; + float max_l = cur_p->data[0].logit; for (size_t i = 1; i < cur_p->size; ++i) { - if (cur_p->data[i].logit > max_logit) { - max_id = cur_p->data[i].id; - max_logit = cur_p->data[i].logit; - } - } - - for (size_t i = 0; i < cur_p->size; ++i) { - if (cur_p->data[i].id != max_id) { + if (cur_p->data[i ].logit > max_l) { + cur_p->data[max_i].logit = -INFINITY; + max_i = i; + max_l = cur_p->data[i].logit; + } else { cur_p->data[i].logit = -INFINITY; } }