cont : avoid extra loop in temperature sampler for sub-zero temp
Some checks failed
flake8 Lint / Lint (push) Has been cancelled

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-18 15:58:52 +03:00
parent 4a5b5870f1
commit 06159898e1
No known key found for this signature in database
GPG Key ID: BF970631944C16B7

View File

@ -66,18 +66,15 @@ static void llama_log_softmax(float * array, size_t size) {
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
if (temp <= 0.0f) {
// find the token with the highest logit and set the rest to -inf
llama_token max_id = cur_p->data[0].id;
float max_logit = cur_p->data[0].logit;
size_t max_i = 0;
float max_l = cur_p->data[0].logit;
for (size_t i = 1; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > max_logit) {
max_id = cur_p->data[i].id;
max_logit = cur_p->data[i].logit;
}
}
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].id != max_id) {
if (cur_p->data[i ].logit > max_l) {
cur_p->data[max_i].logit = -INFINITY;
max_i = i;
max_l = cur_p->data[i].logit;
} else {
cur_p->data[i].logit = -INFINITY;
}
}