speculative : fix off-by-one for n_drafted

This commit is contained in:
Georgi Gerganov 2023-10-17 11:40:09 +03:00
parent 373d782d42
commit f07cd35da4
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -336,7 +336,7 @@ int main(int argc, char ** argv) {
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
// no need to evaluate the last drafted token, since we won't use the result
if (batch_tgt.n_tokens == n_draft) {
if (batch_tgt.n_tokens > n_draft) {
drafts[s].drafting = false;
continue;
}
@ -358,11 +358,14 @@ int main(int argc, char ** argv) {
++n_past_cur;
++n_drafted;
if (batch_tgt.n_tokens >= n_draft) {
if (batch_tgt.n_tokens > n_draft) {
break;
}
}
// account for the last drafted token that we didn't evaluate
++n_drafted;
// evaluate the target model on the drafted tokens
{
llama_kv_cache_seq_keep(ctx_tgt, 0);