mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
server: fix reported top tokens for temperature 0 (#7203)
This commit is contained in:
parent
b83cc3f5b3
commit
5ae3426b0b
@ -35,7 +35,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
|
|||||||
|
|
||||||
result->prev.resize(params.n_prev);
|
result->prev.resize(params.n_prev);
|
||||||
|
|
||||||
result->n_considered = 0;
|
result->n_valid = 0;
|
||||||
|
|
||||||
llama_sampling_set_rng_seed(result, params.seed);
|
llama_sampling_set_rng_seed(result, params.seed);
|
||||||
|
|
||||||
@ -66,7 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) {
|
|||||||
|
|
||||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||||
ctx->cur.clear();
|
ctx->cur.clear();
|
||||||
ctx->n_considered = 0;
|
ctx->n_valid = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) {
|
||||||
@ -256,7 +256,7 @@ static llama_token llama_sampling_sample_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx_sampling->n_considered = cur_p.size;
|
ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size;
|
||||||
|
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
@ -81,7 +81,7 @@ struct llama_sampling_context {
|
|||||||
// TODO: replace with ring-buffer
|
// TODO: replace with ring-buffer
|
||||||
std::vector<llama_token> prev;
|
std::vector<llama_token> prev;
|
||||||
std::vector<llama_token_data> cur;
|
std::vector<llama_token_data> cur;
|
||||||
size_t n_considered;
|
size_t n_valid; // Number of correct top tokens with correct probabilities.
|
||||||
|
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
};
|
};
|
||||||
|
@ -2270,10 +2270,10 @@ struct server_context {
|
|||||||
|
|
||||||
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
|
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
|
||||||
if (n_probs > 0) {
|
if (n_probs > 0) {
|
||||||
const size_t n_considered = slot.ctx_sampling->n_considered;
|
const size_t n_valid = slot.ctx_sampling->n_valid;
|
||||||
|
|
||||||
// Make sure at least n_probs top tokens are at the front of the vector:
|
// Make sure at least n_probs top tokens are at the front of the vector:
|
||||||
if (slot.sparams.temp == 0.0f && n_probs > n_considered) {
|
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
|
||||||
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
|
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2289,7 +2289,7 @@ struct server_context {
|
|||||||
for (size_t i = 0; i < n_probs; ++i) {
|
for (size_t i = 0; i < n_probs; ++i) {
|
||||||
result.probs.push_back({
|
result.probs.push_back({
|
||||||
cur_p.data[i].id,
|
cur_p.data[i].id,
|
||||||
i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
|
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user