sampling : do not set min_keep to n_probs (#5564)

This commit is contained in:
Georgi Gerganov 2024-02-18 19:38:06 +02:00
parent f3f28c5395
commit 689a091bbe
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -121,7 +121,7 @@ static void sampler_queue(
struct llama_context * ctx_main, struct llama_context * ctx_main,
const llama_sampling_params & params, const llama_sampling_params & params,
llama_token_data_array & cur_p, llama_token_data_array & cur_p,
size_t & min_keep) { size_t min_keep) {
const float temp = params.temp; const float temp = params.temp;
const float dynatemp_range = params.dynatemp_range; const float dynatemp_range = params.dynatemp_range;
const float dynatemp_exponent = params.dynatemp_exponent; const float dynatemp_exponent = params.dynatemp_exponent;
@ -248,10 +248,7 @@ static llama_token llama_sampling_sample_impl(
llama_sample_temp(ctx_main, &cur_p, temp); llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else { } else {
// temperature sampling sampler_queue(ctx_main, params, cur_p, 1);
size_t min_keep = std::max(1, params.n_probs);
sampler_queue(ctx_main, params, cur_p, min_keep);
id = llama_sample_token(ctx_main, &cur_p); id = llama_sample_token(ctx_main, &cur_p);