speculative : fix default RNG seed + set sparams.n_probs

This commit is contained in:
Georgi Gerganov 2024-09-23 12:44:28 +03:00
parent 8241bc71b5
commit 3cb33a8e29
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -32,6 +32,9 @@ struct seq_draft {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
// needed to get candidate probs even for temp <= 0.0
params.sparams.n_probs = 128;
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
return 1; return 1;
} }
@ -49,7 +52,7 @@ int main(int argc, char ** argv) {
// probability threshold for splitting a draft branch (only for n_seq_dft > 1) // probability threshold for splitting a draft branch (only for n_seq_dft > 1)
const float p_split = params.p_split; const float p_split = params.p_split;
std::default_random_engine rng(params.sparams.seed); std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
std::uniform_real_distribution<> u_dist; std::uniform_real_distribution<> u_dist;
// init llama.cpp // init llama.cpp