mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-07 09:11:46 +00:00
cont : avoid extra loop in temperature sampler for sub-zero temp
Some checks failed
flake8 Lint / Lint (push) Has been cancelled
Some checks failed
flake8 Lint / Lint (push) Has been cancelled
ggml-ci
This commit is contained in:
parent
4a5b5870f1
commit
06159898e1
@ -66,18 +66,15 @@ static void llama_log_softmax(float * array, size_t size) {
|
||||
static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) {
|
||||
if (temp <= 0.0f) {
|
||||
// find the token with the highest logit and set the rest to -inf
|
||||
llama_token max_id = cur_p->data[0].id;
|
||||
float max_logit = cur_p->data[0].logit;
|
||||
size_t max_i = 0;
|
||||
float max_l = cur_p->data[0].logit;
|
||||
|
||||
for (size_t i = 1; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].logit > max_logit) {
|
||||
max_id = cur_p->data[i].id;
|
||||
max_logit = cur_p->data[i].logit;
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < cur_p->size; ++i) {
|
||||
if (cur_p->data[i].id != max_id) {
|
||||
if (cur_p->data[i ].logit > max_l) {
|
||||
cur_p->data[max_i].logit = -INFINITY;
|
||||
max_i = i;
|
||||
max_l = cur_p->data[i].logit;
|
||||
} else {
|
||||
cur_p->data[i].logit = -INFINITY;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user