mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-10 02:31:46 +00:00
cont : no need for special "greedy" logic
top-k == 1 is the same
This commit is contained in:
parent
cb75bebcad
commit
57fb835e5b
@ -171,59 +171,46 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||||||
params.penalize_nl,
|
params.penalize_nl,
|
||||||
params.ignore_eos));
|
params.ignore_eos));
|
||||||
|
|
||||||
if (params.temp >= 0.0f) {
|
if (params.mirostat == 0) {
|
||||||
if (params.mirostat == 0) {
|
for (const auto & cnstr : params.samplers) {
|
||||||
for (const auto & cnstr : params.samplers) {
|
switch (cnstr) {
|
||||||
switch (cnstr) {
|
case COMMON_SAMPLER_TYPE_TOP_K:
|
||||||
case COMMON_SAMPLER_TYPE_TOP_K:
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k (params.top_k));
|
break;
|
||||||
break;
|
case COMMON_SAMPLER_TYPE_TOP_P:
|
||||||
case COMMON_SAMPLER_TYPE_TOP_P:
|
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_p (params.top_p, params.min_keep));
|
break;
|
||||||
break;
|
case COMMON_SAMPLER_TYPE_MIN_P:
|
||||||
case COMMON_SAMPLER_TYPE_MIN_P:
|
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_min_p (params.min_p, params.min_keep));
|
break;
|
||||||
break;
|
case COMMON_SAMPLER_TYPE_XTC:
|
||||||
case COMMON_SAMPLER_TYPE_XTC:
|
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_xtc (params.xtc_probability, params.xtc_threshold, params.min_keep, params.seed));
|
break;
|
||||||
break;
|
case COMMON_SAMPLER_TYPE_TFS_Z:
|
||||||
case COMMON_SAMPLER_TYPE_TFS_Z:
|
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_tail_free(params.tfs_z, params.min_keep));
|
break;
|
||||||
break;
|
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
||||||
case COMMON_SAMPLER_TYPE_TYPICAL_P:
|
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_typical (params.typ_p, params.min_keep));
|
break;
|
||||||
break;
|
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
||||||
case COMMON_SAMPLER_TYPE_TEMPERATURE:
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
|
break;
|
||||||
break;
|
case COMMON_SAMPLER_TYPE_INFILL:
|
||||||
case COMMON_SAMPLER_TYPE_INFILL:
|
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model));
|
break;
|
||||||
break;
|
default:
|
||||||
default:
|
GGML_ASSERT(false && "unknown sampler type");
|
||||||
GGML_ASSERT(false && "unknown sampler type");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
|
||||||
} else if (params.mirostat == 1) {
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
|
||||||
} else if (params.mirostat == 2) {
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(false && "unknown mirostat version");
|
|
||||||
}
|
}
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
|
||||||
|
} else if (params.mirostat == 1) {
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
|
||||||
|
} else if (params.mirostat == 2) {
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
|
||||||
|
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));
|
||||||
} else {
|
} else {
|
||||||
// negative temperatures will trigger "greedy" sampling: simply take the most likely token each time
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
if (params.n_probs > 0) {
|
|
||||||
// some use cases require to sample greedily, but still obtain the probabilities of the top tokens
|
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/9605
|
|
||||||
//
|
|
||||||
// the following will not produce exactly the same probs as applyging softmax to the full vocabulary, but
|
|
||||||
// it is much faster, since we avoid sorting all tokens and should give a good approximation
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_top_k(params.n_probs));
|
|
||||||
}
|
|
||||||
llama_sampler_chain_add(result->chain, llama_sampler_init_greedy());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
Loading…
Reference in New Issue
Block a user