From a09db95eabb5f75a5534f804882cf82e1bb5cadd Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 29 Apr 2024 10:24:45 -0400 Subject: [PATCH] llama : rename many llama_kv_cache_* functions --- llama.cpp | 111 +++++++++++++++++++++++++++++++++++++----------------- llama.h | 72 +++++++++++++++++++++++++++++------ 2 files changed, 138 insertions(+), 45 deletions(-) diff --git a/llama.cpp b/llama.cpp index 9d887c6db..f972c3472 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2032,7 +2032,6 @@ struct llama_rs_seq_meta { // ring-buffered tree of cached recurrent state data struct llama_rs_cache { - bool do_copy = false; uint32_t head = 0; // first state used for the last slot uint32_t size = 0; @@ -2769,7 +2768,7 @@ struct llama_context { }; // -// kv cache helpers +// kv and rs cache helpers // static bool llama_cache_init( @@ -2898,7 +2897,7 @@ static bool llama_cache_init( // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_kv_cache_find_slot( +static bool llama_cache_find_slot( struct llama_cache & cache, const struct llama_batch & batch) { const uint32_t kv_size = cache.kv.size; @@ -3181,7 +3180,6 @@ static void llama_cache_clear(struct llama_cache & cache) { rs_cell.tail_rc = 0; rs_cell.seq_nodes.clear(); } - cache.rs.do_copy = false; cache.rs.head = 0; cache.rs.used = 0; cache.rs.n_seqs = 0; @@ -3412,8 +3410,8 @@ static llama_pos llama_cache_seq_add( llama_pos p1, llama_pos delta) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } llama_pos n_past = p0; @@ -3535,7 +3533,7 @@ static llama_pos llama_cache_seq_div( } static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { - llama_pos result = 0; + llama_pos result = -1; if (cache.rs.size > 0) { int32_t cell_id = cache.rs.seq_tails[seq_id].tail; @@ -11174,7 +11172,7 @@ static int llama_decode_internal( if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - if (!llama_kv_cache_find_slot(lctx.cache, u_batch)) { + if (!llama_cache_find_slot(lctx.cache, u_batch)) { return 1; } @@ -15790,6 +15788,10 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k } } +bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug) { + return ctx->cache.rs.rebuild(debug); +} + int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) { int result = 0; @@ -15804,55 +15806,96 @@ int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { return ctx->cache.kv.used; } -void llama_kv_cache_clear(struct llama_context * ctx) { +int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx) { + return ctx->cache.rs.used; +} + +void llama_cache_clear(struct llama_context * ctx) { llama_cache_clear(ctx->cache); } +// deprecated +void llama_kv_cache_clear(struct llama_context * ctx) { + llama_cache_clear(ctx); +} + +llama_pos llama_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + return llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); +} + +// deprecated bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return false; } - llama_pos n_past = llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); + llama_pos n_past = llama_cache_seq_rm(ctx, seq_id, p0, p1); return n_past >= p0; } -void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + +llama_pos llama_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { uint32_t n_seq_max = llama_n_seq_max(ctx); if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { - return; + return 0; } if (seq_id_src == seq_id_dst) { - return; + return llama_cache_seq_pos_max(ctx->cache, seq_id_dst) + 1; } - llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); + return llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } -void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { +// deprecated +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + llama_cache_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } llama_cache_seq_keep(ctx->cache, seq_id); } -void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (delta == 0) { - return; - } - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - - llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); +// deprecated +void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { + llama_cache_seq_keep(ctx, seq_id); } -void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (d == 1) { - return; - } - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - - llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); -} - -llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { +llama_pos llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + if (delta == 0) { + return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; + } + + return llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); +} + +// deprecated +void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + llama_cache_seq_add(ctx, seq_id, p0, p1, delta); +} + +llama_pos llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + if (d == 1) { + return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; + } + + return llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); +} + +// deprecated +void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + llama_cache_seq_div(ctx, seq_id, p0, p1, d); +} + +llama_pos llama_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } return llama_cache_seq_pos_max(ctx->cache, seq_id); } +// deprecated +llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { + llama_pos max_pos = llama_cache_seq_pos_max(ctx, seq_id); + return max_pos < 0 ? 0 : max_pos; +} + void llama_kv_cache_defrag(struct llama_context * ctx) { llama_kv_cache_defrag(ctx->cache.kv); } @@ -16597,7 +16640,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, batch.n_seq_id[i] = 1; batch.seq_id[i][0] = dest_seq_id; } - if (!llama_kv_cache_find_slot(cache, batch)) { + if (!llama_cache_find_slot(cache, batch)) { llama_batch_free(batch); LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); return 0; diff --git a/llama.h b/llama.h index b770a275f..c211ca592 100644 --- a/llama.h +++ b/llama.h @@ -515,6 +515,12 @@ extern "C" { // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); + // Rebuild and check the validity of the recurrent state cache's tree of sequences. + // (slow, use only for debugging purposes) + // Returns whether or not the rs cache was valid. + // The errors are always corrected, but only logged when debug is true. + LLAMA_API bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug); + // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); @@ -522,36 +528,60 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache - LLAMA_API void llama_kv_cache_clear( + // Returns the number of used recurrent state cells (i.e. have at least one sequence assigned to them) + LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); + + // Clear the KV and recurrent state caches + LLAMA_API void llama_cache_clear( struct llama_context * ctx); + LLAMA_API DEPRECATED(void llama_kv_cache_clear( + struct llama_context * ctx), + "use llama_cache_clear instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API bool llama_kv_cache_seq_rm( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1), + "use llama_cache_seq_rm instead, and handle its return value for partial removals"); // Copy all tokens that belong to the specified sequence to another sequence - // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_cp( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1), + "use llama_cache_seq_cp instead, and handle its return value for partial copies"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_cache_seq_keep( + LLAMA_API void llama_cache_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_cache_seq_keep instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -559,12 +589,20 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_add( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_add( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta), + "use llama_cache_seq_add instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -572,17 +610,29 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_div( + // Returns n_past + LLAMA_API llama_pos llama_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_div( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d), + "use llama_cache_seq_div instead"); - // Returns the largest position present in the KV cache for the specified sequence - LLAMA_API llama_pos llama_kv_cache_seq_pos_max( + // Returns the largest position present in the KV and/or RS cache for the specified sequence + LLAMA_API llama_pos llama_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_cache_seq_pos_max instead, which also now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: