From f2c9800dfb70369f8436c583d2c361b811690abb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 7 Jan 2024 17:52:12 +0200 Subject: [PATCH] passkey : simplify n_past logic --- examples/passkey/passkey.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 682f90e83..862dde996 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -128,12 +128,14 @@ int main(int argc, char ** argv) { llama_batch batch = llama_batch_init(512, 0, 1); + int n_past = 0; + // fill the KV cache for (int i = 0; i < n_ctx; i += n_batch) { llama_batch_clear(batch); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - llama_batch_add(batch, tokens_list[i + j], i + j, { 0 }, false); + llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); } if (i + n_batch >= n_tokens_all) { @@ -160,10 +162,12 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); llama_kv_cache_seq_shift(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + n_past -= n_discard; + llama_batch_clear(batch); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - llama_batch_add(batch, tokens_list[i + j], n_ctx - n_discard + j, { 0 }, false); + llama_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); } if (i + n_batch >= n_tokens_all) { @@ -178,8 +182,6 @@ int main(int argc, char ** argv) { LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all)); } - int n_past = batch.pos[batch.n_tokens - 1]; - { const int n_discard = n_past - n_ctx + n_predict; @@ -236,13 +238,12 @@ int main(int argc, char ** argv) { fflush(stdout); n_decode += 1; - n_past += 1; // prepare the next batch llama_batch_clear(batch); // push this new token for next evaluation - llama_batch_add(batch, new_token_id, n_past, { 0 }, true); + llama_batch_add(batch, new_token_id, n_past++, { 0 }, true); } n_cur += 1;