From 271104c65c9b99d5b5aca4489d7bec103cd60db9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 3 Apr 2024 11:07:16 -0400 Subject: [PATCH] wip: llama : separate recurrent states from the KV cache This will be necessary to support Jamba (and other recurrent models mixed with Attention). Doesn't compile yet, and finding a slot isn't yet done correctly for recurrent states. --- llama.cpp | 1424 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 979 insertions(+), 445 deletions(-) diff --git a/llama.cpp b/llama.cpp index 267ac4cc0..9ca8ca0f4 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1793,14 +1793,14 @@ struct llama_hparams { return n_embd_head_v * n_head_kv; } - uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings + uint32_t n_embd_r() const { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } - uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings + uint32_t n_embd_s() const { // dimension of the recurrent state embeddings // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -1904,7 +1904,6 @@ struct llama_layer { struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; - int32_t src = 0; // used by recurrent state models to copy states std::set seq_id; @@ -1925,9 +1924,6 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; - bool do_copy = false; - // with recurrent state models, a cell can hold the state for more than one past token - bool recurrent = false; // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_internal also uses it, so it @@ -1947,9 +1943,365 @@ struct llama_kv_cache { std::vector k_l; // per layer std::vector v_l; + size_t total_size() const { + size_t size = 0; + for (struct ggml_tensor * k : k_l) { + size += ggml_nrows(k) * ggml_row_size(k->type, k->ne[0]); + } + for (struct ggml_tensor * v : v_l) { + size += ggml_nrows(v) * ggml_row_size(v->type, v->ne[0]); + } + return size; + } +}; + +// for recurrent models, use a tree of sequences to simplify +// quickly finding the tail cell of each sequence +// TODO: drop the _rs_ infix +struct llama_rs_seq_node { + llama_seq_id seq_id = -1; + int32_t next_cell = -1; + + // needed for automatic typecasting with .find() + llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} + + bool operator<(const llama_rs_seq_node & other) const { + return seq_id < other.seq_id; + } + + bool is_tail() const { + return next_cell < 0; + } +}; + +struct llama_rs_cell { + llama_pos pos = -1; + int32_t src = -1; // copy source id (cleared next when -1) + + // Link to previous cell in this sequence. + // Sequences can only diverge, never converge, + // so this works when there are multiple seq_ids per cell too. + int32_t prev = -1; + + // ref count of tails (should match the number of next_cell == -1 in seq_nodes) + uint32_t tail_rc = 0; + + // seq_ids by insertion order, to simplify updating n_cells compared to a set + std::vector seq_nodes; + + llama_rs_seq_node * get_node(const llama_seq_id & id) { + for (size_t i = 0; i < seq_nodes.size(); ++i) { + if (seq_nodes[i].seq_id == id) { + return &seq_nodes[i]; + } + } + return nullptr; + } + + void insert_node(const llama_rs_seq_node & node) { + llama_rs_seq_node * node_dest = get_node(node.seq_id); + if (node_dest == nullptr) { + seq_nodes.push_back(node); + } else { + *node_dest = node; + } + } + + bool remove_node(llama_rs_seq_node * node_ptr) { + if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) { + size_t offset = node_ptr - seq_nodes.data(); + if (offset % sizeof(llama_rs_seq_node) == 0) { + offset /= sizeof(llama_rs_seq_node); + if (offset < seq_nodes.size()) { + for (size_t i = offset + 1; i < seq_nodes.size(); ++i) { + seq_nodes[i - 1] = seq_nodes[i]; + } + seq_nodes.resize(seq_nodes.size() - 1); + return true; + } + } + } + return false; + } + + bool has_seq_id(const llama_seq_id & id) const { + for (size_t i = 0; i < seq_nodes.size(); ++i) { + if (seq_nodes[i].seq_id == id) { + return true; + } + } + return false; + } + + bool is_empty() const { + return seq_nodes.empty(); + } +}; + + +struct llama_rs_seq_meta { + // cell id of the latest state of this seq_id + int32_t tail = -1; + // number of cells for which this seq_id is the first + // (useful to know if cells in this sequence should be pruned) + int32_t n_cells = 0; + // whether the tail is a cell part of multiple sequences + bool shared = false; +}; + +// ring-buffer of cached recurrent state data +struct llama_rs_cache { + bool do_copy = false; + + uint32_t head = 0; // first state used for the last slot + uint32_t size = 0; + uint32_t used = 0; + + // computed when finding a slot + uint32_t n = 0; // range of states used for the last slot + + // useful to know the minimum reserved cell count per seq_id + // only counts sequences with n_cells > 0 + uint32_t n_seqs = 0; + + // with state models, a cell can hold the state for more than one past token + // TODO: it's probably not possible to always use contiguous cells + std::vector cells; + + // find tail cells faster + std::vector seq_tails; // map seq_ids to cell ids + + // per layer + // NOTE: the naming of r and s is arbitrary + std::vector r_l; // rolling/shift states + std::vector s_l; // ssm (recurrent) states + + // returns whether or not a cell was freed + bool clear_cell(uint32_t i) { + if (i < size) { + llama_rs_cell & rs_cell = cells[i]; + if (!rs_cell.is_empty()) { + // update sequence tree links + bool first = true; + for (const llama_rs_seq_node & node : rs_cell.seq_nodes) { + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: if all next cells are the same cell, this should still work + cells[node.next_cell].prev = rs_cell.prev; + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + // update tail + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail >= 0 && (uint32_t) seq.tail < size) { + llama_rs_cell & new_tail = cells[seq.tail]; + new_tail.insert_node(node.seq_id); // ensures next_cell == -1 + new_tail.tail_rc += 1; + seq.shared = new_tail.seq_nodes.size() > 1; + } else { + seq.shared = false; + } + } + // cell counts + if (first) { + seq.n_cells -= 1; + if (seq.n_cells == 0) { + GGML_ASSERT(seq.tail < 0); + n_seqs -= 1; + } + first = false; + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; + rs_cell.seq_nodes.clear(); + used -= 1; + return true; + } + } + return false; + } + + // TODO: maybe use a simpler data structure than a tree + // returns whether or not a cell was freed + bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < size) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto * node_ptr = rs_cell.get_node(id); // search once + if (node_ptr != nullptr) { + if (rs_cell.seq_nodes.size() == 1) { + return clear_cell(i_cell); + } else { + // update tree + llama_rs_seq_node node = *node_ptr; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + cells[node.next_cell].prev = rs_cell.prev; + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail >= 0 && (uint32_t) seq.tail < size) { + llama_rs_cell & new_tail = cells[seq.tail]; + new_tail.insert_node(node.seq_id); // ensures next_cell == -1 + new_tail.tail_rc += 1; + seq.shared = cells[seq.tail].seq_nodes.size() > 1; + } else { + seq.shared = false; + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; + } + if (node_ptr == rs_cell.seq_nodes.data()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + if (seq.n_cells == 0) { + n_seqs -= 1; + } + // the next node is the new first one, so update its n_cells + // (will never be out-of-bounds because the size is > 1) + llama_rs_seq_node next_node = node_ptr[1]; + if ((uint32_t) next_node.seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node.seq_id]; + next_seq.n_cells += 1; + if (next_seq.n_cells == 1) { + n_seqs += 1; + } + if (other_no_longer_shared) { + next_seq.shared = false; + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } else if (other_no_longer_shared) { + llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; + if ((uint32_t) first_node.seq_id < seq_tails.size()) { + seq_tails[first_node.seq_id].shared = false; + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + const bool removed = rs_cell.remove_node(node_ptr); + GGML_ASSERT(removed); + } + } + } + return false; + } + + bool insert_seq_tail_to_cell(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < seq_tails.size()) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto & seq = seq_tails[id]; + int32_t prev = rs_cell.prev; + if ((uint32_t) seq.tail == i_cell) { + // the cell is already the tail of this seq_id + return false; + } + if (rs_cell.is_empty()) { + prev = seq.tail; + } + // ensure the new tail won't mess up the tree + GGML_ASSERT(seq.tail == -1 || seq.tail == prev); + if (prev >= 0 && (uint32_t) prev < size) { + // the targeted cell has a previous cell + llama_rs_cell & prev_cell = cells[prev]; + llama_rs_seq_node * prev_node = prev_cell.get_node(id); + GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing + GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken + if (rs_cell.pos < 0) { + GGML_ASSERT(rs_cell.is_empty()); + rs_cell.pos = prev_cell.pos + 1; + rs_cell.src = prev_cell.src; + } + prev_cell.tail_rc -= 1; + prev_node->next_cell = i_cell; + } + if (rs_cell.is_empty()) { + // only add after potential failures above + if (seq.n_cells == 0) { + n_seqs += 1; + } + seq.n_cells += 1; + // set pos if still unset + if (rs_cell.pos < 0) { + rs_cell.pos = 0; + rs_cell.src = -1; + } + } + // the target cell was not already a tail of this seq_id + rs_cell.insert_node(id); // next_cell == -1 by default + rs_cell.tail_rc += 1; + seq.tail = i_cell; + seq.shared = rs_cell.seq_nodes.size() > 1; + return true; + } + return false; + } + + // each seq_id should have access to at least this many cells + // (to use when pruning (to avoid over-pruning)) + // (but this over-prunes when the system prompt doesn't take lots of cells) + // Hmm. The system prompt does not need checkpoints... + size_t min_cells_per_seq() const { + return size / (n_seqs > 0 ? n_seqs : 1); + } + + // each seq_id can have at most this many cells + // (ignoring seqs which behave as a shared prompt) + // TODO: avoid recalculating system seq_ids + // (to use when pruning (to avoid over-pruning)) + // NOTE: this also limits the shared prompt to at most half the cells + // (but the shared prompt technically needs only one cell...) + // (IDEA: keep only one cell when `llama_kv_cache_seq_cp` is called on a sequence) + size_t max_cells_per_seq() const { + int32_t n_system_seqs = 0; + int32_t n_system_cells = 0; + for (size_t i = 0; i < seq_tails.size(); ++i) { + auto & seq = seq_tails[i]; + if (seq.tail >= 0 && (size_t) seq.tail < size) { + if (seq.shared && seq.n_cells > 0) { + n_system_seqs += 1; + n_system_cells += seq.n_cells; + } + } + } + int32_t n_other_seqs = n_seqs - n_system_seqs; + return (size - n_system_cells) / (n_other_seqs > 0 ? n_other_seqs : 1); + } + + size_t total_size() const { + size_t size = 0; + for (struct ggml_tensor * r : r_l) { + size += ggml_nrows(r) * ggml_row_size(r->type, r->ne[0]); + } + for (struct ggml_tensor * s : s_l) { + size += ggml_nrows(s) * ggml_row_size(s->type, s->ne[0]); + } + return size; + } +}; + +struct llama_cache { + // key + value cache for self attention + llama_kv_cache kv; + + // recurrent state cache for state space models + llama_rs_cache rs; + std::vector ctxs; std::vector bufs; + // NOTE: padding may make this bigger than kv.total_size() + rs.total_size() size_t total_size() const { size_t size = 0; for (ggml_backend_buffer_t buf : bufs) { @@ -1958,7 +2310,7 @@ struct llama_kv_cache { return size; } - ~llama_kv_cache() { + ~llama_cache() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); } @@ -2146,8 +2498,8 @@ struct llama_context { const llama_model & model; - // key + value cache for the self attention - struct llama_kv_cache kv_self; + // key + value cache for self-attention, and/or recurrent state cache + struct llama_cache cache; std::mt19937 rng; @@ -2205,9 +2557,9 @@ struct llama_context { struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_s_copy; // I32 [n_rs] + struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] + struct ggml_tensor * inp_s_seq; // I32 [n_rs, n_batch] // control vectors struct llama_control_vector cvec; @@ -2221,47 +2573,45 @@ struct llama_context { // kv cache helpers // -static bool llama_kv_cache_init( - struct llama_kv_cache & cache, +static bool llama_cache_init( + struct llama_cache & cache, const llama_model & model, ggml_type type_k, ggml_type type_v, - uint32_t kv_size, + uint32_t n_ctx, + uint32_t n_seq_max, bool offload) { const struct llama_hparams & hparams = model.hparams; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + + // TODO: per layer n_embd_* + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const uint32_t n_embd_r = hparams.n_embd_r(); + const uint32_t n_embd_s = hparams.n_embd_s(); + const bool has_kv = hparams.n_head != 0 && hparams.causal_attn; + const bool has_r = n_embd_r != 0; + const bool has_s = n_embd_s != 0; + const bool has_rs = has_r || has_s; + const uint32_t kv_size = has_kv ? n_ctx : 0; + const uint32_t rs_size = has_rs ? n_seq_max : 0; + // TODO: per cache type layer count const int64_t n_layer = hparams.n_layer; - cache.has_shift = false; + cache.kv.size = kv_size; - // TODO: find a nicer way to add other recurrent model architectures - cache.recurrent = model.arch == LLM_ARCH_MAMBA; + cache.kv.type_k = type_k; + cache.kv.type_v = type_v; - // TODO: support mixed reccurent Transformer architectues - // NOTE: (!a || b) is a logical implication (a -> b) - GGML_ASSERT(!cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_s()); - GGML_ASSERT(!cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_s()); - GGML_ASSERT( cache.recurrent || n_embd_k_gqa == hparams.n_embd_k_gqa()); - GGML_ASSERT( cache.recurrent || n_embd_v_gqa == hparams.n_embd_v_gqa()); + cache.kv.cells.clear(); + cache.kv.cells.resize(kv_size); - cache.head = 0; - cache.size = kv_size; - cache.used = 0; + cache.rs.size = rs_size; - cache.type_k = type_k; - cache.type_v = type_v; - - cache.cells.clear(); - cache.cells.resize(kv_size); - - if (cache.recurrent) { - // init state copy sources - for (uint32_t i = 0; i < cache.size; ++i) { - cache.cells[i].src = i; - } - } + cache.rs.cells.clear(); + cache.rs.cells.resize(rs_size); + cache.rs.seq_tails.clear(); + cache.rs.seq_tails.resize(rs_size); #ifdef GGML_USE_CLBLAST offload = false; @@ -2282,7 +2632,7 @@ static bool llama_kv_cache_init( for (auto & it : buft_layer_count) { int n_layers = it.second; struct ggml_init_params params = { - /*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(), + /*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -2295,17 +2645,37 @@ static bool llama_kv_cache_init( cache.ctxs.push_back(ctx); } - cache.k_l.reserve(n_layer); - cache.v_l.reserve(n_layer); + if (has_kv) { + cache.kv.k_l.reserve(n_layer); + cache.kv.v_l.reserve(n_layer); + } + if (has_r) { + cache.rs.r_l.reserve(n_layer); + } + if (has_s) { + cache.rs.s_l.reserve(n_layer); + } for (int i = 0; i < (int) n_layer; i++) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - cache.k_l.push_back(k); - cache.v_l.push_back(v); + if (has_kv) { + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.kv.k_l.push_back(k); + cache.kv.v_l.push_back(v); + } + if (has_r) { + ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_r*rs_size); + ggml_format_name(r, "cache_r_l%d", i); + cache.rs.r_l.push_back(r); + } + if (has_s) { + ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_s*rs_size); + ggml_format_name(s, "cache_s_l%d", i); + cache.rs.s_l.push_back(s); + } } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -2330,23 +2700,30 @@ static bool llama_kv_cache_init( // Note: On success, it's important that cache.head points // to the first cell of the slot. static bool llama_kv_cache_find_slot( - struct llama_kv_cache & cache, - const struct llama_batch & batch) { - const uint32_t n_ctx = cache.size; + struct llama_cache & cache, + const struct llama_batch & batch) { + const uint32_t kv_size = cache.kv.size; + const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; - if (cache.recurrent) { + if (rs_size > 0) { // For recurrent state architectures (like Mamba), - // each KV cache cell can store the state for a whole sequence. + // each cache cell can store the state for a whole sequence. + // TODO: real ring-buffer of states + // TODO: state chekpoints (multiple cells per sequence) + // TODO: find a way to always make the rs slot contiguous - llama_seq_id min = cache.size - 1; + // Okay, need to find a slot. Everything should fit assuming the biggest seq_id < rs_size + + + llama_seq_id min = cache.rs.size - 1; llama_seq_id max = 0; for (uint32_t i = 0; i < n_tokens; ++i) { for (int32_t j = 0; j < batch.n_seq_id[i]; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; // make sure it's a valid seq_id - if ((uint32_t) seq_id < cache.size) { + if ((uint32_t) seq_id < rs_size) { if (seq_id > max) { max = seq_id; } @@ -2354,83 +2731,93 @@ static bool llama_kv_cache_find_slot( min = seq_id; } // Assuming the tokens are in-order - if (batch.pos[i] != cache.cells[seq_id].pos + 1) { + if (batch.pos[i] != cache.rs.cells[seq_id].pos + 1) { // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.cells[seq_id].pos, seq_id); + __func__, batch.pos[i], cache.rs.cells[seq_id].pos, seq_id); } - if (cache.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { - cache.used += 1; + if (cache.rs.cells[seq_id].pos < 0 && 0 <= batch.pos[i]) { + cache.rs.used += 1; } - cache.cells[seq_id].pos = batch.pos[i]; - // NOTE: seq_ids are not inserted here; they are handled when the input tensors are set + cache.rs.cells[seq_id].pos = batch.pos[i]; + cache.rs.cells[seq_id].seq_nodes.insert(seq_id); } else { // too big seq_id // TODO: would it be possible to resize the KV cache size instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= kv_size=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } } } // allow getting the range of used cells, from head to head + n - cache.head = min; - cache.n = max - min + 1; + cache.rs.head = min; + cache.rs.n = max - min + 1; // sanity check - return max >= min; - } - // otherwise, one cell per token. - - if (n_tokens > n_ctx) { - LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, n_ctx); - return false; - } - - uint32_t n_tested = 0; - - while (true) { - if (cache.head + n_tokens > n_ctx) { - n_tested += n_ctx - cache.head; - cache.head = 0; - continue; - } - - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cache.cells[cache.head + i].pos >= 0) { - found = false; - cache.head += i + 1; - n_tested += i + 1; - break; - } - } - - if (found) { - break; - } - - if (n_tested >= n_ctx) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + if (max < min) { return false; } } - for (uint32_t i = 0; i < n_tokens; i++) { - cache.cells[cache.head + i].pos = batch.pos[i]; - - for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { - cache.cells[cache.head + i].seq_id.insert(batch.seq_id[i][j]); + if (kv_size > 0) { + // one KV cell per token + if (n_tokens > kv_size) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > n_ctx=%d\n", __func__, n_tokens, kv_size); + return false; } - } - cache.used += n_tokens; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (cache.kv.head > cache.kv.used + 2*n_tokens) { + cache.kv.head = 0; + } + + uint32_t n_tested = 0; + + while (true) { + if (cache.kv.head + n_tokens > kv_size) { + n_tested += kv_size - cache.kv.head; + cache.kv.head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.kv.cells[cache.kv.head + i].pos >= 0) { + found = false; + cache.kv.head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= kv_size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t i = 0; i < n_tokens; i++) { + cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; + + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + cache.kv.cells[cache.kv.head + i].seq_id.insert(batch.seq_id[i][j]); + } + } + + cache.kv.used += n_tokens; + } return true; } -// find how many cells are currently in use +// find how many KV cells are currently in use static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { for (uint32_t i = cache.size; i > 0; --i) { const llama_kv_cell & cell = cache.cells[i - 1]; @@ -2443,214 +2830,381 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } -static void llama_kv_cache_clear(struct llama_kv_cache & cache) { - for (int32_t i = 0; i < (int32_t) cache.size; ++i) { - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); +static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { + for (uint32_t i = cache.size; i > 0; --i) { + const llama_rs_cell & cell = cache.cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } } - cache.head = 0; - cache.used = 0; + + return 0; } -static bool llama_kv_cache_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - uint32_t new_head = cache.size; +static void llama_cache_clear(struct llama_cache & cache) { + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + kv_cell.pos = -1; + kv_cell.delta = 0; + kv_cell.seq_id.clear(); + } + cache.kv.has_shift = false; + cache.kv.do_defrag = false; + cache.kv.head = 0; + cache.kv.used = 0; + } + if (cache.rs.size > 0) { + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.seq_nodes.clear(); + } + cache.rs.do_copy = false; + cache.rs.head = 0; + cache.rs.used = 0; + cache.rs.n_seqs = 0; + cache.rs.seq_tails.clear(); + cache.rs.seq_tails.resize(cache.rs.size); + } +} + +static llama_pos llama_cache_seq_rm( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - // models like Mamba can't have a state partially erased - if (cache.recurrent) { - if (seq_id >= (int64_t) cache.size) { + llama_pos n_past = p0; + + if (cache.rs.size > 0) { + if (seq_id >= (int64_t) cache.rs.size) { // could be fatal - return false; + return n_past; } - if (0 <= seq_id) { - // partial intersection is invalid - if ((0 < p0 && p0 <= cache.cells[seq_id].pos) || (0 < p1 && p1 <= cache.cells[seq_id].pos)) { - return false; - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; + uint32_t new_head = cache.rs.size; + // adjust p0 and p1 according to the states found + llama_pos new_p0 = 0; + llama_pos new_p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (seq_id < 0 || rs_cell.has_seq_id(seq_id)) { + if (rs_cell.pos < p0) { + // move forward the new p0 further + if (rs_cell.pos >= new_p0) { + new_p0 = rs_cell.pos + 1; + } + } else if (rs_cell.pos >= p1) { + // move back the new p1 further + if (rs_cell.pos < new_p1) { + new_p1 = rs_cell.pos; + } + if (rs_cell.pos >= n_past) { + n_past = rs_cell.pos + 1; + } + } else { // (rs_cell.pos >= p0 && rs_cell.pos < p1) + if (seq_id < 0) { + cache.rs.clear_cell(i); + } else { // (rs_cell.has_seq_id(seq_id)) + cache.rs.remove_seq_from_cell(i, seq_id); + } + if (rs_cell.is_empty() && new_head == cache.rs.size) { + new_head = i; + } + } } } + p0 = new_p0; + p1 = new_p1; + // correctly set n_past when there's nothing after p1 + if (n_past < p0) { n_past = p0; } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - if (seq_id < 0) { - cache.cells[i].seq_id.clear(); - } else if (cache.cells[i].has_seq_id(seq_id)) { - cache.cells[i].seq_id.erase(seq_id); - } else { - continue; - } - if (cache.cells[i].is_empty()) { - // keep count of the number of used cells - if (cache.cells[i].pos >= 0) cache.used--; + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; - cache.cells[i].pos = -1; - if (new_head == cache.size) new_head = i; + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + + if (seq_id < 0 || kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + if (seq_id < 0) { + kv_cell.seq_id.clear(); + } else { // (kv_cell.has_seq_id(seq_id)) + kv_cell.seq_id.erase(seq_id); + } + if (kv_cell.is_empty()) { + // keep count of the number of used cells + if (kv_cell.pos >= 0) { cache.kv.used--; } + + kv_cell.pos = -1; + if (new_head == cache.kv.size) { new_head = i; } + } + } else { + if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; + } + } } } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.kv.size && new_head < cache.kv.head) { + cache.kv.head = new_head; + } } - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; - - return true; + return n_past; } -static void llama_kv_cache_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { +static llama_pos llama_cache_seq_cp( + struct llama_cache & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.recurrent) { - if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { - seq_id_src = cache.cells[seq_id_src].src; - GGML_ASSERT((uint32_t) seq_id_src < cache.size); - // intent to "copy from" - // supports copy chains thanks to taking the source of the source - cache.cells[seq_id_dst].src = seq_id_src; + // TODO: in practice this seems to be only used on whole sequences; + // should partial sequence copy be removed? - // preserve the "keep or clear" status of the copied sequence - if (cache.cells[seq_id_src].has_seq_id(seq_id_src)) { - cache.cells[seq_id_dst].seq_id.insert(seq_id_dst); - } else { - cache.cells[seq_id_dst].seq_id.erase(seq_id_dst); + llama_pos n_past = 0; + + if (cache.rs.size > 0) { + // have to start from beginning for recurrent models + p0 = 0; + if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { + auto seq_src = cache.rs.seq_tails[seq_id_src]; + int32_t src_tail = seq_src.tail; + // find the last tail of src in the pos range + while (src_tail >= 0 && (uint32_t) src_tail < cache.rs.size) { + llama_rs_cell & tail_cell = cache.rs.cells[src_tail]; + if (tail_cell.pos < p1) { + break; + } + src_tail = tail_cell.prev; } - cache.do_copy = true; + uint32_t new_head = cache.rs.size; - cache.cells[seq_id_dst].pos = cache.cells[seq_id_src].pos; + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) { + if (i == (uint32_t) src_tail) { + // need to be inserted in order, but there's only one + cache.rs.insert_seq_tail_to_cell(i, seq_id_dst); + } else { + // keep only the tail cell of the source + // assuming a copy means no rollback will be attempted afterwards + cache.rs.remove_seq_from_cell(i, seq_id_src); + if (new_head == cache.rs.size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; + } } - return; + p1 = n_past; } - // otherwise, this is the KV cache of a Transformer-like model - cache.head = 0; + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { + kv_cell.seq_id.insert(seq_id_dst); + if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; + } + } + } + } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].seq_id.insert(seq_id_dst); + return n_past; +} + +static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id) { + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (!kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= 0) cache.kv.used--; + kv_cell.pos = -1; + kv_cell.seq_id.clear(); + if (new_head == cache.kv.size) new_head = i; + } else { + kv_cell.seq_id.clear(); + kv_cell.seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.kv.size && new_head < cache.kv.head) { + cache.kv.head = new_head; } } } -static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { - uint32_t new_head = cache.size; - - for (uint32_t i = 0; i < cache.size; ++i) { - if (!cache.cells[i].has_seq_id(seq_id)) { - if (cache.cells[i].pos >= 0) cache.used--; - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) new_head = i; - } else { - cache.cells[i].seq_id.clear(); - cache.cells[i].seq_id.insert(seq_id); - } - } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; -} - -static void llama_kv_cache_seq_add( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - uint32_t new_head = cache.size; +static llama_pos llama_cache_seq_add( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.recurrent) { + llama_pos n_past = p0; + + if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; + auto & seq = cache.rs.seq_tails[seq_id]; + // follow the sequence from its tail + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + GGML_ASSERT((uint32_t) cell_id < cache.rs.size); + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + int32_t i = cell_id; + cell_id = rs_cell.prev; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos += delta; + if (rs_cell.pos < 0) { + // NOTE: this affects the other sequences which share the cell + cache.rs.clear_cell(i); + // TODO: update cache.rs.head + } } - } - return; - } - - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; - cache.cells[i].pos += delta; - cache.cells[i].delta += delta; - - if (cache.cells[i].pos < 0) { - if (!cache.cells[i].is_empty()) { - cache.used--; - } - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) { - new_head = i; - } + if (n_past <= rs_cell.pos) { + n_past = rs_cell.pos + 1; } } } - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - cache.head = new_head != cache.size ? new_head : 0; + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + cache.kv.has_shift = true; + kv_cell.pos += delta; + kv_cell.delta += delta; + + if (kv_cell.pos < 0) { + if (!kv_cell.is_empty()) { + cache.kv.used--; + } + kv_cell.pos = -1; + kv_cell.seq_id.clear(); + if (new_head == cache.kv.size) { + new_head = i; + } + } + } + if (n_past <= kv_cell.pos) { + n_past = kv_cell.pos + 1; + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.kv.head = new_head != cache.kv.size ? new_head : 0; + } + + return n_past; } -static void llama_kv_cache_seq_div( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { +static llama_pos llama_cache_seq_div( + struct llama_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); - if (cache.recurrent) { + llama_pos n_past = p0; + + if (cache.rs.size > 0) { // for Mamba-like models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - llama_kv_cell & cell = cache.cells[seq_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; + auto & seq = cache.rs.seq_tails[seq_id]; + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + GGML_ASSERT((uint32_t) cell_id < cache.rs.size); + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos /= d; } - } - return; - } - - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; - - { - llama_pos p_old = cache.cells[i].pos; - cache.cells[i].pos /= d; - cache.cells[i].delta += cache.cells[i].pos - p_old; + cell_id = rs_cell.prev; + if (n_past <= rs_cell.pos) { + n_past = rs_cell.pos + 1; } } } + + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + cache.kv.has_shift = true; + + { + llama_pos p_old = kv_cell.pos; + kv_cell.pos /= d; + kv_cell.delta += kv_cell.pos - p_old; + } + } + if (n_past <= kv_cell.pos) { + n_past = kv_cell.pos + 1; + } + } + } + } + + return n_past; } -static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { +static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { llama_pos result = 0; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id)) { - result = std::max(result, cache.cells[i].pos); + if (cache.rs.size > 0) { + int32_t cell_id = cache.rs.seq_tails[seq_id].tail; + if (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + result = rs_cell.pos; + } + // exit early + return result; + } + + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + result = std::max(result, kv_cell.pos); + } } } @@ -6009,6 +6563,7 @@ struct llm_build_context { const llama_cparams & cparams; const llama_batch & batch; const llama_kv_cache & kv_self; + const llama_rs_cache & rs_self; const int64_t n_embd; const int64_t n_layer; @@ -6034,8 +6589,10 @@ struct llm_build_context { const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + const int32_t n_rs; const int32_t n_outputs; const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t rs_head; const int32_t n_orig_ctx; const enum llama_pooling_type pooling_type; @@ -6058,7 +6615,8 @@ struct llm_build_context { hparams (model.hparams), cparams (lctx.cparams), batch (batch), - kv_self (lctx.kv_self), + kv_self (lctx.cache.kv), + rs_self (lctx.cache.rs), n_embd (hparams.n_embd), n_layer (hparams.n_layer), n_rot (hparams.n_rot), @@ -6081,8 +6639,10 @@ struct llm_build_context { norm_rms_eps (hparams.f_norm_rms_eps), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), - n_outputs (worst_case ? n_tokens : lctx.n_outputs), - kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + n_rs (worst_case ? rs_self.size : rs_self.n), + n_outputs (worst_case ? n_tokens : lctx.n_outputs), + kv_head (worst_case ? kv_self.size - n_tokens : kv_self.head), + rs_head (worst_case ? 0 : rs_self.head), n_orig_ctx (cparams.n_yarn_orig_ctx), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), @@ -6148,29 +6708,6 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_s_copy() { - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); - - GGML_ASSERT(kv_self.recurrent); - - struct ggml_tensor * state_copy = build_inp_s_copy(); - - for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); - - conv_states = ggml_get_rows(ctx0, conv_states, state_copy); - ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); - - // TODO: name the intermediate tensors with cb() - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, conv_states, kv_self.k_l[il])); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, ssm_states, kv_self.v_l[il])); - } - - return gf; - } - struct ggml_cgraph * build_defrag(const std::vector & ids) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -6267,21 +6804,21 @@ struct llm_build_context { } struct ggml_tensor * build_inp_s_copy() { - lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, kv_self.size); + lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); cb(lctx.inp_s_copy, "inp_s_copy", -1); ggml_set_input(lctx.inp_s_copy); return lctx.inp_s_copy; } struct ggml_tensor * build_inp_s_mask() { - lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); + lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_rs); cb(lctx.inp_s_mask, "inp_s_mask", -1); ggml_set_input(lctx.inp_s_mask); return lctx.inp_s_mask; } struct ggml_tensor * build_inp_s_seq() { - lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_kv, n_tokens); + lctx.inp_s_seq = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_rs, n_tokens); cb(lctx.inp_s_seq, "inp_s_seq", -1); ggml_set_input(lctx.inp_s_seq); return lctx.inp_s_seq; @@ -9269,26 +9806,31 @@ struct llm_build_context { // {n_embd, n_tokens} inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + struct ggml_tensor * state_copy = build_inp_s_copy(); struct ggml_tensor * state_mask = build_inp_s_mask(); struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { - // (ab)using the KV cache to store the states - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size); + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(), rs_self.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(), rs_self.size); + + // copy states + { + // TODO: use some sort of read-only head and n to pass smaller tensors to ggml_get_rows + // NOTE: assuming the copy destinations are ALL contained in the current batch + // this shrinks the tensors's ne[1] to n_rs + conv_states = ggml_get_rows(ctx0, conv_states, state_copy); + ssm_states = ggml_get_rows(ctx0, ssm_states, state_copy); + } // clear states of sequences which are starting at the beginning of this batch { - conv_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]), - state_mask); - ssm_states = ggml_mul(ctx0, - ggml_view_2d(ctx0, ssm_states, ssm_states->ne[0], n_kv, ssm_states->nb[1], kv_head*ssm_states->nb[1]), - state_mask); + conv_states = ggml_mul(ctx0, conv_states, state_mask); + ssm_states = ggml_mul(ctx0, ssm_states, state_mask); } - conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_kv); - ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_kv); + conv_states = ggml_reshape_3d(ctx0, conv_states, d_conv - 1, d_inner, n_rs); + ssm_states = ggml_reshape_3d(ctx0, ssm_states, d_state, d_inner, n_rs); // norm cur = llm_build_norm(ctx0, inpL, hparams, @@ -9321,8 +9863,8 @@ struct llm_build_context { // store last (d_conv - 1) columns of the conv_state part of x_conv back into the KV cache ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_kv, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), - ggml_view_1d(ctx0, kv_self.k_l[il], (d_conv - 1)*(d_inner)*(n_kv), kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); + ggml_view_2d(ctx0, x_conv, d_conv - 1, d_inner*n_rs, d_conv*ggml_element_size(x_conv), (1+d_inner*n_tokens)*ggml_element_size(x_conv)), + ggml_view_1d(ctx0, rs_self.r_l[il], (d_conv - 1)*(d_inner)*(n_rs), rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(x_conv)))); // extract x from x_conv x = ggml_view_2d(ctx0, x_conv, d_inner, n_tokens, d_inner*ggml_element_size(x_conv), 0); @@ -9348,15 +9890,15 @@ struct llm_build_context { // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. - // => {d_inner, n_tokens} and {d_state, d_inner, n_kv} combined, + // => {d_inner, n_tokens} and {d_state, d_inner, n_rs} combined, // because only a single tensor can be returned. struct ggml_tensor * y_ssm_states = ggml_ssm_scan(ctx0, ssm_states, x, dt, model.layers[il].ssm_a, B, C, state_seq); // store last states (the second part of y_ssm_states) ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_kv, d_inner*n_tokens*ggml_element_size(y_ssm_states)), - ggml_view_1d(ctx0, kv_self.v_l[il], d_state*d_inner*n_kv, kv_head*d_state*d_inner*ggml_element_size(ssm_states)))); + ggml_view_1d(ctx0, y_ssm_states, d_state*d_inner*n_rs, d_inner*n_tokens*ggml_element_size(y_ssm_states)), + ggml_view_1d(ctx0, rs_self.s_l[il], d_state*d_inner*n_rs, rs_head*d_state*d_inner*ggml_element_size(ssm_states)))); struct ggml_tensor * y = ggml_view_2d(ctx0, y_ssm_states, d_inner, n_tokens, d_inner*ggml_element_size(y_ssm_states), 0); @@ -9558,23 +10100,6 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) { return result; } -static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { - llama_batch dummy; - dummy.n_tokens = 0; - - llm_build_cb cb = [&](struct ggml_tensor * , const char * , int ) { }; - - struct llm_build_context llm(lctx, dummy, cb, false); - - llm.init(); - - struct ggml_cgraph * result = llm.build_s_copy(); - - llm.free(); - - return result; -} - static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_batch & batch, @@ -9729,26 +10254,14 @@ static struct ggml_cgraph * llama_build_graph( } static void llama_set_k_shift(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; + const int64_t kv_size = lctx.cache.kv.size; assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); int32_t * data = (int32_t *) lctx.inp_K_shift->data; for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].delta; - } -} - -static void llama_set_s_copy(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; + data[i] = lctx.cache.kv.cells[i].delta; } } @@ -9759,7 +10272,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const auto & hparams = lctx.model.hparams; const auto & cparams = lctx.cparams; - const auto & kv_self = lctx.kv_self; + const auto & kv_self = lctx.cache.kv; + const auto & rs_self = lctx.cache.rs; if (batch.token) { const int64_t n_tokens = batch.n_tokens; @@ -9835,7 +10349,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { for (int i = 0; i < n_kv; ++i) { float f; - if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { f = -INFINITY; } else { f = 0.0f; @@ -9886,7 +10400,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { float * data = (float *) lctx.inp_KQ_pos->data; for (int i = 0; i < n_kv; ++i) { - data[i] = float(lctx.kv_self.cells[i].pos); + data[i] = float(kv_self.cells[i].pos); } } @@ -9943,29 +10457,54 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (kv_self.recurrent) { - const int64_t n_kv = kv_self.n; + if (rs_self.size > 0) { + const int64_t n_rs = rs_self.n; if (lctx.inp_s_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); float * data = (float *) lctx.inp_s_mask->data; - // states which are not affected by the current batch are left untouched - for (int i = 0; i < n_kv; ++i) { - llama_seq_id seq_id = i + lctx.kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[seq_id]; - bool has_self_seq = kv_cell.has_seq_id(seq_id); + // clear unused states + for (int i = 0; i < n_rs; ++i) { + uint32_t cell_id = i + rs_self.head; + llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; - data[i] = (float) has_self_seq; + data[i] = (float) rs_cell.src >= 0; - // ensure current sequences will be kept - if (!has_self_seq && kv_cell.pos >= 0) { - kv_cell.seq_id.insert(seq_id); + // only clear once + if (rs_cell.src < 0) { + rs_cell.src = cell_id; } } } + + // checkpoints require copies between cells + if (lctx.inp_s_copy) { + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); + int32_t * data = (int32_t *) lctx.inp_s_copy->data; + + // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n + for (uint32_t i = 0; i < n_rs; ++i) { + const uint32_t cell_id = i + rs_self.head; + llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; + + // prevent out-of-bound sources + if (rs_cell.src < 0 || (uint32_t) rs_cell.src >= rs_self.size) { + rs_cell.src = cell_id; + } + + data[i] = rs_cell.src; + + // ensure copy only happens once + if (rs_cell.src != (int32_t) cell_id) { + rs_cell.src = cell_id; + } + } + } + // For Mamba (and other recurrent architectures), // update the correct state(s)/sequence(s) for each token of the batch. + // Each row contains relative cell ids of the sequences for the associated token. // Like with the KQ_mask, if a token in the batch has multiple sequences, // they are assumed to be equivalent (not here, but in ggml_ssm_scan and ggml_ssm_conv). if (lctx.inp_s_seq) { @@ -9978,12 +10517,20 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { const int32_t n_seq = batch.n_seq_id[j]; GGML_ASSERT(0 < n_seq); // a token should be part of at least 1 sequence - for (int i = 0; i < n_kv; ++i) { + for (int i = 0; i < n_rs; ++i) { if (i < n_seq) { - // for this type of model, the head is the minimum seq_id of the batch - data[j*n_kv + i] = batch.seq_id[j][i] - kv_self.head; + llama_seq_id seq_id = batch.seq_id[j][i]; + GGML_ASSERT((uint32_t) seq_id < rs_self.seq_tails.size()); + const auto & seq = rs_self.seq_tails[seq_id]; + // all sequences of this batch should already be initialized + GGML_ASSERT(seq.tail >= 0); + // ensure the relative cell id will be positive but not too big + GGML_ASSERT((uint32_t) seq.tail >= rs_self.head); + GGML_ASSERT((uint32_t) seq.tail < rs_self.head + rs_self.n); + + data[j*n_rs + i] = seq.tail - rs_self.head; } else { - data[j*n_kv + i] = -1; + data[j*n_rs + i] = -1; } } } @@ -10129,7 +10676,8 @@ static int llama_decode_internal( //ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); #endif - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; + auto & rs_self = lctx.cache.rs; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -10245,17 +10793,11 @@ static int llama_decode_internal( if (hparams.causal_attn) { llama_kv_cache_update(&lctx); - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (kv_self.head > kv_self.used + 2*n_tokens) { - kv_self.head = 0; - } - - if (!llama_kv_cache_find_slot(kv_self, u_batch)) { + if (!llama_kv_cache_find_slot(lctx.cache, u_batch)) { return 1; } - if (!kv_self.recurrent) { + if (kv_self.size > 0) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important @@ -10329,11 +10871,15 @@ static int llama_decode_internal( // update the kv ring buffer { kv_self.head += n_tokens; + rs_self.head += rs_self.n; // Ensure kv cache head points to a valid index. if (kv_self.head >= kv_self.size) { kv_self.head = 0; } + if (rs_self.head >= rs_self.size) { + rs_self.head = 0; + } } #ifdef GGML_PERF @@ -10430,7 +10976,7 @@ static int llama_decode_internal( // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; const auto & hparams = lctx.model.hparams; @@ -10651,7 +11197,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { bool need_reserve = false; // apply K-shift if needed - if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { + if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.cache.kv.has_shift) { { ggml_backend_sched_reset(lctx.sched); @@ -10667,7 +11213,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } { - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; kv_self.has_shift = false; @@ -10677,39 +11223,13 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } - if (lctx.kv_self.recurrent && lctx.kv_self.do_copy) { - { - ggml_backend_sched_reset(lctx.sched); - - ggml_cgraph * gf = llama_build_graph_s_copy(lctx); - - ggml_backend_sched_alloc_graph(lctx.sched, gf); - - llama_set_s_copy(lctx); - - llama_graph_compute(lctx, gf, lctx.cparams.n_threads); - - need_reserve = true; - } - - { - auto & kv_self = lctx.kv_self; - - kv_self.do_copy = false; - - for (uint32_t i = 0; i < kv_self.size; ++i) { - kv_self.cells[i].src = i; - } - } - } - // defragment the KV cache if needed - if (lctx.kv_self.do_defrag) { + if (lctx.cache.kv.do_defrag) { llama_kv_cache_defrag_internal(lctx); need_reserve = true; - lctx.kv_self.do_defrag = false; + lctx.cache.kv.do_defrag = false; } // reserve a worst case graph again @@ -14258,18 +14778,8 @@ struct llama_context * llama_new_context_with_model( ctx->rng = std::mt19937(params.seed); ctx->logits_all = params.logits_all; - uint32_t kv_size = cparams.n_ctx; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; - - // Mamba only needs a constant number of KV cache cells per sequence - if (model->arch == LLM_ARCH_MAMBA) { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - } + const ggml_type type_k = params.type_k; + const ggml_type type_v = params.type_v; GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); @@ -14377,25 +14887,42 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_cache_init(ctx->cache, ctx->model, type_k, type_v, cparams.n_ctx, cparams.n_seq_max, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; } - { + if (ctx->cache.rs.size > 0) { + size_t memory_size_r = 0; + size_t memory_size_s = 0; + + for (auto & r : ctx->cache.rs.r_l) { + memory_size_r += ggml_nbytes(r); + } + + for (auto & s : ctx->cache.rs.s_l) { + memory_size_s += ggml_nbytes(s); + } + + LLAMA_LOG_INFO("%s: SSM state size = %7.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), + ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f), + ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f)); + } + if (ctx->cache.kv.size > 0) { size_t memory_size_k = 0; size_t memory_size_v = 0; - for (auto & k : ctx->kv_self.k_l) { + for (auto & k : ctx->cache.kv.k_l) { memory_size_k += ggml_nbytes(k); } - for (auto & v : ctx->kv_self.v_l) { + for (auto & v : ctx->cache.kv.v_l) { memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); @@ -14513,7 +15040,11 @@ uint32_t llama_n_ubatch(const struct llama_context * ctx) { } uint32_t llama_n_seq_max(const struct llama_context * ctx) { - return ctx->kv_self.size; + if (ctx->cache.rs.size > 0) { + return ctx->cache.rs.size; + } else { + return ctx->cache.kv.size; + } } enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { @@ -14799,8 +15330,9 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { } void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) { - if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) { - view->n_cells = int32_t(ctx->kv_self.size); + const llama_kv_cache & kv_self = ctx->cache.kv; + if (uint32_t(view->n_cells) < kv_self.size || view->cells == nullptr) { + view->n_cells = int32_t(kv_self.size); void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); view->cells = (struct llama_kv_cache_view_cell *)p; @@ -14809,7 +15341,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k view->cells_sequences = (llama_seq_id *)p; } - const std::vector & kv_cells = ctx->kv_self.cells; + const std::vector & kv_cells = kv_self.cells; llama_kv_cache_view_cell * c_curr = view->cells; llama_seq_id * cs_curr = view->cells_sequences; int32_t used_cells = 0; @@ -14818,7 +15350,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k uint32_t max_contig = 0; int32_t max_contig_idx = -1; - for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) { + for (int32_t i = 0; i < int32_t(kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) { const size_t curr_size = kv_cells[i].seq_id.size(); token_count += curr_size; c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; @@ -14856,67 +15388,77 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k view->max_contiguous_idx = max_contig_idx; view->token_count = token_count; view->used_cells = used_cells; - if (uint32_t(used_cells) != ctx->kv_self.used) { + if (uint32_t(used_cells) != kv_self.used) { LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", - __func__, ctx->kv_self.used, used_cells); + __func__, kv_self.used, used_cells); } } int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) { int result = 0; - for (uint32_t i = 0; i < ctx->kv_self.size; i++) { - result += ctx->kv_self.cells[i].seq_id.size(); + for (uint32_t i = 0; i < ctx->cache.kv.size; i++) { + result += ctx->cache.kv.cells[i].seq_id.size(); } return result; } int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { - return ctx->kv_self.used; + return ctx->cache.kv.used; } void llama_kv_cache_clear(struct llama_context * ctx) { - llama_kv_cache_clear(ctx->kv_self); + llama_cache_clear(ctx->cache); } bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return false; } + llama_pos n_past = llama_cache_seq_rm(ctx->cache, seq_id, p0, p1); + return n_past >= p0; } void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + uint32_t n_seq_max = llama_n_seq_max(ctx); + if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { + return; + } if (seq_id_src == seq_id_dst) { return; } - llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); + llama_cache_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - llama_kv_cache_seq_keep(ctx->kv_self, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + llama_cache_seq_keep(ctx->cache, seq_id); } void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { if (delta == 0) { return; } + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta); + llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); } void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { if (d == 1) { return; } + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } - llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); + llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); } llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id); + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + return llama_cache_seq_pos_max(ctx->cache, seq_id); } void llama_kv_cache_defrag(struct llama_context * ctx) { - llama_kv_cache_defrag(ctx->kv_self); + llama_kv_cache_defrag(ctx->cache.kv); } void llama_kv_cache_update(struct llama_context * ctx) { @@ -14944,9 +15486,10 @@ size_t llama_get_state_size(const struct llama_context * ctx) { const size_t s_kv_head = sizeof(uint32_t); const size_t s_kv_size = sizeof(uint32_t); const size_t s_kv_used = sizeof(uint32_t); - const size_t s_kv = ctx->kv_self.total_size(); + const size_t s_kv = ctx->cache.kv.total_size(); const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + cparams.n_seq_max*sizeof(llama_seq_id); - const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell; + const size_t s_kv_cells = ctx->cache.kv.size * s_kv_cell; + // TODO: rs cache cells const size_t s_total = ( + s_rng_size @@ -15241,14 +15784,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { } } + // FIXME: set rs cache too // set kv cache { - const auto & kv_self = ctx->kv_self; + const auto & kv_self = ctx->cache.kv; const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); size_t kv_buf_size; uint32_t kv_head; @@ -15279,16 +15823,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size); inp += k_size; - if (kv_self.recurrent) { - // v is contiguous for recurrent models - // TODO: use other tensors for state models than k and v - const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*kv_head); - - ggml_backend_tensor_set(kv_self.v_l[il], inp, 0, v_size); - inp += v_size; - continue; - } - // v is not contiguous, copy row by row const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size); @@ -15303,8 +15837,8 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { llama_kv_cache_clear(ctx); - ctx->kv_self.head = kv_head; - ctx->kv_self.used = kv_used; + ctx->cache.kv.head = kv_head; + ctx->cache.kv.used = kv_used; for (uint32_t i = 0; i < kv_head; ++i) { llama_pos pos; @@ -15313,13 +15847,13 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { memcpy(&pos, inp, sizeof(pos)); inp += sizeof(pos); memcpy(&seq_id_size, inp, sizeof(seq_id_size)); inp += sizeof(seq_id_size); - ctx->kv_self.cells[i].pos = pos; + ctx->cache.kv.cells[i].pos = pos; llama_seq_id seq_id; for (size_t j = 0; j < seq_id_size; ++j) { memcpy(&seq_id, inp, sizeof(seq_id)); inp += sizeof(seq_id); - ctx->kv_self.cells[i].seq_id.insert(seq_id); + ctx->cache.kv.cells[i].seq_id.insert(seq_id); } } }