mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
speculative : fix the n_drafted fix + p constants
This commit is contained in:
parent
f07cd35da4
commit
e6dd81f0bc
@ -36,6 +36,10 @@ int main(int argc, char ** argv) {
|
||||
// max number of parallel drafting sequences (i.e. tree branches)
|
||||
const int n_seq_dft = params.n_parallel;
|
||||
|
||||
// TODO: make this configurable
|
||||
const float p_accept = 0.4f;
|
||||
const float p_split = 0.3f;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("speculative", "log"));
|
||||
LOG_TEE("Log start\n");
|
||||
@ -272,8 +276,7 @@ int main(int argc, char ** argv) {
|
||||
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
||||
}
|
||||
|
||||
// TODO: make this configurable
|
||||
if (cur_p[0].p < 0.4) {
|
||||
if (cur_p[0].p < p_accept) {
|
||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
|
||||
drafts[s].drafting = false;
|
||||
continue;
|
||||
@ -283,8 +286,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// attempt to split the branch if the probability is high enough
|
||||
for (int f = 1; f < 8; ++f) {
|
||||
// TODO: make this configurable
|
||||
if (n_seq_cur < n_seq_dft && cur_p[f].p > 0.3) {
|
||||
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
||||
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||
@ -364,7 +366,9 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// account for the last drafted token that we didn't evaluate
|
||||
++n_drafted;
|
||||
if (batch_tgt.n_tokens > n_draft) {
|
||||
++n_drafted;
|
||||
}
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user