mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
speculative : fix off-by-one for n_drafted
This commit is contained in:
parent
373d782d42
commit
f07cd35da4
@ -336,7 +336,7 @@ int main(int argc, char ** argv) {
|
|||||||
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
||||||
|
|
||||||
// no need to evaluate the last drafted token, since we won't use the result
|
// no need to evaluate the last drafted token, since we won't use the result
|
||||||
if (batch_tgt.n_tokens == n_draft) {
|
if (batch_tgt.n_tokens > n_draft) {
|
||||||
drafts[s].drafting = false;
|
drafts[s].drafting = false;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -358,11 +358,14 @@ int main(int argc, char ** argv) {
|
|||||||
++n_past_cur;
|
++n_past_cur;
|
||||||
++n_drafted;
|
++n_drafted;
|
||||||
|
|
||||||
if (batch_tgt.n_tokens >= n_draft) {
|
if (batch_tgt.n_tokens > n_draft) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// account for the last drafted token that we didn't evaluate
|
||||||
|
++n_drafted;
|
||||||
|
|
||||||
// evaluate the target model on the drafted tokens
|
// evaluate the target model on the drafted tokens
|
||||||
{
|
{
|
||||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||||
|
Loading…
Reference in New Issue
Block a user