sampling : temp == 0.0 -> no probs, temp < 0.0 -> probs

This commit is contained in:
Georgi Gerganov 2023-10-28 14:04:57 +03:00
parent c86cca8061
commit bbfc62ac2f
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 5 additions and 4 deletions

View File

@ -224,6 +224,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
sparams.temp = std::stof(argv[i]); sparams.temp = std::stof(argv[i]);
sparams.temp = std::max(sparams.temp, 0.0f);
} else if (arg == "--tfs") { } else if (arg == "--tfs") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;

View File

@ -168,12 +168,12 @@ llama_token llama_sampling_sample(
} }
if (temp < 0.0) { if (temp < 0.0) {
// greedy sampling, no probs
id = llama_sample_token_greedy(ctx_main, &cur_p);
} else if (temp == 0.0) {
// greedy sampling, with probs // greedy sampling, with probs
llama_sample_softmax(ctx_main, &cur_p); llama_sample_softmax(ctx_main, &cur_p);
id = cur_p.data[0].id; id = cur_p.data[0].id;
} else if (temp == 0.0) {
// greedy sampling, no probs
id = llama_sample_token_greedy(ctx_main, &cur_p);
} else { } else {
if (mirostat == 1) { if (mirostat == 1) {
const int mirostat_m = 100; const int mirostat_m = 100;

View File

@ -118,7 +118,7 @@ int main(int argc, char ** argv) {
std::vector<seq_draft> drafts(n_seq_dft); std::vector<seq_draft> drafts(n_seq_dft);
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
params.sparams.temp = 0.0f; params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
drafts[s].ctx_sampling = llama_sampling_init(params.sparams); drafts[s].ctx_sampling = llama_sampling_init(params.sparams);