llama : rethink recurrent state cell counts

* llama : begin work on support for variable GQA

This will also be useful for Jamba if we consider the Mamba layers
to have 0 KV heads.

* llama : gracefully fail when not finding hybrid slot
This commit is contained in:
Francis Couture-Harpin 2024-05-24 16:19:25 -04:00
parent 3b57b55c6f
commit 7e13f19fb5

586
llama.cpp
View File

@ -1753,6 +1753,9 @@ struct llama_hparams {
uint32_t n_expert_used = 0; uint32_t n_expert_used = 0;
uint32_t n_vocab_type = 0; // for BERT-style token types uint32_t n_vocab_type = 0; // for BERT-style token types
// TODO: find a more compact way to add more per-layer hyper-parameters
std::vector<int32_t> n_head_kv_vec;
float f_norm_eps; float f_norm_eps;
float f_norm_rms_eps; float f_norm_rms_eps;
@ -1793,6 +1796,8 @@ struct llama_hparams {
if (this->n_expert != other.n_expert) return true; if (this->n_expert != other.n_expert) return true;
if (this->n_expert_used != other.n_expert_used) return true; if (this->n_expert_used != other.n_expert_used) return true;
if (this->n_head_kv_vec != other.n_head_kv_vec) return true;
if (this->rope_finetuned != other.rope_finetuned) return true; if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true; if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
@ -1812,29 +1817,46 @@ struct llama_hparams {
return false; return false;
} }
uint32_t n_gqa() const { uint32_t n_head_kv_l(uint32_t layer) const {
if (layer < n_head_kv_vec.size()) {
int32_t n_hkv_l = n_head_kv_vec[layer];
// TODO: what should happen when it's negative?
GGML_ASSERT(n_hkv_l >= 0);
return n_hkv_l;
}
return n_head_kv;
}
uint32_t n_gqa(uint32_t layer = 0) const {
uint32_t n_head_kv = n_head_kv_l(layer);
if (n_head_kv == 0) { if (n_head_kv == 0) {
return 0; return 0;
} }
return n_head/n_head_kv; return n_head/n_head_kv;
} }
uint32_t n_embd_k_gqa() const { // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t layer = 0) const { // dimension of key embeddings across all k-v heads
uint32_t n_head_kv = n_head_kv_l(layer);
return n_embd_head_k * n_head_kv; return n_embd_head_k * n_head_kv;
} }
uint32_t n_embd_v_gqa() const { // dimension of value embeddings across all k-v heads uint32_t n_embd_v_gqa(uint32_t layer = 0) const { // dimension of value embeddings across all k-v heads
uint32_t n_head_kv = n_head_kv_l(layer);
return n_embd_head_v * n_head_kv; return n_embd_head_v * n_head_kv;
} }
uint32_t n_embd_r() const { // dimension of the rolling state embeddings uint32_t n_embd_r(uint32_t layer) const { // dimension of the rolling state embeddings
// TODO: support using an SSM in place of the MLP of a Transformer
if (n_head_kv_l(layer) != 0) { return 0; }
// corresponds to Mamba's conv_states size // corresponds to Mamba's conv_states size
// TODO: maybe support other convolution strides than 1 // TODO: maybe support other convolution strides than 1
// NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed
return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner;
} }
uint32_t n_embd_s() const { // dimension of the recurrent state embeddings uint32_t n_embd_s(uint32_t layer) const { // dimension of the recurrent state embeddings
// TODO: support using an SSM in place of the MLP of a Transformer
if (n_head_kv_l(layer) != 0) { return 0; }
// corresponds to Mamba's ssm_states size // corresponds to Mamba's ssm_states size
return ssm_d_state * ssm_d_inner; return ssm_d_state * ssm_d_inner;
} }
@ -2078,10 +2100,12 @@ struct llama_rs_cache {
// computed when finding a slot // computed when finding a slot
uint32_t n = 0; // range of states used for the last slot uint32_t n = 0; // range of states used for the last slot
// useful to know the minimum reserved cell count per seq_id // only counts cells which are tails of all of their sequences.
// only counts sequences which have a non-shared tail // useful to know the minimum reserved cell count per seq_id.
uint32_t n_seqs = 0; uint32_t n_seqs = 0;
// cells part of multiple sequences AND which have at least one tail // cells part of multiple sequences,
// but which are only the tail of some of them.
// useful to dismiss sequences used as a shared prompt
uint32_t n_shared_tail_cells = 0; uint32_t n_shared_tail_cells = 0;
// with state models, a cell can hold the state for more than one past token // with state models, a cell can hold the state for more than one past token
@ -2279,10 +2303,8 @@ struct llama_rs_cache {
for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) {
llama_rs_cell & rs_cell = cells[cell_id]; llama_rs_cell & rs_cell = cells[cell_id];
if (!rs_cell.seq_nodes.empty()) { if (!rs_cell.seq_nodes.empty()) {
if (rs_cell.seq_nodes.size() == 1) { if (rs_cell.seq_nodes.size() == rs_cell.tail_rc) {
if (rs_cell.tail_rc == 1) { n_seqs_verif += 1;
n_seqs_verif += 1;
}
} else if (rs_cell.tail_rc > 0) { } else if (rs_cell.tail_rc > 0) {
n_shared_tail_cells_verif += 1; n_shared_tail_cells_verif += 1;
} }
@ -2308,9 +2330,11 @@ struct llama_rs_cache {
} }
// returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed.
// Why an iterator? Because it allows using std::vector<T>::erase.
std::vector<llama_rs_seq_node>::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector<llama_rs_seq_node>::iterator node_iter) { std::vector<llama_rs_seq_node>::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector<llama_rs_seq_node>::iterator node_iter) {
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
// TODO: assert the iterator points inside the correct vector // The iterator needs to point inside the correct vector
GGML_ASSERT(node_iter.base() >= rs_cell.seq_nodes.data() && node_iter.base() < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size());
if (node_iter != rs_cell.seq_nodes.end()) { if (node_iter != rs_cell.seq_nodes.end()) {
// update the tree // update the tree
llama_rs_seq_node node = *node_iter; llama_rs_seq_node node = *node_iter;
@ -2325,12 +2349,20 @@ struct llama_rs_cache {
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
prev_node->next_cell = node.next_cell; prev_node->next_cell = node.next_cell;
if (node.is_tail()) { if (node.is_tail()) {
// move the tail back to the previous cell
if (prev_cell.seq_nodes.size() > 1) { if (prev_cell.seq_nodes.size() > 1) {
if (prev_cell.tail_rc == 0) { if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) {
n_shared_tail_cells += 1; if (prev_cell.tail_rc == 0) {
} n_shared_tail_cells += 1;
if (rs_cell.seq_nodes.size() == 1) { }
n_seqs -= 1;
// o oo oo
// |/ -> o/
// | |
// e.g. when removing the leaf with a single tail
if (rs_cell.tail_rc == 1 && prev_cell.tail_rc != prev_cell.seq_nodes.size()) {
n_seqs -= 1;
}
} }
} }
prev_cell.tail_rc += 1; prev_cell.tail_rc += 1;
@ -2341,17 +2373,22 @@ struct llama_rs_cache {
if (node.is_tail()) { if (node.is_tail()) {
seq.tail = rs_cell.prev; seq.tail = rs_cell.prev;
if (rs_cell.tail_rc == 1) { if (rs_cell.tail_rc == 1) {
if (rs_cell.seq_nodes.size() > 1) { if (seq.tail < 0) {
// assuming the previous cell of a shared cell is also shared,
// this was a shared tail cell, but will no longer be a tail cell
n_shared_tail_cells -= 1;
} else if (seq.tail < 0) {
// no more tail, no more sequence // no more tail, no more sequence
n_seqs -= 1; if (rs_cell.seq_nodes.size() > 1) {
n_shared_tail_cells -= 1;
} else {
n_seqs -= 1;
}
} }
} }
GGML_ASSERT(rs_cell.tail_rc > 0); GGML_ASSERT(rs_cell.tail_rc > 0);
rs_cell.tail_rc -= 1; rs_cell.tail_rc -= 1;
} else if (rs_cell.tail_rc == rs_cell.seq_nodes.size() - 1) {
// will fully become a tail cell
if (rs_cell.tail_rc > 0) {
n_seqs += 1;
}
} }
if (node_iter == rs_cell.seq_nodes.begin()) { if (node_iter == rs_cell.seq_nodes.begin()) {
// this seq_id was the first in the list // this seq_id was the first in the list
@ -2363,14 +2400,6 @@ struct llama_rs_cache {
if ((uint32_t) next_node->seq_id < seq_tails.size()) { if ((uint32_t) next_node->seq_id < seq_tails.size()) {
auto & next_seq = seq_tails[next_node->seq_id]; auto & next_seq = seq_tails[next_node->seq_id];
next_seq.n_cells += 1; next_seq.n_cells += 1;
// only the tail ref count from the other seq_ids are left in tail_rc
if (rs_cell.tail_rc > 0) {
// will become a non-shared cell
if (rs_cell.seq_nodes.size() == 2) {
n_shared_tail_cells -= 1;
n_seqs += 1;
}
}
} else { } else {
GGML_ASSERT(false && "invalid seq_id"); GGML_ASSERT(false && "invalid seq_id");
} }
@ -2433,43 +2462,41 @@ struct llama_rs_cache {
rs_cell.pos = prev_cell.pos + 1; rs_cell.pos = prev_cell.pos + 1;
rs_cell.src = prev_cell.src; rs_cell.src = prev_cell.src;
} }
prev_cell.tail_rc -= 1;
prev_node->next_cell = i_cell; prev_node->next_cell = i_cell;
rs_cell.prev = prev; rs_cell.prev = prev;
if (seq.tail == prev) { if (seq.tail == prev) {
// What to do when the tail moves... // What to do when the tail moves...
// from unique to shared (n_seqs--) // (Legend: tail: O, one or more non-tails: o, one or more tails O+, empty: _)
// if the new cell has one seq_id or has no tails (n_shared_tail_cells++) // O -> oO (n_seqs--, n_shared_tail_cells++)
// if the new cell has one seq_id and a tail (n_seqs-- (yes, another time)) // O -> O (seq.n_cells++)
// from unique to unique (seq.n_cells++) // OO+ -> oO (n_seqs--, n_shared_tail_cells += 2)
// from empty to unique (seq.n_cells++, n_seqs++) // OO+ -> O+ (n_shared_tail_cells++ (the previous cell becomes oO+))
// from empty to shared // _ -> oO (n_shared_tail_cells++)
// if the new cell only has one seq_id or has no tail (n_shared_tail_cells++) // _ -> O (seq.n_cells++, n_seqs++)
// if the new cell only has one seq_id and has one tail (n_seqs--) // Oo -> O (seq.n_cells++, n_seqs++, n_shared_tail_cell--)
// from shared to shared // Oo -> OO+ (n_shared_tail_cell--)
// if the last cell has no tails (n_shared_tail_cells--) // OOo -> O (seq.n_cells++, n_seqs++)
// if the new cell has no tails or has one seq_id (n_shared_tail_cells++) if (prev_cell.seq_nodes.size() == prev_cell.tail_rc) {
// if the new cell only has one seq_id and has one tail (n_seqs--) // from fully tail
// from shared to unique (seq.n_cells++) if (prev_cell.tail_rc > 1) {
// if this seq_id was not the first of the last cell (n_seqs++) // the previous tail becomes shared with a non-tail
// if the last cell has no tails (n_shared_tail_cells--) n_shared_tail_cells += 1;
if (prev_cell.seq_nodes.size() > 1) {
// from shared
if (rs_cell.is_empty()) {
// to unique
if (prev_cell.seq_nodes[0].seq_id != id) {
n_seqs += 1;
}
} }
// the previous cell is no longer a shared tail if (!rs_cell.is_empty() && rs_cell.tail_rc == 0) {
if (prev_cell.tail_rc == 0) { // the new tail cell was previously a fully non-tail cell
n_shared_tail_cells += 1;
n_seqs -= 1;
}
} else if (rs_cell.is_empty()) {
// from shared to unique
n_seqs += 1;
if (prev_cell.tail_rc == 1) {
// it was the last tail of the previous cell
n_shared_tail_cells -= 1; n_shared_tail_cells -= 1;
} }
} else if (!rs_cell.is_empty()) {
// from unique to shared
n_seqs -= 1;
} }
} }
prev_cell.tail_rc -= 1;
} }
if (rs_cell.is_empty()) { if (rs_cell.is_empty()) {
// to unique // to unique
@ -2482,15 +2509,10 @@ struct llama_rs_cache {
rs_cell.src = -1; rs_cell.src = -1;
} }
used += 1; used += 1;
} else { } else if (rs_cell.tail_rc == 0) {
// to shared // to shared
if (rs_cell.seq_nodes.size() == 1) { if (seq.tail < 0) {
// a lone tail becomes a shared cell // from empty to shared
if (rs_cell.tail_rc > 0) {
n_seqs -= 1;
}
n_shared_tail_cells += 1;
} else if (rs_cell.tail_rc == 0) {
n_shared_tail_cells += 1; n_shared_tail_cells += 1;
} }
} }
@ -2910,26 +2932,18 @@ static bool llama_cache_init(
const llama_context * ctx, const llama_context * ctx,
ggml_type type_k, ggml_type type_k,
ggml_type type_v, ggml_type type_v,
uint32_t n_ctx,
uint32_t n_seq_max,
bool offload) { bool offload) {
const llama_model & model = ctx->model; const llama_model & model = ctx->model;
const llama_cparams & cparams = ctx->cparams; const llama_cparams & cparams = ctx->cparams;
const struct llama_hparams & hparams = model.hparams; const struct llama_hparams & hparams = model.hparams;
// TODO: per layer n_embd_* const bool has_kv = hparams.n_head_kv != 0 && hparams.causal_attn;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const bool has_r = hparams.ssm_d_conv != 0 && hparams.ssm_d_inner != 0;
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const bool has_s = hparams.ssm_d_state != 0 && hparams.ssm_d_inner != 0;
const uint32_t n_embd_r = hparams.n_embd_r();
const uint32_t n_embd_s = hparams.n_embd_s();
const bool has_kv = hparams.n_head != 0 && hparams.causal_attn;
const bool has_r = n_embd_r != 0;
const bool has_s = n_embd_s != 0;
const bool has_rs = has_r || has_s; const bool has_rs = has_r || has_s;
const uint32_t kv_size = has_kv ? n_ctx : 0; const uint32_t kv_size = has_kv ? cparams.n_ctx : 0;
const uint32_t rs_size = has_rs ? n_seq_max : 0; const uint32_t rs_size = has_rs ? cparams.n_seq_max : 0;
// TODO: per cache type layer count
const int64_t n_layer = hparams.n_layer; const int64_t n_layer = hparams.n_layer;
cache.kv.size = kv_size; cache.kv.size = kv_size;
@ -2967,6 +2981,7 @@ static bool llama_cache_init(
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
for (auto & it : buft_layer_count) { for (auto & it : buft_layer_count) {
int n_layers = it.second; int n_layers = it.second;
// TODO: for mixed architectures, avoid allocating empty recurrent state or kv cache tensors
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(), /*.mem_size =*/ (2*has_kv + has_r+has_s)*n_layers*ggml_tensor_overhead(),
/*.mem_buffer =*/ NULL, /*.mem_buffer =*/ NULL,
@ -2995,20 +3010,20 @@ static bool llama_cache_init(
for (int i = 0; i < (int) n_layer; i++) { for (int i = 0; i < (int) n_layer; i++) {
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
if (has_kv) { if (has_kv) {
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size);
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size);
ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(k, "cache_k_l%d", i);
ggml_format_name(v, "cache_v_l%d", i); ggml_format_name(v, "cache_v_l%d", i);
cache.kv.k_l.push_back(k); cache.kv.k_l.push_back(k);
cache.kv.v_l.push_back(v); cache.kv.v_l.push_back(v);
} }
if (has_r) { if (has_r) {
ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_r*rs_size); ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size);
ggml_format_name(r, "cache_r_l%d", i); ggml_format_name(r, "cache_r_l%d", i);
cache.rs.r_l.push_back(r); cache.rs.r_l.push_back(r);
} }
if (has_s) { if (has_s) {
ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd_s*rs_size); ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size);
ggml_format_name(s, "cache_s_l%d", i); ggml_format_name(s, "cache_s_l%d", i);
cache.rs.s_l.push_back(s); cache.rs.s_l.push_back(s);
} }
@ -3024,7 +3039,7 @@ static bool llama_cache_init(
return false; return false;
} }
ggml_backend_buffer_clear(buf, 0); ggml_backend_buffer_clear(buf, 0);
LLAMA_LOG_INFO("%s: %10s ctx buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); LLAMA_LOG_INFO("%s: %10s cache buf size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0);
cache.bufs.push_back(buf); cache.bufs.push_back(buf);
} }
@ -3042,177 +3057,21 @@ static bool llama_cache_find_slot(
const uint32_t rs_size = cache.rs.size; const uint32_t rs_size = cache.rs.size;
const uint32_t n_tokens = batch.n_tokens; const uint32_t n_tokens = batch.n_tokens;
// FIXME: on failure, leave all caches in a consistent state. // only check first, to allow failing gracefully
if (rs_size > 0) { if (rs_size > 0) {
// For recurrent state architectures (like Mamba), // everything should fit if all seq_ids are smaller than the max
// each cache cell can store the state for a whole sequence.
// TODO: find a way to always make the rs slot contiguous
llama_seq_id min_seq = cache.rs.size - 1;
llama_seq_id max_seq = 0;
uint32_t min_cell = cache.rs.size - 1;
uint32_t max_cell = 0;
for (uint32_t i = 0; i < n_tokens; ++i) { for (uint32_t i = 0; i < n_tokens; ++i) {
int32_t target_cell = -1; // ensure all the sequences of a token get the same cell int32_t n_seq_id = batch.n_seq_id[i];
int32_t n_seq_ids = batch.n_seq_id[i]; for (int32_t j = 0; j < n_seq_id; ++j) {
for (int32_t j = 0; j < n_seq_ids; ++j) {
llama_seq_id seq_id = batch.seq_id[i][j]; llama_seq_id seq_id = batch.seq_id[i][j];
bool need_new_cell = false;
// Everything should fit assuming the biggest seq_id < rs_size
if ((uint32_t) seq_id < rs_size) {
llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id];
if (seq_id > max_seq) { max_seq = seq_id; }
if (seq_id < min_seq) { min_seq = seq_id; }
if (!seq.in_ubatch && target_cell >= 0) { if (seq_id < 0 || (uint32_t) seq_id >= rs_size) {
// never saw this seq_id before,
// but there's already a cell reserved for this token, use it
cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id);
} else if (seq.tail < 0) {
need_new_cell = true;
} else {
llama_rs_cell & tail = cache.rs.cells[seq.tail];
if (seq.in_ubatch) {
// this seq_id was already seen before in the batch
// assuming the tail cell already "has" this seq_id
tail.pos += 1;
target_cell = seq.tail;
} else {
// first time this sequence is seen,
// there's no reserved cell yet;
// if it's not the first sequence of the token, how could it even get here?
GGML_ASSERT(j == 0);
bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids;
if (has_same_seqs) {
// the tail cell of a seq_id is assumed to already be part of the seq_id,
// hence the skip of the first seq_id
for (int32_t k = 1; k < n_seq_ids; ++k) {
if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) {
has_same_seqs = false;
}
}
}
// TODO: make the checkpoint interval configurable
if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) {
// a checkpoint should be saved
need_new_cell = true;
} else {
// re-use last tail
tail.pos += 1;
target_cell = seq.tail;
}
}
}
if (need_new_cell && target_cell < 0) {
const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq);
uint32_t cell_id = cache.rs.size;
bool looped_once = false;
while (true) {
if (cache.rs.head >= cache.rs.size) {
cache.rs.head = 0;
if (looped_once) {
// avoid infinite loop
// NOTE: this should not happen, but gracefully fail anyway
LLAMA_LOG_ERROR("%s: recurrent state cache seems full, but should not. This is a bug.\n", __func__);
return false;
}
looped_once = true;
}
cell_id = cache.rs.head;
llama_rs_cell & candidate = cache.rs.cells[cell_id];
if (candidate.is_empty()) { break; }
if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) {
// the candidate is the old tail
if (candidate.seq_nodes.size() > 1) {
// prune out the other seq_ids, because they diverge
// TODO(maybe): hande this in insert_seq_tail_to_cell_id
// (hopefully doesn't happen too often)
for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) {
if (node_iter->seq_id == seq_id) {
node_iter = std::next(node_iter);
} else {
node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter);
}
}
}
// re-use the tail cell to avoid not finding anything
candidate.pos += 1;
break;
}
if (candidate.tail_rc > 0) {
// skip tails of other sequences
cache.rs.head += 1;
continue;
}
if (candidate.seq_nodes.size() > 1) {
// shared prompts are not usually backtracked, so they can be pruned
cache.rs.clear_cell(candidate);
break;
}
// prune too-long sequences
llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id;
if (seq_id_to_prune == seq_id) {
// TODO: selectively skip some cells to keep older states
cache.rs.clear_cell(candidate);
break;
}
GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size());
auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune];
if (seq_to_prune.n_cells > min_cells_per_seq) {
cache.rs.clear_cell(candidate);
break;
}
cache.rs.head += 1;
}
if (cell_id < cache.rs.size) {
cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id);
target_cell = cell_id;
}
}
if (seq.tail >= 0) {
if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; }
if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; }
seq.in_ubatch = true;
}
// Assuming the tokens are in-order
if (batch.pos[i] != cache.rs.cells[seq.tail].pos) {
// What should happen when the pos backtracks or skips a value?
// Clearing the state mid-batch would require special-casing which isn't done.
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
__func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id);
}
} else {
// too big seq_id // too big seq_id
// TODO: would it be possible to resize the rs cache size instead? // TODO: would it be possible to resize the rs cache size instead?
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size);
return false; return false;
} }
} }
cache.rs.head = target_cell + 1;
}
for (llama_seq_id i = min_seq; i <= max_seq; ++i) {
// make sure it's cleared for next time
cache.rs.seq_tails[i].in_ubatch = false;
}
// allow getting the range of used cells, from head to head + n
cache.rs.head = min_cell;
cache.rs.n = max_cell - min_cell + 1;
// sanity check
if (max_seq < min_seq || max_cell < min_cell) {
return false;
} }
} }
@ -3257,7 +3116,174 @@ static bool llama_cache_find_slot(
return false; return false;
} }
} }
}
// now modification can be done, and should NOT fail
if (rs_size > 0) {
// For recurrent state architectures (like Mamba),
// each cache cell can store the state for a whole sequence.
// TODO: find a way to always make the rs slot contiguous
llama_seq_id min_seq = cache.rs.size - 1;
llama_seq_id max_seq = 0;
uint32_t min_cell = cache.rs.size - 1;
uint32_t max_cell = 0;
for (uint32_t i = 0; i < n_tokens; ++i) {
int32_t target_cell = -1; // ensure all the sequences of a token get the same cell
int32_t n_seq_ids = batch.n_seq_id[i];
for (int32_t j = 0; j < n_seq_ids; ++j) {
llama_seq_id seq_id = batch.seq_id[i][j];
bool need_new_cell = false;
// Everything should fit assuming the biggest seq_id < rs_size
GGML_ASSERT((uint32_t) seq_id < rs_size);
llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id];
if (seq_id > max_seq) { max_seq = seq_id; }
if (seq_id < min_seq) { min_seq = seq_id; }
if (!seq.in_ubatch && target_cell >= 0) {
// never saw this seq_id before,
// but there's already a cell reserved for this token, use it
cache.rs.insert_seq_tail_to_cell_id(target_cell, seq_id);
} else if (seq.tail < 0) {
// this seq_id has no tail (and is empty)
need_new_cell = true;
} else {
llama_rs_cell & tail = cache.rs.cells[seq.tail];
if (seq.in_ubatch) {
// this seq_id was already seen before in the batch
// assuming the tail cell already "has" this seq_id
tail.pos += 1;
target_cell = seq.tail;
} else {
// first time this sequence is seen,
// there's no reserved cell yet;
// if it's not the first sequence of the token, how could it even get here?
GGML_ASSERT(j == 0);
bool has_same_seqs = tail.seq_nodes.size() == (size_t) n_seq_ids;
if (has_same_seqs) {
// the tail cell of a seq_id is assumed to already be part of the seq_id,
// hence the skip of the first seq_id
for (int32_t k = 1; k < n_seq_ids; ++k) {
if (batch.seq_id[i][k] != tail.seq_nodes[k].seq_id) {
has_same_seqs = false;
}
}
}
// TODO: make the checkpoint interval configurable
if (!has_same_seqs || tail.prev < 0 || tail.pos - cache.rs.cells[tail.prev].pos >= 16) {
// a checkpoint should be saved
need_new_cell = true;
} else {
// re-use last tail
tail.pos += 1;
target_cell = seq.tail;
}
}
}
// reserve a cell for this seq_id
if (need_new_cell && target_cell < 0) {
const int32_t min_cells_per_seq = cache.rs.min_cells_per_seq(seq);
uint32_t cell_id = cache.rs.size;
bool looped_once = false;
while (true) {
if (cache.rs.head >= cache.rs.size) {
cache.rs.head = 0;
// avoid infinite loop
// NOTE: this should not fail; if it does, it's a bug.
GGML_ASSERT(!looped_once && "recurrent state cache seems full, but should not.");
looped_once = true;
}
cell_id = cache.rs.head;
llama_rs_cell & candidate = cache.rs.cells[cell_id];
if (candidate.is_empty()) { break; }
if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) {
// the candidate is the old tail
if (candidate.seq_nodes.size() > 1) {
// prune out the other seq_ids, because they diverge
// TODO(maybe): hande this in insert_seq_tail_to_cell_id
// (hopefully doesn't happen too often)
for (auto node_iter = candidate.seq_nodes.begin(); node_iter != candidate.seq_nodes.end();) {
if (node_iter->seq_id == seq_id) {
node_iter = std::next(node_iter);
} else {
node_iter = cache.rs.remove_seq_node_from_cell(candidate, node_iter);
}
}
}
// re-use the tail cell to avoid not finding anything
candidate.pos += 1;
break;
}
if (candidate.tail_rc > 0) {
// skip tails of other sequences
cache.rs.head += 1;
continue;
}
if (candidate.seq_nodes.size() > 1) {
// shared prompts are not usually backtracked, so they can be pruned
cache.rs.clear_cell(candidate);
break;
}
// prune too-long sequences
llama_seq_id seq_id_to_prune = candidate.seq_nodes[0].seq_id;
if (seq_id_to_prune == seq_id) {
// TODO: selectively skip some cells to keep older states
cache.rs.clear_cell(candidate);
break;
}
GGML_ASSERT((size_t) seq_id_to_prune < cache.rs.seq_tails.size());
auto & seq_to_prune = cache.rs.seq_tails[seq_id_to_prune];
if (seq_to_prune.n_cells > min_cells_per_seq) {
cache.rs.clear_cell(candidate);
break;
}
cache.rs.head += 1;
}
if (cell_id < cache.rs.size) {
cache.rs.insert_seq_tail_to_cell_id(cell_id, seq_id);
target_cell = cell_id;
}
}
if (seq.tail >= 0) {
if (min_cell > (uint32_t) seq.tail) { min_cell = seq.tail; }
if (max_cell < (uint32_t) seq.tail) { max_cell = seq.tail; }
seq.in_ubatch = true;
}
// Assuming the tokens are in-order
if (batch.pos[i] != cache.rs.cells[seq.tail].pos) {
// What should happen when the pos backtracks or skips a value?
// Clearing the state mid-batch would require special-casing which isn't done.
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
__func__, batch.pos[i], cache.rs.cells[cache.rs.head].pos - 1, seq_id);
}
}
cache.rs.head = target_cell + 1;
}
for (llama_seq_id i = min_seq; i <= max_seq; ++i) {
// make sure it's cleared for next time
cache.rs.seq_tails[i].in_ubatch = false;
}
// allow getting the range of used cells, from head to head + n
cache.rs.head = min_cell;
cache.rs.n = max_cell - min_cell + 1;
// sanity check
GGML_ASSERT(min_seq <= max_seq && min_cell <= max_cell);
}
if (kv_size > 0) {
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {
cache.kv.cells[cache.kv.head + i].pos = batch.pos[i]; cache.kv.cells[cache.kv.head + i].pos = batch.pos[i];
@ -4194,9 +4220,9 @@ struct llama_model_loader {
bool get_arr(const std::string & key, std::vector<T> & result, const bool required = true) { bool get_arr(const std::string & key, std::vector<T> & result, const bool required = true) {
const int kid = gguf_find_key(meta, key.c_str()); const int kid = gguf_find_key(meta, key.c_str());
if (kid < 0) { if (kid < 0 || gguf_get_kv_type(meta, kid) != GGUF_TYPE_ARRAY) {
if (required) { if (required) {
throw std::runtime_error(format("key not found in model: %s", key.c_str())); throw std::runtime_error(format("array key not found in model: %s", key.c_str()));
} }
return false; return false;
} }
@ -4204,16 +4230,17 @@ struct llama_model_loader {
struct GGUFMeta::ArrayInfo arr_info = struct GGUFMeta::ArrayInfo arr_info =
GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid); GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(meta, kid);
if (arr_info.gt != GGUF_TYPE_FLOAT32 && arr_info.gt != GGUF_TYPE_INT32) { // TODO: allow ANY lossless cast
throw std::runtime_error(format("%s is not a float32 or int32 array", key.c_str())); // GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T));
switch (arr_info.gt) {
case GGUF_TYPE_FLOAT32: GGML_ASSERT((std::is_same<T, float>::value)); break;
case GGUF_TYPE_INT32: GGML_ASSERT((std::is_same<T, int32_t>::value)); break;
default:
throw std::runtime_error(format("%s is not a float32, int32 array", key.c_str()));
} }
// GGML_ASSERT(gguf_type_size(arr_info.gt) == sizeof(T)); result.reserve(arr_info.length);
GGML_ASSERT((arr_info.gt != GGUF_TYPE_FLOAT32 || std::is_same<T, float>::value)); result.assign((const T *)arr_info.data, (const T *)arr_info.data + arr_info.length);
GGML_ASSERT((arr_info.gt != GGUF_TYPE_INT32 || std::is_same<T, int>::value));
result.resize(arr_info.length);
result.assign((const T*)arr_info.data, (const T *)arr_info.data + arr_info.length);
return true; return true;
} }
@ -4750,7 +4777,12 @@ static void llm_load_hparams(
// n_head_kv is optional, default to n_head // n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head; hparams.n_head_kv = hparams.n_head;
ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false);
// per-layer n_head_kv
if (!ml.get_arr(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv_vec, false)) {
// global/fallback n_head_kv
ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT_KV, hparams.n_head_kv, false);
}
bool rope_finetuned = false; bool rope_finetuned = false;
ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false); ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
@ -6704,10 +6736,7 @@ static bool llm_load_tensors(
model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); // same as tok_embd, duplicated to allow offloading
const int64_t n_ff = hparams.n_ff;
const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
for (uint32_t i = 0; i < n_layer; ++i) { for (uint32_t i = 0; i < n_layer; ++i) {
ggml_context * ctx_layer = ctx_for_layer(i); ggml_context * ctx_layer = ctx_for_layer(i);
@ -7198,8 +7227,8 @@ static void llm_build_kv_store(
int64_t il) { int64_t il) {
const int64_t n_ctx = cparams.n_ctx; const int64_t n_ctx = cparams.n_ctx;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
GGML_ASSERT(kv.size == n_ctx); GGML_ASSERT(kv.size == n_ctx);
@ -7465,9 +7494,9 @@ static struct ggml_tensor * llm_build_kqv(
int il) { int il) {
const int64_t n_ctx = cparams.n_ctx; const int64_t n_ctx = cparams.n_ctx;
const int64_t n_head = hparams.n_head; const int64_t n_head = hparams.n_head;
const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_head_kv = hparams.n_head_kv_l(il);
const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_head_v = hparams.n_embd_head_v; const int64_t n_embd_head_v = hparams.n_embd_head_v;
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
@ -7619,9 +7648,7 @@ struct llm_build_context {
const int64_t n_head; const int64_t n_head;
const int64_t n_head_kv; const int64_t n_head_kv;
const int64_t n_embd_head_k; const int64_t n_embd_head_k;
const int64_t n_embd_k_gqa;
const int64_t n_embd_head_v; const int64_t n_embd_head_v;
const int64_t n_embd_v_gqa;
const int64_t n_expert; const int64_t n_expert;
const int64_t n_expert_used; const int64_t n_expert_used;
@ -7673,9 +7700,7 @@ struct llm_build_context {
n_head (hparams.n_head), n_head (hparams.n_head),
n_head_kv (hparams.n_head_kv), n_head_kv (hparams.n_head_kv),
n_embd_head_k (hparams.n_embd_head_k), n_embd_head_k (hparams.n_embd_head_k),
n_embd_k_gqa (hparams.n_embd_k_gqa()),
n_embd_head_v (hparams.n_embd_head_v), n_embd_head_v (hparams.n_embd_head_v),
n_embd_v_gqa (hparams.n_embd_v_gqa()),
n_expert (hparams.n_expert), n_expert (hparams.n_expert),
n_expert_used (hparams.n_expert_used), n_expert_used (hparams.n_expert_used),
freq_base (cparams.rope_freq_base), freq_base (cparams.rope_freq_base),
@ -7746,9 +7771,9 @@ struct llm_build_context {
// we rotate only the first n_rot dimensions // we rotate only the first n_rot dimensions
ggml_rope_ext_inplace(ctx0, ggml_rope_ext_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k_l[il], ggml_view_3d(ctx0, kv_self.k_l[il],
n_embd_head_k, n_head_kv, n_ctx, n_embd_head_k, hparams.n_head_kv_l(il), n_ctx,
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, hparams.n_embd_k_gqa(il)),
0), 0),
lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
@ -7777,6 +7802,9 @@ struct llm_build_context {
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il], ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
n_embd_k_gqa, nm, n_embd_k_gqa, nm,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
@ -11014,8 +11042,8 @@ struct llm_build_context {
struct ggml_tensor * state_seq = build_inp_s_seq(); struct ggml_tensor * state_seq = build_inp_s_seq();
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(), rs_self.size); struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, rs_self.r_l[il], hparams.n_embd_r(il), rs_self.size);
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(), rs_self.size); struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, rs_self.s_l[il], hparams.n_embd_s(il), rs_self.size);
// copy states // copy states
{ {
@ -16452,7 +16480,7 @@ struct llama_context * llama_new_context_with_model(
} }
ctx->backends.push_back(ctx->backend_cpu); ctx->backends.push_back(ctx->backend_cpu);
if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.n_ctx, cparams.n_seq_max, cparams.offload_kqv)) { if (!llama_cache_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) {
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
llama_free(ctx); llama_free(ctx);
return nullptr; return nullptr;
@ -17282,7 +17310,7 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
@ -17434,7 +17462,7 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
size_t kv_buf_size; size_t kv_buf_size;
@ -17627,7 +17655,7 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
for (uint32_t i = 0; i < kv_self.size; ++i) { for (uint32_t i = 0; i < kv_self.size; ++i) {
@ -17713,7 +17741,7 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
// Write the layer count // Write the layer count
@ -17843,7 +17871,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
// Sanity check model compatibility // Sanity check model compatibility
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); // FIXME: per layer
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
if (n_layer != n_layer_ref) { if (n_layer != n_layer_ref) {
LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref); LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);