From 89febfed9322c8849520dc63c93ee4f5fd72556e Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Wed, 21 Feb 2024 10:33:54 -0500 Subject: [PATCH] examples : do not assume BOS when shifting context (#5622) --- examples/main/main.cpp | 12 +++++++----- examples/server/server.cpp | 13 +++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index f5d2f4893..7555dffe4 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -334,6 +334,8 @@ int main(int argc, char ** argv) { // number of tokens to keep when resetting context if (params.n_keep < 0 || params.n_keep > (int) embd_inp.size() || params.instruct || params.chatml) { params.n_keep = (int)embd_inp.size(); + } else { + params.n_keep += add_bos; // always keep the BOS token } // prefix & suffix for instruct mode @@ -383,8 +385,8 @@ int main(int argc, char ** argv) { } } - if (params.n_keep > 0) { - LOG_TEE("%s: static prompt based on n_keep: '", __func__); + if (params.n_keep > add_bos) { + LOG_TEE("%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.n_keep; i++) { LOG_TEE("%s", llama_token_to_piece(ctx, embd_inp[i]).c_str()); } @@ -540,14 +542,14 @@ int main(int argc, char ** argv) { break; } - const int n_left = n_past - params.n_keep - 1; + const int n_left = n_past - params.n_keep; const int n_discard = n_left/2; LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); + llama_kv_cache_seq_shift(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); n_past -= n_discard; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1c4479512..c84719a0d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1487,14 +1487,15 @@ struct llama_server_context if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx) { // Shift context - const int n_left = system_tokens.size() + slot.n_past - slot.params.n_keep - 1; + const int n_keep = slot.params.n_keep + add_bos_token; + const int n_left = system_tokens.size() + slot.n_past - n_keep; const int n_discard = n_left / 2; - LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard); - llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1); - llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, system_tokens.size() + slot.n_past, -n_discard); + LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, n_keep, n_left, n_discard); + llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard); + llama_kv_cache_seq_shift(ctx, slot.id, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); - for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++) + for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { slot.cache_tokens[i - n_discard] = slot.cache_tokens[i]; } @@ -1507,7 +1508,7 @@ struct llama_server_context LOG_VERBOSE("context shift", { { "n_ctx", n_ctx }, - { "n_keep", params.n_keep }, + { "n_keep", n_keep }, { "n_left", n_left }, }); }