speculative : fix the n_drafted fix + p constants

This commit is contained in:
Georgi Gerganov 2023-10-17 17:04:31 +03:00
parent f07cd35da4
commit e6dd81f0bc
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -36,6 +36,10 @@ int main(int argc, char ** argv) {
// max number of parallel drafting sequences (i.e. tree branches) // max number of parallel drafting sequences (i.e. tree branches)
const int n_seq_dft = params.n_parallel; const int n_seq_dft = params.n_parallel;
// TODO: make this configurable
const float p_accept = 0.4f;
const float p_split = 0.3f;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("speculative", "log")); log_set_target(log_filename_generator("speculative", "log"));
LOG_TEE("Log start\n"); LOG_TEE("Log start\n");
@ -272,8 +276,7 @@ int main(int argc, char ** argv) {
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
} }
// TODO: make this configurable if (cur_p[0].p < p_accept) {
if (cur_p[0].p < 0.4) {
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p); LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
drafts[s].drafting = false; drafts[s].drafting = false;
continue; continue;
@ -283,8 +286,7 @@ int main(int argc, char ** argv) {
// attempt to split the branch if the probability is high enough // attempt to split the branch if the probability is high enough
for (int f = 1; f < 8; ++f) { for (int f = 1; f < 8; ++f) {
// TODO: make this configurable if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
if (n_seq_cur < n_seq_dft && cur_p[f].p > 0.3) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur); LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
@ -364,7 +366,9 @@ int main(int argc, char ** argv) {
} }
// account for the last drafted token that we didn't evaluate // account for the last drafted token that we didn't evaluate
++n_drafted; if (batch_tgt.n_tokens > n_draft) {
++n_drafted;
}
// evaluate the target model on the drafted tokens // evaluate the target model on the drafted tokens
{ {