diff --git a/common/sampling.cpp b/common/sampling.cpp index 3715a7985..f0f1b92d3 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -35,7 +35,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_ result->prev.resize(params.n_prev); - result->n_considered = 0; + result->n_valid = 0; llama_sampling_set_rng_seed(result, params.seed); @@ -66,7 +66,7 @@ void llama_sampling_reset(llama_sampling_context * ctx) { std::fill(ctx->prev.begin(), ctx->prev.end(), 0); ctx->cur.clear(); - ctx->n_considered = 0; + ctx->n_valid = 0; } void llama_sampling_set_rng_seed(struct llama_sampling_context * ctx, uint32_t seed) { @@ -256,7 +256,7 @@ static llama_token llama_sampling_sample_impl( } } - ctx_sampling->n_considered = cur_p.size; + ctx_sampling->n_valid = temp == 0.0f ? 0 : cur_p.size; return id; } diff --git a/common/sampling.h b/common/sampling.h index 5b73ecdcd..655732ad1 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -81,7 +81,7 @@ struct llama_sampling_context { // TODO: replace with ring-buffer std::vector prev; std::vector cur; - size_t n_considered; + size_t n_valid; // Number of correct top tokens with correct probabilities. std::mt19937 rng; }; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 305f79492..2bf4026d5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2270,10 +2270,10 @@ struct server_context { const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs); if (n_probs > 0) { - const size_t n_considered = slot.ctx_sampling->n_considered; + const size_t n_valid = slot.ctx_sampling->n_valid; // Make sure at least n_probs top tokens are at the front of the vector: - if (slot.sparams.temp == 0.0f && n_probs > n_considered) { + if (slot.sparams.temp == 0.0f && n_probs > n_valid) { llama_sample_top_k(ctx, &cur_p, n_probs, 0); } @@ -2289,7 +2289,7 @@ struct server_context { for (size_t i = 0; i < n_probs; ++i) { result.probs.push_back({ cur_p.data[i].id, - i >= n_considered ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. + i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability. }); } }