diff --git a/common/sampling.cpp b/common/sampling.cpp index e51d07611..7ef1d2217 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -173,22 +173,23 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st if (params.temp > 0.0f) { if (params.mirostat == 0) { + size_t min_keep = std::max(1, params.min_keep); for (const auto & cnstr : params.samplers) { switch (cnstr) { case GPT_SAMPLER_TYPE_TOP_K: llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k)); break; case GPT_SAMPLER_TYPE_TOP_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, min_keep)); break; case GPT_SAMPLER_TYPE_MIN_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, min_keep)); break; case GPT_SAMPLER_TYPE_TFS_Z: - llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, min_keep)); break; case GPT_SAMPLER_TYPE_TYPICAL_P: - llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep)); + llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, min_keep)); break; case GPT_SAMPLER_TYPE_TEMPERATURE: llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));