From 6e08281e588bbba1a5d180290a94a43f167f3a1a Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Sun, 29 Oct 2023 11:31:40 -0600 Subject: [PATCH] Extend llama_kv_cache_seq_rm to allow matching any sequence (#3843) * Extend llama_kv_cache_seq_rm to allow matichng any sequence * Replace llama_kv_cache_tokens_rm with llama_kv_cache_clear Use llama_kv_cache_clear for cache clearing Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality --- common/common.cpp | 2 +- examples/batched-bench/batched-bench.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 4 ++-- examples/main/main.cpp | 2 +- examples/perplexity/perplexity.cpp | 6 ++--- examples/server/server.cpp | 2 +- llama.cpp | 29 ++++++++++++------------ llama.h | 15 +++++------- 8 files changed, 30 insertions(+), 32 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index f81f4d354..c187128d6 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -889,7 +889,7 @@ std::tuple llama_init_from_gpt_par std::vector tmp = { llama_token_bos(model), llama_token_eos(model), }; llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); - llama_kv_cache_tokens_rm(lctx, -1, -1); + llama_kv_cache_clear(lctx); llama_reset_timings(lctx); } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 43f9c971d..533c55c17 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -185,7 +185,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_TEE("%s: llama_decode() failed\n", __func__); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 20767d555..780398184 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1037,7 +1037,7 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); // warmup run if (t.n_prompt > 0) { @@ -1048,7 +1048,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); uint64_t t_start = get_time_ns(); if (t.n_prompt > 0) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 3d9f670b9..8a43b6ab8 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -298,7 +298,7 @@ int main(int argc, char ** argv) { } // remove any "future" tokens that we might have inherited from the previous session - llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1); + llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); } LOGLN( diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 3c2542e8c..bd2c73d87 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -210,7 +210,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -339,7 +339,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -573,7 +573,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { } // clear the KV cache - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab); if (logits.empty()) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 5b7e4139d..c163c7f8e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -857,7 +857,7 @@ struct llama_server_context void kv_cache_clear() { // clear the entire KV cache - llama_kv_cache_tokens_rm(ctx, -1, -1); + llama_kv_cache_clear(ctx); clean_kv_cache = false; } diff --git a/llama.cpp b/llama.cpp index d8510a5cf..a4340d527 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1466,17 +1466,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } -static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) { - if (c0 < 0) c0 = 0; - if (c1 < 0) c1 = cache.size; - - for (int32_t i = c0; i < c1; ++i) { +static void llama_kv_cache_clear(struct llama_kv_cache & cache) { + for (int32_t i = 0; i < cache.size; ++i) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); } - - // Searching for a free slot can start here since we know it will be empty. - cache.head = uint32_t(c0); + cache.head = 0; } static void llama_kv_cache_seq_rm( @@ -1490,8 +1485,14 @@ static void llama_kv_cache_seq_rm( if (p1 < 0) p1 = std::numeric_limits::max(); for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].seq_id.erase(seq_id); + if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + if (seq_id < 0) { + cache.cells[i].seq_id.clear(); + } else if (cache.cells[i].has_seq_id(seq_id)) { + cache.cells[i].seq_id.erase(seq_id); + } else { + continue; + } if (cache.cells[i].seq_id.empty()) { cache.cells[i].pos = -1; if (new_head == cache.size) new_head = i; @@ -9207,8 +9208,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) { return ctx->kv_self.head; } -void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) { - llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1); +void llama_kv_cache_clear(struct llama_context * ctx) { + llama_kv_cache_clear(ctx->kv_self); } void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { @@ -9654,7 +9655,7 @@ int llama_eval( llama_token * tokens, int32_t n_tokens, int n_past) { - llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); + llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0)); if (ret < 0) { @@ -9669,7 +9670,7 @@ int llama_eval_embd( float * embd, int32_t n_tokens, int n_past) { - llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1); + llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1); llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, }; diff --git a/llama.h b/llama.h index 6927bd601..d727dbd9f 100644 --- a/llama.h +++ b/llama.h @@ -334,17 +334,14 @@ extern "C" { LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), "avoid using this, it will be removed in the future, instead - count the tokens in user code"); - // Remove all tokens data of cells in [c0, c1) - // c0 < 0 : [0, c1] - // c1 < 0 : [c0, inf) - LLAMA_API void llama_kv_cache_tokens_rm( - struct llama_context * ctx, - int32_t c0, - int32_t c1); + // Clear the KV cache + LLAMA_API void llama_kv_cache_clear( + struct llama_context * ctx); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - // p0 < 0 : [0, p1] - // p1 < 0 : [p0, inf) + // seq_id < 0 : match any sequence + // p0 < 0 : [0, p1] + // p1 < 0 : [p0, inf) LLAMA_API void llama_kv_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id,