diff --git a/llama.cpp b/llama.cpp index 92bff6b90..15f7ca43a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2887,7 +2887,6 @@ static bool llama_cache_init( bool offload) { const struct llama_hparams & hparams = model.hparams; - // TODO: per layer n_embd_* const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -3010,6 +3009,8 @@ static bool llama_cache_find_slot( const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; + // FIXME: on failure, leave all caches in a consistent state. + if (rs_size > 0) { // For recurrent state architectures (like Mamba), // each cache cell can store the state for a whole sequence. @@ -3509,7 +3510,7 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id } } -static llama_pos llama_cache_seq_add( +static void llama_cache_seq_add( struct llama_cache & cache, llama_seq_id seq_id, llama_pos p0, @@ -3519,8 +3520,6 @@ static llama_pos llama_cache_seq_add( if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } - llama_pos n_past = p0; - if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be shifted auto & seq = cache.rs.seq_tails[seq_id]; @@ -3541,9 +3540,6 @@ static llama_pos llama_cache_seq_add( } } } - if (n_past <= rs_cell.pos) { - n_past = rs_cell.pos + 1; - } } // If we freed up a slot, set head to it so searching can start there. @@ -3573,9 +3569,6 @@ static llama_pos llama_cache_seq_add( } } } - if (n_past <= kv_cell.pos) { - n_past = kv_cell.pos + 1; - } } } @@ -3583,11 +3576,9 @@ static llama_pos llama_cache_seq_add( // Otherwise we just start the next search from the beginning. cache.kv.head = new_head != cache.kv.size ? new_head : 0; } - - return n_past; } -static llama_pos llama_cache_seq_div( +static void llama_cache_seq_div( struct llama_cache & cache, llama_seq_id seq_id, llama_pos p0, @@ -3596,8 +3587,6 @@ static llama_pos llama_cache_seq_div( if (p0 < 0) { p0 = 0; } if (p1 < 0) { p1 = std::numeric_limits::max(); } - llama_pos n_past = p0; - if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be changed auto & seq = cache.rs.seq_tails[seq_id]; @@ -3609,9 +3598,6 @@ static llama_pos llama_cache_seq_div( rs_cell.pos /= d; } cell_id = rs_cell.prev; - if (n_past <= rs_cell.pos) { - n_past = rs_cell.pos + 1; - } } } @@ -3628,14 +3614,9 @@ static llama_pos llama_cache_seq_div( kv_cell.delta += kv_cell.pos - p_old; } } - if (n_past <= kv_cell.pos) { - n_past = kv_cell.pos + 1; - } } } } - - return n_past; } static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { @@ -16935,13 +16916,11 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { llama_cache_seq_keep(ctx, seq_id); } -llama_pos llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } - if (delta == 0) { - return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; - } +void llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (delta == 0) { return; } - return llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); + llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); } // deprecated @@ -16949,13 +16928,11 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla llama_cache_seq_add(ctx, seq_id, p0, p1, delta); } -llama_pos llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } - if (d == 1) { - return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1; - } +void llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (d == 1) { return; } - return llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); + llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); } // deprecated diff --git a/llama.h b/llama.h index fa6d0b586..bf0f4a9e1 100644 --- a/llama.h +++ b/llama.h @@ -562,7 +562,8 @@ extern "C" { // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past + // Returns n_past (one more than the largest remaining pos in the seq_id) + // which is only meaningful to handle for partial removals. LLAMA_API llama_pos llama_cache_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, @@ -579,7 +580,8 @@ extern "C" { // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past + // Returns n_past (one more than the largest remaining pos in the destination seq_id) + // which is only meaningful to handle when partially copying. LLAMA_API llama_pos llama_cache_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, @@ -609,8 +611,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past - LLAMA_API llama_pos llama_cache_seq_add( + LLAMA_API void llama_cache_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -630,8 +631,7 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - // Returns n_past - LLAMA_API llama_pos llama_cache_seq_div( + LLAMA_API void llama_cache_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -652,7 +652,7 @@ extern "C" { LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id), - "use llama_cache_seq_pos_max instead, which also now returns -1 instead of 0 when the seq_id has no cells"); + "use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: