From 7e13f19fb527b62ca87930841608b7369d86173a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 24 May 2024 16:19:25 -0400 Subject: [PATCH] llama : rethink recurrent state cell counts * llama : begin work on support for variable GQA This will also be useful for Jamba if we consider the Mamba layers to have 0 KV heads. * llama : gracefully fail when not finding hybrid slot --- llama.cpp | 586 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 307 insertions(+), 279 deletions(-) diff --git a/llama.cpp b/llama.cpp index 3501163ba..969249126 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1753,6 +1753,9 @@ struct llama_hparams { uint32_t n_expert_used = 0; uint32_t n_vocab_type = 0; // for BERT-style token types + // TODO: find a more compact way to add more per-layer hyper-parameters + std::vector n_head_kv_vec; + float f_norm_eps; float f_norm_rms_eps; @@ -1793,6 +1796,8 @@ struct llama_hparams { if (this->n_expert != other.n_expert) return true; if (this->n_expert_used != other.n_expert_used) return true; + if (this->n_head_kv_vec != other.n_head_kv_vec) return true; + if (this->rope_finetuned != other.rope_finetuned) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; @@ -1812,29 +1817,46 @@ struct llama_hparams { return false; } - uint32_t n_gqa() const { + uint32_t n_head_kv_l(uint32_t layer) const { + if (layer < n_head_kv_vec.size()) { + int32_t n_hkv_l = n_head_kv_vec[layer]; + // TODO: what should happen when it's negative? + GGML_ASSERT(n_hkv_l >= 0); + return n_hkv_l; + } + return n_head_kv; + } + + uint32_t n_gqa(uint32_t layer = 0) const { + uint32_t n_head_kv = n_head_kv_l(layer); if (n_head_kv == 0) { return 0; } return n_head/n_head_kv; } - uint32_t n_embd_k_gqa() const { // dimension of key embeddings across all k-v heads + uint32_t n_embd_k_gqa(uint32_t layer = 0) const { // dimension of key embeddings across all k-v heads + uint32_t n_head_kv = n_head_kv_l(layer); return n_embd_head_k * n_head_kv; } - uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads + uint32_t n_embd_v_gqa(uint32_t layer = 0) const { // dimension of value embeddings across all k-v heads + uint32_t n_head_kv = n_head_kv_l(layer); return n_embd_head_v * n_head_kv; } - uint32_t n_embd_r() const { // dimension of the rolling state embeddings + uint32_t n_embd_r(uint32_t layer) const { // dimension of the rolling state embeddings + // TODO: support using an SSM in place of the MLP of a Transformer + if (n_head_kv_l(layer) != 0) { return 0; } // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } - uint32_t n_embd_s() const { // dimension of the recurrent state embeddings + uint32_t n_embd_s(uint32_t layer) const { // dimension of the recurrent state embeddings + // TODO: support using an SSM in place of the MLP of a Transformer + if (n_head_kv_l(layer) != 0) { return 0; } // corresponds to Mamba's ssm_states size return ssm_d_state * ssm_d_inner; } @@ -2078,10 +2100,12 @@ struct llama_rs_cache { // computed when finding a slot uint32_t n = 0; // range of states used for the last slot - // useful to know the minimum reserved cell count per seq_id - // only counts sequences which have a non-shared tail + // only counts cells which are tails of all of their sequences. + // useful to know the minimum reserved cell count per seq_id. uint32_t n_seqs = 0; - // cells part of multiple sequences AND which have at least one tail + // cells part of multiple sequences, + // but which are only the tail of some of them. + // useful to dismiss sequences used as a shared prompt uint32_t n_shared_tail_cells = 0; // with state models, a cell can hold the state for more than one past token @@ -2279,10 +2303,8 @@ struct llama_rs_cache { for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { llama_rs_cell & rs_cell = cells[cell_id]; if (!rs_cell.seq_nodes.empty()) { - if (rs_cell.seq_nodes.size() == 1) { - if (rs_cell.tail_rc == 1) { - n_seqs_verif += 1; - } + if (rs_cell.seq_nodes.size() == rs_cell.tail_rc) { + n_seqs_verif += 1; } else if (rs_cell.tail_rc > 0) { n_shared_tail_cells_verif += 1; } @@ -2308,9 +2330,11 @@ struct llama_rs_cache { } // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. + // Why an iterator? Because it allows using std::vector::erase. std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); - // TODO: assert the iterator points inside the correct vector + // The iterator needs to point inside the correct vector + GGML_ASSERT(node_iter.base() >= rs_cell.seq_nodes.data() && node_iter.base() < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); if (node_iter != rs_cell.seq_nodes.end()) { // update the tree llama_rs_seq_node node = *node_iter; @@ -2325,12 +2349,20 @@ struct llama_rs_cache { GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); prev_node->next_cell = node.next_cell; if (node.is_tail()) { + // move the tail back to the previous cell if (prev_cell.seq_nodes.size() > 1) { - if (prev_cell.tail_rc == 0) { - n_shared_tail_cells += 1; - } - if (rs_cell.seq_nodes.size() == 1) { - n_seqs -= 1; + if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) { + if (prev_cell.tail_rc == 0) { + n_shared_tail_cells += 1; + } + + // o oo oo + // |/ -> o/ + // | | + // e.g. when removing the leaf with a single tail + if (rs_cell.tail_rc == 1 && prev_cell.tail_rc != prev_cell.seq_nodes.size()) { + n_seqs -= 1; + } } } prev_cell.tail_rc += 1; @@ -2341,17 +2373,22 @@ struct llama_rs_cache { if (node.is_tail()) { seq.tail = rs_cell.prev; if (rs_cell.tail_rc == 1) { - if (rs_cell.seq_nodes.size() > 1) { - // assuming the previous cell of a shared cell is also shared, - // this was a shared tail cell, but will no longer be a tail cell - n_shared_tail_cells -= 1; - } else if (seq.tail < 0) { + if (seq.tail < 0) { // no more tail, no more sequence - n_seqs -= 1; + if (rs_cell.seq_nodes.size() > 1) { + n_shared_tail_cells -= 1; + } else { + n_seqs -= 1; + } } } GGML_ASSERT(rs_cell.tail_rc > 0); rs_cell.tail_rc -= 1; + } else if (rs_cell.tail_rc == rs_cell.seq_nodes.size() - 1) { + // will fully become a tail cell + if (rs_cell.tail_rc > 0) { + n_seqs += 1; + } } if (node_iter == rs_cell.seq_nodes.begin()) { // this seq_id was the first in the list @@ -2363,14 +2400,6 @@ struct llama_rs_cache { if ((uint32_t) next_node->seq_id < seq_tails.size()) { auto & next_seq = seq_tails[next_node->seq_id]; next_seq.n_cells += 1; - // only the tail ref count from the other seq_ids are left in tail_rc - if (rs_cell.tail_rc > 0) { - // will become a non-shared cell - if (rs_cell.seq_nodes.size() == 2) { - n_shared_tail_cells -= 1; - n_seqs += 1; - } - } } else { GGML_ASSERT(false && "invalid seq_id"); } @@ -2433,43 +2462,41 @@ struct llama_rs_cache { rs_cell.pos = prev_cell.pos + 1; rs_cell.src = prev_cell.src; } - prev_cell.tail_rc -= 1; prev_node->next_cell = i_cell; rs_cell.prev = prev; if (seq.tail == prev) { // What to do when the tail moves... - // from unique to shared (n_seqs--) - // if the new cell has one seq_id or has no tails (n_shared_tail_cells++) - // if the new cell has one seq_id and a tail (n_seqs-- (yes, another time)) - // from unique to unique (seq.n_cells++) - // from empty to unique (seq.n_cells++, n_seqs++) - // from empty to shared - // if the new cell only has one seq_id or has no tail (n_shared_tail_cells++) - // if the new cell only has one seq_id and has one tail (n_seqs--) - // from shared to shared - // if the last cell has no tails (n_shared_tail_cells--) - // if the new cell has no tails or has one seq_id (n_shared_tail_cells++) - // if the new cell only has one seq_id and has one tail (n_seqs--) - // from shared to unique (seq.n_cells++) - // if this seq_id was not the first of the last cell (n_seqs++) - // if the last cell has no tails (n_shared_tail_cells--) - if (prev_cell.seq_nodes.size() > 1) { - // from shared - if (rs_cell.is_empty()) { - // to unique - if (prev_cell.seq_nodes[0].seq_id != id) { - n_seqs += 1; - } + // (Legend: tail: O, one or more non-tails: o, one or more tails O+, empty: _) + // O -> oO (n_seqs--, n_shared_tail_cells++) + // O -> O (seq.n_cells++) + // OO+ -> oO (n_seqs--, n_shared_tail_cells += 2) + // OO+ -> O+ (n_shared_tail_cells++ (the previous cell becomes oO+)) + // _ -> oO (n_shared_tail_cells++) + // _ -> O (seq.n_cells++, n_seqs++) + // Oo -> O (seq.n_cells++, n_seqs++, n_shared_tail_cell--) + // Oo -> OO+ (n_shared_tail_cell--) + // OOo -> O (seq.n_cells++, n_seqs++) + if (prev_cell.seq_nodes.size() == prev_cell.tail_rc) { + // from fully tail + if (prev_cell.tail_rc > 1) { + // the previous tail becomes shared with a non-tail + n_shared_tail_cells += 1; } - // the previous cell is no longer a shared tail - if (prev_cell.tail_rc == 0) { + if (!rs_cell.is_empty() && rs_cell.tail_rc == 0) { + // the new tail cell was previously a fully non-tail cell + n_shared_tail_cells += 1; + n_seqs -= 1; + } + } else if (rs_cell.is_empty()) { + // from shared to unique + n_seqs += 1; + if (prev_cell.tail_rc == 1) { + // it was the last tail of the previous cell n_shared_tail_cells -= 1; } - } else if (!rs_cell.is_empty()) { - // from unique to shared - n_seqs -= 1; } } + prev_cell.tail_rc -= 1; } if (rs_cell.is_empty()) { // to unique @@ -2482,15 +2509,10 @@ struct llama_rs_cache { rs_cell.src = -1; } used += 1; - } else { + } else if (rs_cell.tail_rc == 0) { // to shared - if (rs_cell.seq_nodes.size() == 1) { - // a lone tail becomes a shared cell - if (rs_cell.tail_rc > 0) { - n_seqs -= 1; - } - n_shared_tail_cells += 1; - } else if (rs_cell.tail_rc == 0) { + if (seq.tail < 0) { + // from empty to shared n_shared_tail_cells += 1; } } @@ -2910,26 +2932,18 @@ static bool llama_cache_init( const llama_context * ctx, ggml_type type_k, ggml_type type_v, - uint32_t n_ctx, - uint32_t n_seq_max, bool offload) { const llama_model & model = ctx->model; const llama_cparams & cparams = ctx->cparams; const struct llama_hparams & hparams = model.hparams; - // TODO: per layer n_embd_* - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - const uint32_t n_embd_r = hparams.n_embd_r(); - const uint32_t n_embd_s = hparams.n_embd_s(); - const bool has_kv = hparams.n_head != 0 && hparams.causal_attn; - const bool has_r = n_embd_r != 0; - const bool has_s = n_embd_s != 0; + const bool has_kv = hparams.n_head_kv != 0 && hparams.causal_attn; + const bool has_r = hparams.ssm_d_conv != 0 && hparams.ssm_d_inner != 0; + const bool has_s = hparams.ssm_d_state != 0 && hparams.ssm_d_inner != 0; const bool has_rs = has_r || has_s; - const uint32_t kv_size = has_kv ? n_ctx : 0; - const uint32_t rs_size = has_rs ? n_seq_max : 0; - // TODO: per cache type layer count + const uint32_t kv_size = has_kv ? cparams.n_ctx : 0; + const uint32_t rs_size = has_rs ? cparams.n_seq_max : 0; const int64_t n_layer = hparams.n_layer; cache.kv.size = kv_size; @@ -2967,6 +2981,7 @@ static bool llama_cache_init( std::map ctx_map; for (auto & it : buft_layer_count) { int n_layers = it.second; + // TODO: for mixed architectures, avoid allocating empty recurrent state or kv cache tensors struct ggml_init_params params = { /*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, @@ -2995,20 +3010,20 @@ static bool llama_cache_init( for (int i = 0; i < (int) n_layer; i++) { struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); if (has_kv) { - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); cache.kv.k_l.push_back(k); cache.kv.v_l.push_back(v); } if (has_r) { - ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_r*rs_size); + ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size); ggml_format_name(r, "cache_r_l%d", i); cache.rs.r_l.push_back(r); } if (has_s) { - ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_s*rs_size); + ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size); ggml_format_name(s, "cache_s_l%d", i); cache.rs.s_l.push_back(s); } @@ -3024,7 +3039,7 @@ static bool llama_cache_init( return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s ctx buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s cache buf size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -3042,177 +3057,21 @@ static bool llama_cache_find_slot( const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; - // FIXME: on failure, leave all caches in a consistent state. - + // only check first, to allow failing gracefully if (rs_size > 0) { - // For recurrent state architectures (like Mamba), - // each cache cell can store the state for a whole sequence. - // TODO: find a way to always make the rs slot contiguous - - llama_seq_id min_seq = cache.rs.size - 1; - llama_seq_id max_seq = 0; - uint32_t min_cell = cache.rs.size - 1; - uint32_t max_cell = 0; - + // everything should fit if all seq_ids are smaller than the max for (uint32_t i = 0; i < n_tokens; ++i) { - int32_t target_cell = -1; // ensure all the sequences of a token get the same cell - int32_t n_seq_ids = batch.n_seq_id[i]; - for (int32_t j = 0; j < n_seq_ids; ++j) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; ++j) { llama_seq_id seq_id = batch.seq_id[i][j]; - bool need_new_cell = false; - // Everything should fit assuming the biggest seq_id < rs_size - if ((uint32_t) seq_id < rs_size) { - llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; - if (seq_id > max_seq) { max_seq = seq_id; } - if (seq_id < min_seq) { min_seq = seq_id; } - if (!seq.in_ubatch && target_cell >= 0) { - // never saw this seq_id before, - // but there's already a cell reserved for this token, use it - cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); - } else if (seq.tail < 0) { - need_new_cell = true; - } else { - llama_rs_cell & tail = cache.rs.cells[seq.tail]; - if (seq.in_ubatch) { - // this seq_id was already seen before in the batch - // assuming the tail cell already "has" this seq_id - tail.pos += 1; - target_cell = seq.tail; - } else { - // first time this sequence is seen, - // there's no reserved cell yet; - // if it's not the first sequence of the token, how could it even get here? - GGML_ASSERT(j == 0); - - bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; - if (has_same_seqs) { - // the tail cell of a seq_id is assumed to already be part of the seq_id, - // hence the skip of the first seq_id - for (int32_t k = 1; k < n_seq_ids; ++k) { - if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { - has_same_seqs = false; - } - } - } - - // TODO: make the checkpoint interval configurable - if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { - // a checkpoint should be saved - need_new_cell = true; - } else { - // re-use last tail - tail.pos += 1; - target_cell = seq.tail; - } - } - } - - if (need_new_cell && target_cell < 0) { - const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); - - uint32_t cell_id = cache.rs.size; - bool looped_once = false; - - while (true) { - if (cache.rs.head >= cache.rs.size) { - cache.rs.head = 0; - if (looped_once) { - // avoid infinite loop - // NOTE: this should not happen, but gracefully fail anyway - LLAMA_LOG_ERROR("%s: recurrent state cache seems full, but should not. This is a bug.\n", __func__); - return false; - } - looped_once = true; - } - cell_id = cache.rs.head; - llama_rs_cell & candidate = cache.rs.cells[cell_id]; - if (candidate.is_empty()) { break; } - if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { - // the candidate is the old tail - if (candidate.seq_nodes.size() > 1) { - // prune out the other seq_ids, because they diverge - // TODO(maybe): hande this in insert_seq_tail_to_cell_id - // (hopefully doesn't happen too often) - for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { - if (node_iter->seq_id == seq_id) { - node_iter = std::next(node_iter); - } else { - node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); - } - } - } - // re-use the tail cell to avoid not finding anything - candidate.pos += 1; - break; - } - if (candidate.tail_rc > 0) { - // skip tails of other sequences - cache.rs.head += 1; - continue; - } - if (candidate.seq_nodes.size() > 1) { - // shared prompts are not usually backtracked, so they can be pruned - cache.rs.clear_cell(candidate); - break; - } - - // prune too-long sequences - llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; - if (seq_id_to_prune == seq_id) { - // TODO: selectively skip some cells to keep older states - cache.rs.clear_cell(candidate); - break; - } - GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); - auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; - if (seq_to_prune.n_cells > min_cells_per_seq) { - cache.rs.clear_cell(candidate); - break; - } - cache.rs.head += 1; - } - if (cell_id < cache.rs.size) { - cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); - target_cell = cell_id; - } - } - - if (seq.tail >= 0) { - if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } - if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } - seq.in_ubatch = true; - } - - // Assuming the tokens are in-order - if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); - } - } else { + if (seq_id < 0 || (uint32_t) seq_id >= rs_size) { // too big seq_id // TODO: would it be possible to resize the rs cache size instead? LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } } - cache.rs.head = target_cell + 1; - } - - for (llama_seq_id i = min_seq; i <= max_seq; ++i) { - // make sure it's cleared for next time - cache.rs.seq_tails[i].in_ubatch = false; - } - - // allow getting the range of used cells, from head to head + n - cache.rs.head = min_cell; - cache.rs.n = max_cell - min_cell + 1; - - // sanity check - if (max_seq < min_seq || max_cell < min_cell) { - return false; } } @@ -3257,7 +3116,174 @@ static bool llama_cache_find_slot( return false; } } + } + // now modification can be done, and should NOT fail + + if (rs_size > 0) { + // For recurrent state architectures (like Mamba), + // each cache cell can store the state for a whole sequence. + // TODO: find a way to always make the rs slot contiguous + + llama_seq_id min_seq = cache.rs.size - 1; + llama_seq_id max_seq = 0; + uint32_t min_cell = cache.rs.size - 1; + uint32_t max_cell = 0; + + for (uint32_t i = 0; i < n_tokens; ++i) { + int32_t target_cell = -1; // ensure all the sequences of a token get the same cell + int32_t n_seq_ids = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_ids; ++j) { + llama_seq_id seq_id = batch.seq_id[i][j]; + bool need_new_cell = false; + // Everything should fit assuming the biggest seq_id < rs_size + GGML_ASSERT((uint32_t) seq_id < rs_size); + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + if (seq_id > max_seq) { max_seq = seq_id; } + if (seq_id < min_seq) { min_seq = seq_id; } + + if (!seq.in_ubatch && target_cell >= 0) { + // never saw this seq_id before, + // but there's already a cell reserved for this token, use it + cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id); + } else if (seq.tail < 0) { + // this seq_id has no tail (and is empty) + need_new_cell = true; + } else { + llama_rs_cell & tail = cache.rs.cells[seq.tail]; + if (seq.in_ubatch) { + // this seq_id was already seen before in the batch + // assuming the tail cell already "has" this seq_id + tail.pos += 1; + target_cell = seq.tail; + } else { + // first time this sequence is seen, + // there's no reserved cell yet; + // if it's not the first sequence of the token, how could it even get here? + GGML_ASSERT(j == 0); + + bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids; + if (has_same_seqs) { + // the tail cell of a seq_id is assumed to already be part of the seq_id, + // hence the skip of the first seq_id + for (int32_t k = 1; k < n_seq_ids; ++k) { + if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) { + has_same_seqs = false; + } + } + } + + // TODO: make the checkpoint interval configurable + if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) { + // a checkpoint should be saved + need_new_cell = true; + } else { + // re-use last tail + tail.pos += 1; + target_cell = seq.tail; + } + } + } + + // reserve a cell for this seq_id + if (need_new_cell && target_cell < 0) { + const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq); + + uint32_t cell_id = cache.rs.size; + bool looped_once = false; + + while (true) { + if (cache.rs.head >= cache.rs.size) { + cache.rs.head = 0; + // avoid infinite loop + // NOTE: this should not fail; if it does, it's a bug. + GGML_ASSERT(!looped_once && "recurrent state cache seems full, but should not."); + looped_once = true; + } + cell_id = cache.rs.head; + llama_rs_cell & candidate = cache.rs.cells[cell_id]; + if (candidate.is_empty()) { break; } + if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { + // the candidate is the old tail + if (candidate.seq_nodes.size() > 1) { + // prune out the other seq_ids, because they diverge + // TODO(maybe): hande this in insert_seq_tail_to_cell_id + // (hopefully doesn't happen too often) + for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) { + if (node_iter->seq_id == seq_id) { + node_iter = std::next(node_iter); + } else { + node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter); + } + } + } + // re-use the tail cell to avoid not finding anything + candidate.pos += 1; + break; + } + if (candidate.tail_rc > 0) { + // skip tails of other sequences + cache.rs.head += 1; + continue; + } + if (candidate.seq_nodes.size() > 1) { + // shared prompts are not usually backtracked, so they can be pruned + cache.rs.clear_cell(candidate); + break; + } + + // prune too-long sequences + llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id; + if (seq_id_to_prune == seq_id) { + // TODO: selectively skip some cells to keep older states + cache.rs.clear_cell(candidate); + break; + } + GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size()); + auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune]; + if (seq_to_prune.n_cells > min_cells_per_seq) { + cache.rs.clear_cell(candidate); + break; + } + cache.rs.head += 1; + } + if (cell_id < cache.rs.size) { + cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id); + target_cell = cell_id; + } + } + + if (seq.tail >= 0) { + if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; } + if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; } + seq.in_ubatch = true; + } + + // Assuming the tokens are in-order + if (batch.pos[i] != cache.rs.cells[seq.tail].pos) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", + __func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id); + } + } + cache.rs.head = target_cell + 1; + } + + for (llama_seq_id i = min_seq; i <= max_seq; ++i) { + // make sure it's cleared for next time + cache.rs.seq_tails[i].in_ubatch = false; + } + + // allow getting the range of used cells, from head to head + n + cache.rs.head = min_cell; + cache.rs.n = max_cell - min_cell + 1; + + // sanity check + GGML_ASSERT(min_seq <= max_seq && min_cell <= max_cell); + } + + if (kv_size > 0) { for (uint32_t i = 0; i < n_tokens; i++) { cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; @@ -4194,9 +4220,9 @@ struct llama_model_loader { bool get_arr(const std::string & key, std::vector & result, const bool required = true) { const int kid = gguf_find_key(meta, key.c_str()); - if (kid < 0) { + if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) { if (required) { - throw std::runtime_error(format("key not found in model: %s", key.c_str())); + throw std::runtime_error(format("array key not found in model: %s", key.c_str())); } return false; } @@ -4204,16 +4230,17 @@ struct llama_model_loader { struct GGUFMeta::ArrayInfo arr_info = GGUFMeta::GKV::get_kv(meta, kid); - if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { - throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); + // TODO: allow ANY lossless cast + // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); + switch (arr_info.gt) { + case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same::value)); break; + case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same::value)); break; + default: + throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str())); } - // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same::value)); - GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same::value)); - - result.resize(arr_info.length); - result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length); + result.reserve(arr_info.length); + result.assign((const T *)arr_info.data, (const T *)arr_info.data + arr_info.length); return true; } @@ -4750,7 +4777,12 @@ static void llm_load_hparams( // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; - ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + + // per-layer n_head_kv + if (!ml.get_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_vec, false)) { + // global/fallback n_head_kv + ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false); + } bool rope_finetuned = false; ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); @@ -6704,10 +6736,7 @@ static bool llm_load_tensors( model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading - const int64_t n_ff = hparams.n_ff; const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); for (uint32_t i = 0; i < n_layer; ++i) { ggml_context * ctx_layer = ctx_for_layer(i); @@ -7198,8 +7227,8 @@ static void llm_build_kv_store( int64_t il) { const int64_t n_ctx = cparams.n_ctx; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); GGML_ASSERT(kv.size == n_ctx); @@ -7465,9 +7494,9 @@ static struct ggml_tensor * llm_build_kqv( int il) { const int64_t n_ctx = cparams.n_ctx; const int64_t n_head = hparams.n_head; - const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_head_kv = hparams.n_head_kv_l(il); const int64_t n_embd_head_k = hparams.n_embd_head_k; - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_head_v = hparams.n_embd_head_v; const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); @@ -7619,9 +7648,7 @@ struct llm_build_context { const int64_t n_head; const int64_t n_head_kv; const int64_t n_embd_head_k; - const int64_t n_embd_k_gqa; const int64_t n_embd_head_v; - const int64_t n_embd_v_gqa; const int64_t n_expert; const int64_t n_expert_used; @@ -7673,9 +7700,7 @@ struct llm_build_context { n_head (hparams.n_head), n_head_kv (hparams.n_head_kv), n_embd_head_k (hparams.n_embd_head_k), - n_embd_k_gqa (hparams.n_embd_k_gqa()), n_embd_head_v (hparams.n_embd_head_v), - n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), n_expert_used (hparams.n_expert_used), freq_base (cparams.rope_freq_base), @@ -7746,9 +7771,9 @@ struct llm_build_context { // we rotate only the first n_rot dimensions ggml_rope_ext_inplace(ctx0, ggml_view_3d(ctx0, kv_self.k_l[il], - n_embd_head_k, n_head_kv, n_ctx, + n_embd_head_k, hparams.n_head_kv_l(il), n_ctx, ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, hparams.n_embd_k_gqa(il)), 0), lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); @@ -7777,6 +7802,9 @@ struct llm_build_context { } for (int il = 0; il < n_layer; ++il) { + int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il], n_embd_k_gqa, nm, ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), @@ -11014,8 +11042,8 @@ struct llm_build_context { struct ggml_tensor * state_seq = build_inp_s_seq(); for (int il = 0; il < n_layer; ++il) { - struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(), rs_self.size); - struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(), rs_self.size); + struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(il), rs_self.size); + struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(il), rs_self.size); // copy states { @@ -16452,7 +16480,7 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.n_ctx, cparams.n_seq_max, cparams.offload_kqv)) { + if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; @@ -17282,7 +17310,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // NOTE: kv_size and kv_buf_size are mostly used for sanity checks @@ -17434,7 +17462,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) { const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); size_t kv_buf_size; @@ -17627,7 +17655,7 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); for (uint32_t i = 0; i < kv_self.size; ++i) { @@ -17713,7 +17741,7 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); // Write the layer count @@ -17843,7 +17871,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // Sanity check model compatibility const auto & hparams = ctx->model.hparams; const uint32_t n_layer = hparams.n_layer; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); if (n_layer != n_layer_ref) { LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);