llama : correctly handle more edge cases for the rs cache

This commit is contained in:
Francis Couture-Harpin 2024-04-09 17:35:22 -04:00
parent 0028010d01
commit 0c8b3b2095

407
llama.cpp
View File

@ -2034,7 +2034,7 @@ struct llama_rs_cache {
uint32_t n = 0; // range of states used for the last slot uint32_t n = 0; // range of states used for the last slot
// useful to know the minimum reserved cell count per seq_id // useful to know the minimum reserved cell count per seq_id
// only counts sequences with n_cells > 0 AND which have a non-shared tail // only counts sequences which have a non-shared tail
uint32_t n_seqs = 0; uint32_t n_seqs = 0;
// cells part of multiple sequences AND which have at least one tail // cells part of multiple sequences AND which have at least one tail
uint32_t n_shared_tail_cells = 0; uint32_t n_shared_tail_cells = 0;
@ -2082,21 +2082,37 @@ struct llama_rs_cache {
llama_rs_cell & cell = cells[cell_id]; llama_rs_cell & cell = cells[cell_id];
if (cell.seq_nodes.empty()) { if (cell.seq_nodes.empty()) {
if (cell.pos >= 0) { if (cell.pos >= 0) {
if (debug) {
LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n",
__func__, cell_id, cell.pos);
}
cell.pos = -1; cell.pos = -1;
was_valid = false; was_valid = false;
} }
} }
if (cell.pos < 0) { if (cell.pos < 0) {
if (cell.pos != -1) { if (cell.pos != -1) {
if (debug) {
LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n",
__func__, cell_id, cell.pos);
}
cell.pos = -1; cell.pos = -1;
was_valid = false; was_valid = false;
} }
if (!cell.seq_nodes.empty()) { if (!cell.seq_nodes.empty()) {
if (debug) {
LLAMA_LOG_ERROR("%s: cells[%d] has %zu seq_ids while it's empty (should have none)\n",
__func__, cell_id, cell.seq_nodes.size());
}
cell.seq_nodes.clear(); cell.seq_nodes.clear();
was_valid = false; was_valid = false;
} }
cell.src = -1; cell.src = -1;
if (cell.prev != -1) { if (cell.prev != -1) {
if (debug) {
LLAMA_LOG_ERROR("%s: cells[%d].prev is %d while it's empty (should be -1)\n",
__func__, cell_id, cell.prev);
}
cell.prev = -1; cell.prev = -1;
was_valid = false; was_valid = false;
} }
@ -2213,17 +2229,15 @@ struct llama_rs_cache {
// n_seqs // n_seqs
uint32_t n_seqs_verif = 0; uint32_t n_seqs_verif = 0;
uint32_t n_shared_tail_cells_verif = 0; uint32_t n_shared_tail_cells_verif = 0;
for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) {
auto & seq = seq_tails[seq_id]; llama_rs_cell & rs_cell = cells[cell_id];
if (seq.tail >= 0) { if (!rs_cell.seq_nodes.empty()) {
llama_rs_cell & tail_cell = cells[seq.tail]; if (rs_cell.seq_nodes.size() == 1) {
// NOTE: could also have checked if n_cells > 0 if (rs_cell.tail_rc == 1) {
if (!tail_cell.seq_nodes.empty() && tail_cell.seq_nodes[0].seq_id == seq_id) {
if (tail_cell.seq_nodes.size() > 1) {
n_shared_tail_cells_verif += 1;
} else {
n_seqs_verif += 1; n_seqs_verif += 1;
} }
} else if (rs_cell.tail_rc > 0) {
n_shared_tail_cells_verif += 1;
} }
} }
} }
@ -2246,72 +2260,15 @@ struct llama_rs_cache {
return was_valid; return was_valid;
} }
// returns whether or not a cell was freed
void clear_cell(llama_rs_cell & rs_cell) {
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
if (!rs_cell.is_empty()) {
// update sequence tree links
bool first = true;
for (const llama_rs_seq_node & node : rs_cell.seq_nodes) {
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
// NOTE: if all next cells are the same cell, this should still work
cells[node.next_cell].prev = rs_cell.prev;
}
// next_cell of the nodes of the previous cell
if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) {
llama_rs_cell & prev_cell = cells[rs_cell.prev];
auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node);
// assuming the previous node is always found
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
prev_node->next_cell = node.next_cell;
if (node.is_tail()) {
prev_cell.tail_rc += 1;
}
}
if ((uint32_t) node.seq_id < seq_tails.size()) {
auto & seq = seq_tails[node.seq_id];
// update tail
if (node.is_tail()) {
seq.tail = rs_cell.prev;
}
// cell counts
if (first) {
seq.n_cells -= 1;
if (rs_cell.tail_rc > 0 && seq.tail < 0) {
// last tail cell
if (rs_cell.seq_nodes.size() > 1) {
n_shared_tail_cells -= 1;
} else {
n_seqs -= 1;
}
}
first = false;
}
} else {
GGML_ASSERT(false && "invalid seq_id");
}
}
rs_cell.pos = -1;
rs_cell.src = -1;
rs_cell.prev = -1;
rs_cell.tail_rc = 0;
rs_cell.seq_nodes.clear();
used -= 1;
}
}
// returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed.
std::vector<llama_rs_seq_node>::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector<llama_rs_seq_node>::iterator node_iter) { std::vector<llama_rs_seq_node>::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector<llama_rs_seq_node>::iterator node_iter) {
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
// TODO: assert the iterator points inside the correct vector // TODO: assert the iterator points inside the correct vector
if (node_iter != rs_cell.seq_nodes.end()) { if (node_iter != rs_cell.seq_nodes.end()) {
if (rs_cell.seq_nodes.size() == 1) { // update the tree
clear_cell(rs_cell);
return rs_cell.seq_nodes.end();
}
// else update tree
llama_rs_seq_node node = *node_iter; llama_rs_seq_node node = *node_iter;
if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) {
// NOTE: because of this, partially removing seq_ids from cells should only be done from the tail
cells[node.next_cell].prev = rs_cell.prev; cells[node.next_cell].prev = rs_cell.prev;
} }
if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) {
@ -2321,6 +2278,14 @@ struct llama_rs_cache {
GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); GGML_ASSERT(prev_node != prev_cell.seq_nodes.end());
prev_node->next_cell = node.next_cell; prev_node->next_cell = node.next_cell;
if (node.is_tail()) { if (node.is_tail()) {
if (prev_cell.seq_nodes.size() > 1) {
if (prev_cell.tail_rc == 0) {
n_shared_tail_cells += 1;
}
if (rs_cell.seq_nodes.size() == 1) {
n_seqs -= 1;
}
}
prev_cell.tail_rc += 1; prev_cell.tail_rc += 1;
} }
} }
@ -2328,11 +2293,15 @@ struct llama_rs_cache {
auto & seq = seq_tails[node.seq_id]; auto & seq = seq_tails[node.seq_id];
if (node.is_tail()) { if (node.is_tail()) {
seq.tail = rs_cell.prev; seq.tail = rs_cell.prev;
if (seq.tail < 0 && rs_cell.tail_rc == 1) { if (rs_cell.tail_rc == 1) {
// assuming the previous cell of a shared cell is also shared, if (rs_cell.seq_nodes.size() > 1) {
// (no need to update the shared tail cells count elsewhere, then) // assuming the previous cell of a shared cell is also shared,
// this was a shared tail cell, but will no longer be a tail cell // this was a shared tail cell, but will no longer be a tail cell
n_shared_tail_cells -= 1; n_shared_tail_cells -= 1;
} else if (seq.tail < 0) {
// no more tail, no more sequence
n_seqs -= 1;
}
} }
GGML_ASSERT(rs_cell.tail_rc > 0); GGML_ASSERT(rs_cell.tail_rc > 0);
rs_cell.tail_rc -= 1; rs_cell.tail_rc -= 1;
@ -2341,21 +2310,30 @@ struct llama_rs_cache {
// this seq_id was the first in the list // this seq_id was the first in the list
seq.n_cells -= 1; seq.n_cells -= 1;
// the next node is the new first one, so update its n_cells auto next_node = std::next(node_iter);
// (will never be out-of-bounds because the size is > 1) if (next_node != rs_cell.seq_nodes.end()) {
llama_rs_seq_node next_node = *(std::next(node_iter)); // the next node is the new first one, so update its n_cells
if ((uint32_t) next_node.seq_id < seq_tails.size()) { if ((uint32_t) next_node->seq_id < seq_tails.size()) {
auto & next_seq = seq_tails[next_node.seq_id]; auto & next_seq = seq_tails[next_node->seq_id];
next_seq.n_cells += 1; next_seq.n_cells += 1;
// only the tail ref count from the other seq_ids are left in tail_rc // only the tail ref count from the other seq_ids are left in tail_rc
if (rs_cell.tail_rc > 0) { if (rs_cell.tail_rc > 0) {
// will become a non-shared cell // will become a non-shared cell
if (rs_cell.seq_nodes.size() == 2) { if (rs_cell.seq_nodes.size() == 2) {
n_seqs += 1; n_shared_tail_cells -= 1;
n_seqs += 1;
}
} }
} else {
GGML_ASSERT(false && "invalid seq_id");
} }
} else { } else {
GGML_ASSERT(false && "invalid seq_id"); // this was the last seq_id of the cell
used -= 1;
rs_cell.pos = -1;
rs_cell.src = -1;
rs_cell.prev = -1;
// the other fields *should* have already been updated elsewhere
} }
} }
} else { } else {
@ -2366,6 +2344,13 @@ struct llama_rs_cache {
return node_iter; return node_iter;
} }
void clear_cell(llama_rs_cell & rs_cell) {
GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size());
for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) {
node_iter = remove_seq_node_from_cell(rs_cell, node_iter);
}
}
// returns whether or not the seq_id was removed // returns whether or not the seq_id was removed
bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) {
if (i_cell < size && (size_t) id < size) { if (i_cell < size && (size_t) id < size) {
@ -2404,47 +2389,63 @@ struct llama_rs_cache {
prev_cell.tail_rc -= 1; prev_cell.tail_rc -= 1;
prev_node->next_cell = i_cell; prev_node->next_cell = i_cell;
rs_cell.prev = prev; rs_cell.prev = prev;
if (seq.tail == prev) {
// What to do when the tail moves...
// from unique to shared (n_seqs--)
// if the new cell has one seq_id or has no tails (n_shared_tail_cells++)
// if the new cell has one seq_id and a tail (n_seqs-- (yes, another time))
// from unique to unique (seq.n_cells++)
// from empty to unique (seq.n_cells++, n_seqs++)
// from empty to shared
// if the new cell only has one seq_id or has no tail (n_shared_tail_cells++)
// if the new cell only has one seq_id and has one tail (n_seqs--)
// from shared to shared
// if the last cell has no tails (n_shared_tail_cells--)
// if the new cell has no tails or has one seq_id (n_shared_tail_cells++)
// if the new cell only has one seq_id and has one tail (n_seqs--)
// from shared to unique (seq.n_cells++)
// if this seq_id was not the first of the last cell (n_seqs++)
// if the last cell has no tails (n_shared_tail_cells--)
if (prev_cell.seq_nodes.size() > 1) {
// from shared
if (rs_cell.is_empty()) {
// to unique
if (prev_cell.seq_nodes[0].seq_id != id) {
n_seqs += 1;
}
}
// the previous cell is no longer a shared tail
if (prev_cell.tail_rc == 0) {
n_shared_tail_cells -= 1;
}
} else if (!rs_cell.is_empty()) {
// from unique to shared
n_seqs -= 1;
}
}
} }
if (rs_cell.is_empty()) { if (rs_cell.is_empty()) {
// either the sequence didn't own any cells or had a shared tail cell // to unique
if (seq.n_cells == 0 || (seq.tail >= 0 && cells[seq.tail].seq_nodes.size() > 1)) {
n_seqs += 1;
}
seq.n_cells += 1; seq.n_cells += 1;
// set pos if still unset if (seq.tail < 0) {
if (rs_cell.pos < 0) { // from empty to unique
n_seqs += 1;
// pos was not yet set
rs_cell.pos = 0; rs_cell.pos = 0;
rs_cell.src = -1; rs_cell.src = -1;
} }
used += 1; used += 1;
} else if (rs_cell.seq_nodes.size() == 1 && rs_cell.tail_rc == 1) { } else {
// don't count shared-cell tails // to shared
// FIXME: make this saner if (rs_cell.seq_nodes.size() == 1) {
n_seqs -= 1; // a lone tail becomes a shared cell
n_shared_tail_cells += 1; if (rs_cell.tail_rc > 0) {
} else if (rs_cell.tail_rc == 0) { n_seqs -= 1;
// shared cell without a tail gets a tail;
// FIXME: don't prune, in case this is used in llama_cache_seq_cp
GGML_ASSERT(false); // make sure we don't get here by accident
// prune the other sequences out of this cell
// NOTE: have to inline the removal because the state tree is partially invalid
bool first = true;
for (auto & node : rs_cell.seq_nodes) {
GGML_ASSERT(node.seq_id != id);
GGML_ASSERT(node.next_cell >= 0);
// easy removal, none of the nodes are tails
llama_rs_cell & next_cell = cells[node.next_cell];
next_cell.prev = rs_cell.prev;
if (first) {
auto & first_seq = seq_tails[node.seq_id];
first_seq.n_cells -= 1;
first = false;
} }
n_shared_tail_cells += 1;
} else if (rs_cell.tail_rc == 0) {
n_shared_tail_cells += 1;
} }
rs_cell.seq_nodes.clear();
} else if (rs_cell.seq_nodes.size() != rs_cell.tail_rc) {
// this is correct as long as this isn't called when trying to find a slot
// TODO: find a way to assert this
} }
// the target cell was not already a tail of this seq_id // the target cell was not already a tail of this seq_id
rs_cell.insert_node(id); // next_cell == -1 by default rs_cell.insert_node(id); // next_cell == -1 by default
@ -2977,6 +2978,7 @@ static bool llama_kv_cache_find_slot(
llama_rs_cell & candidate = cache.rs.cells[cell_id]; llama_rs_cell & candidate = cache.rs.cells[cell_id];
if (candidate.is_empty()) { break; } if (candidate.is_empty()) { break; }
if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) { if (candidate.tail_rc == 1 && seq.tail == (int32_t) cell_id) {
// the candidate is the old tail
if (candidate.seq_nodes.size() > 1) { if (candidate.seq_nodes.size() > 1) {
// prune out the other seq_ids, because they diverge // prune out the other seq_ids, because they diverge
// TODO(maybe): hande this in insert_seq_tail_to_cell_id // TODO(maybe): hande this in insert_seq_tail_to_cell_id
@ -3198,40 +3200,42 @@ static llama_pos llama_cache_seq_rm(
llama_pos new_p0 = 0; llama_pos new_p0 = 0;
llama_pos new_p1 = std::numeric_limits<llama_pos>::max(); llama_pos new_p1 = std::numeric_limits<llama_pos>::max();
for (uint32_t i = 0; i < cache.rs.size; ++i) { // partial seq_id removal has to happen from the tail
llama_rs_cell & rs_cell = cache.rs.cells[i]; llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id];
auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); int32_t cell_id = seq.tail;
while (cell_id >= 0) {
llama_rs_cell & rs_cell = cache.rs.cells[cell_id];
// copy before the cell is potentially changed
int32_t prev_id = rs_cell.prev;
if (rs_cell.pos >= p1 && rs_cell.seq_nodes.size() > 1) {
// non-tail removal for shared cells can only be done when clearing a cell
// (i.e. when the next cell's link to the previous cell can be safely changed)
p1 = rs_cell.pos + 1;
}
if (rs_cell.pos >= p0 && rs_cell.pos < p1) {
auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id);
// if the node isn't found, the sequence tree is malformed
GGML_ASSERT(node_iter != rs_cell.seq_nodes.end());
cache.rs.remove_seq_node_from_cell(rs_cell, node_iter);
// get the smallest removed cell id
if (new_head > (uint32_t) cell_id) { new_head = cell_id; }
} else {
// one more than the biggest non-removed cell of this sequence
if (rs_cell.pos >= n_past) { n_past = rs_cell.pos + 1; }
if (seq_id < 0 || seq_node != rs_cell.seq_nodes.end()) {
if (rs_cell.pos < p0) { if (rs_cell.pos < p0) {
// move forward the new p0 further // new_p0 should be right after the max pos in the states before p0
if (rs_cell.pos >= new_p0) { if (rs_cell.pos >= new_p0) { new_p0 = rs_cell.pos + 1; }
new_p0 = rs_cell.pos + 1; } else { // (rs_cell.pos >= p1)
} // new_p1 should be the min pos in the states after p1
} else if (rs_cell.pos >= p1) { if (rs_cell.pos < new_p1) { new_p1 = rs_cell.pos; }
// move back the new p1 further
if (rs_cell.pos < new_p1) {
new_p1 = rs_cell.pos;
}
if (rs_cell.pos >= n_past) {
n_past = rs_cell.pos + 1;
}
} else { // (rs_cell.pos >= p0 && rs_cell.pos < p1)
if (seq_id < 0) {
cache.rs.clear_cell(rs_cell);
} else { // (rs_cell.has_seq_id(seq_id))
cache.rs.remove_seq_node_from_cell(rs_cell, seq_node);
}
if (rs_cell.is_empty() && new_head == cache.rs.size) {
new_head = i;
}
} }
} }
cell_id = prev_id;
} }
p0 = new_p0; p0 = new_p0;
p1 = new_p1; p1 = new_p1;
// correctly set n_past when there's nothing after p1
if (n_past < p0) { n_past = p0; }
// If we freed up a slot, set head to it so searching can start there. // If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.rs.size && new_head < cache.rs.head) { if (new_head != cache.rs.size && new_head < cache.rs.head) {
@ -3259,10 +3263,8 @@ static llama_pos llama_cache_seq_rm(
kv_cell.pos = -1; kv_cell.pos = -1;
if (new_head == cache.kv.size) { new_head = i; } if (new_head == cache.kv.size) { new_head = i; }
} }
} else { } else if (kv_cell.pos >= n_past) {
if (kv_cell.pos >= n_past) { n_past = kv_cell.pos + 1;
n_past = kv_cell.pos + 1;
}
} }
} }
} }
@ -3292,42 +3294,37 @@ static llama_pos llama_cache_seq_cp(
llama_pos n_past = 0; llama_pos n_past = 0;
if (cache.rs.size > 0) { if (cache.rs.size > 0) {
// have to start from beginning for recurrent models // have to start from the beginning for recurrent models
p0 = 0; p0 = 0;
if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) {
auto seq_src = cache.rs.seq_tails[seq_id_src]; int32_t src_head = -1;
int32_t src_tail = seq_src.tail; int32_t head_pos = p1;
// find the last tail of src in the pos range int32_t src_next = -1;
while (src_tail >= 0 && (uint32_t) src_tail < cache.rs.size) { // find the start of the sequence
llama_rs_cell & tail_cell = cache.rs.cells[src_tail];
if (tail_cell.pos < p1) {
break;
}
src_tail = tail_cell.prev;
}
uint32_t new_head = cache.rs.size;
for (uint32_t i = 0; i < cache.rs.size; ++i) { for (uint32_t i = 0; i < cache.rs.size; ++i) {
llama_rs_cell & rs_cell = cache.rs.cells[i]; llama_rs_cell & rs_cell = cache.rs.cells[i];
if (rs_cell.pos >= p0 && rs_cell.pos < p1 && rs_cell.has_seq_id(seq_id_src)) { if (!rs_cell.is_empty() && rs_cell.prev < 0) {
if (i == (uint32_t) src_tail) { auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src);
// need to be inserted in order, but there's only one if (seq_node != rs_cell.seq_nodes.end()) {
cache.rs.insert_seq_tail_to_cell_id(i, seq_id_dst); src_head = i;
} else { head_pos = rs_cell.pos;
// keep only the tail cell of the source src_next = seq_node->next_cell;
// assuming a copy means no rollback will be attempted afterwards break;
cache.rs.remove_seq_from_cell_id(i, seq_id_src);
if (new_head == cache.rs.size) {
new_head = i;
}
} }
} }
} }
while (src_head >= 0 && head_pos < p1) {
// If we freed up a slot, set head to it so searching can start there. cache.rs.insert_seq_tail_to_cell_id(src_head, seq_id_dst);
if (new_head != cache.rs.size && new_head < cache.rs.head) { src_head = src_next;
cache.rs.head = new_head; if (head_pos >= n_past) { n_past = head_pos + 1; }
if (src_next >= 0) {
llama_rs_cell & rs_cell = cache.rs.cells[src_next];
auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src);
head_pos = rs_cell.pos;
// it should always be found if the seq tree is valid
GGML_ASSERT(seq_node != rs_cell.seq_nodes.end());
src_next = seq_node->next_cell;
}
} }
} }
p1 = n_past; p1 = n_past;
@ -3338,9 +3335,7 @@ static llama_pos llama_cache_seq_cp(
llama_kv_cell & kv_cell = cache.kv.cells[i]; llama_kv_cell & kv_cell = cache.kv.cells[i];
if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) {
kv_cell.seq_id.insert(seq_id_dst); kv_cell.seq_id.insert(seq_id_dst);
if (kv_cell.pos >= n_past) { if (kv_cell.pos >= n_past) { n_past = kv_cell.pos + 1; }
n_past = kv_cell.pos + 1;
}
} }
} }
} }
@ -3352,18 +3347,19 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id
if (cache.rs.size > 0) { if (cache.rs.size > 0) {
uint32_t new_head = cache.rs.size; uint32_t new_head = cache.rs.size;
for (uint32_t i = 0; i < cache.rs.size; ++i) { // partial seq_id removal has to happen from the tail(s)
llama_rs_cell & rs_cell = cache.rs.cells[i]; for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) {
if (!rs_cell.seq_nodes.empty()) { if (i == (uint32_t) seq_id) { continue; }
for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { llama_rs_seq_meta & seq = cache.rs.seq_tails[i];
if (node_iter->seq_id != seq_id) { int32_t cell_id = seq.tail;
node_iter = cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); while (cell_id >= 0) {
} else { llama_rs_cell & rs_cell = cache.rs.cells[cell_id];
node_iter = std::next(node_iter); auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), i);
} GGML_ASSERT(node_iter != rs_cell.seq_nodes.end());
} cache.rs.remove_seq_node_from_cell(rs_cell, node_iter);
if (new_head == cache.rs.size && rs_cell.is_empty()) { cell_id = rs_cell.prev;
new_head = i; if (new_head > (uint32_t) cell_id && rs_cell.is_empty()) {
new_head = cell_id;
} }
} }
} }
@ -3414,6 +3410,7 @@ static llama_pos llama_cache_seq_add(
auto & seq = cache.rs.seq_tails[seq_id]; auto & seq = cache.rs.seq_tails[seq_id];
// follow the sequence from its tail // follow the sequence from its tail
int32_t cell_id = seq.tail; int32_t cell_id = seq.tail;
uint32_t new_head = cache.rs.size;
while (cell_id >= 0) { while (cell_id >= 0) {
GGML_ASSERT((uint32_t) cell_id < cache.rs.size); GGML_ASSERT((uint32_t) cell_id < cache.rs.size);
llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; llama_rs_cell & rs_cell = cache.rs.cells[cell_id];
@ -3423,13 +3420,19 @@ static llama_pos llama_cache_seq_add(
if (rs_cell.pos < 0) { if (rs_cell.pos < 0) {
// NOTE: this affects the other sequences which share the cell // NOTE: this affects the other sequences which share the cell
cache.rs.clear_cell(rs_cell); cache.rs.clear_cell(rs_cell);
// TODO: update cache.rs.head if (new_head > (uint32_t) cell_id) {
new_head = cell_id;
}
} }
} }
if (n_past <= rs_cell.pos) { if (n_past <= rs_cell.pos) {
n_past = rs_cell.pos + 1; n_past = rs_cell.pos + 1;
} }
} }
// If we freed up a slot, set head to it so searching can start there.
// Otherwise we just start the next search from the beginning.
cache.rs.head = new_head != cache.rs.size ? new_head : 0;
} }
if (cache.kv.size > 0) { if (cache.kv.size > 0) {
@ -3474,8 +3477,8 @@ static llama_pos llama_cache_seq_div(
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
int d) { int d) {
if (p0 < 0) p0 = 0; if (p0 < 0) { p0 = 0; }
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max(); if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); }
llama_pos n_past = p0; llama_pos n_past = p0;
@ -11275,6 +11278,10 @@ static int llama_decode_internal(
} }
} }
n_outputs_prev += lctx.n_outputs; n_outputs_prev += lctx.n_outputs;
#ifndef NDEBUG
GGML_ASSERT(lctx.cache.rs.rebuild(true));
#endif
} }
// wait for the computation to finish (automatically done when obtaining the model output) // wait for the computation to finish (automatically done when obtaining the model output)
@ -16332,11 +16339,19 @@ void llama_batch_free(struct llama_batch batch) {
int32_t llama_decode( int32_t llama_decode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch) { struct llama_batch batch) {
#ifndef NDEBUG
GGML_ASSERT(ctx->cache.rs.rebuild(true));
#endif
const int ret = llama_decode_internal(*ctx, batch); const int ret = llama_decode_internal(*ctx, batch);
if (ret < 0) { if (ret < 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
} }
#ifndef NDEBUG
GGML_ASSERT(ctx->cache.rs.rebuild(true));
#endif
return ret; return ret;
} }