From 1f17ea631c863e50f292354c8916046de01aacf7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 18 Sep 2023 19:01:20 +0300 Subject: [PATCH] speculative : fix KV cache management --- examples/speculative/speculative.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 06173393c..053073397 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -172,6 +172,7 @@ int main(int argc, char ** argv) { LOG("out of drafted tokens\n"); } + llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx); llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads); ++n_past_dft; @@ -217,6 +218,7 @@ int main(int argc, char ** argv) { // sample n_draft tokens from the draft model using greedy decoding int n_past_cur = n_past_dft; + for (int i = 0; i < n_draft; ++i) { float * logits = llama_get_logits(ctx_dft); @@ -256,6 +258,7 @@ int main(int argc, char ** argv) { } // evaluate the drafted token on the draft model + llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx); llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads); ++n_past_cur; @@ -265,6 +268,7 @@ int main(int argc, char ** argv) { } // evaluate the target model on the drafted tokens + llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx); llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads); ++n_past_tgt;