diff --git a/llama.cpp b/llama.cpp index 9ca8ca0f4..6dc310bf9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1962,11 +1962,12 @@ struct llama_rs_seq_node { llama_seq_id seq_id = -1; int32_t next_cell = -1; - // needed for automatic typecasting with .find() + // needed for automatic typecasting from a llama_seq_id llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} - bool operator<(const llama_rs_seq_node & other) const { - return seq_id < other.seq_id; + // needed for more convenient std::find + bool operator==(const llama_rs_seq_node & other) const { + return seq_id == other.seq_id; } bool is_tail() const { @@ -1989,48 +1990,18 @@ struct llama_rs_cell { // seq_ids by insertion order, to simplify updating n_cells compared to a set std::vector seq_nodes; - llama_rs_seq_node * get_node(const llama_seq_id & id) { - for (size_t i = 0; i < seq_nodes.size(); ++i) { - if (seq_nodes[i].seq_id == id) { - return &seq_nodes[i]; - } - } - return nullptr; - } - void insert_node(const llama_rs_seq_node & node) { - llama_rs_seq_node * node_dest = get_node(node.seq_id); - if (node_dest == nullptr) { + auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node); + if (node_dest == seq_nodes.end()) { seq_nodes.push_back(node); } else { + // overwrite the pre-existing node with the same seq_id if it exists *node_dest = node; } } - bool remove_node(llama_rs_seq_node * node_ptr) { - if (node_ptr != nullptr && seq_nodes.data() <= node_ptr) { - size_t offset = node_ptr - seq_nodes.data(); - if (offset % sizeof(llama_rs_seq_node) == 0) { - offset /= sizeof(llama_rs_seq_node); - if (offset < seq_nodes.size()) { - for (size_t i = offset + 1; i < seq_nodes.size(); ++i) { - seq_nodes[i - 1] = seq_nodes[i]; - } - seq_nodes.resize(seq_nodes.size() - 1); - return true; - } - } - } - return false; - } - bool has_seq_id(const llama_seq_id & id) const { - for (size_t i = 0; i < seq_nodes.size(); ++i) { - if (seq_nodes[i].seq_id == id) { - return true; - } - } - return false; + return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end(); } bool is_empty() const { @@ -2132,67 +2103,65 @@ struct llama_rs_cache { bool remove_seq_from_cell(uint32_t i_cell, const llama_seq_id & id) { if (i_cell < size && (size_t) id < size) { llama_rs_cell & rs_cell = cells[i_cell]; - auto * node_ptr = rs_cell.get_node(id); // search once - if (node_ptr != nullptr) { + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once + if (node_iter != rs_cell.seq_nodes.end()) { if (rs_cell.seq_nodes.size() == 1) { return clear_cell(i_cell); - } else { - // update tree - llama_rs_seq_node node = *node_ptr; - if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { - cells[node.next_cell].prev = rs_cell.prev; - } - if ((uint32_t) node.seq_id < seq_tails.size()) { - auto & seq = seq_tails[node.seq_id]; - bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; - if (node.is_tail()) { - seq.tail = rs_cell.prev; - if (seq.tail >= 0 && (uint32_t) seq.tail < size) { - llama_rs_cell & new_tail = cells[seq.tail]; - new_tail.insert_node(node.seq_id); // ensures next_cell == -1 - new_tail.tail_rc += 1; - seq.shared = cells[seq.tail].seq_nodes.size() > 1; - } else { - seq.shared = false; - } - GGML_ASSERT(rs_cell.tail_rc > 0); - rs_cell.tail_rc -= 1; - } - if (node_ptr == rs_cell.seq_nodes.data()) { - // this seq_id was the first in the list - seq.n_cells -= 1; - if (seq.n_cells == 0) { - n_seqs -= 1; - } - // the next node is the new first one, so update its n_cells - // (will never be out-of-bounds because the size is > 1) - llama_rs_seq_node next_node = node_ptr[1]; - if ((uint32_t) next_node.seq_id < seq_tails.size()) { - auto & next_seq = seq_tails[next_node.seq_id]; - next_seq.n_cells += 1; - if (next_seq.n_cells == 1) { - n_seqs += 1; - } - if (other_no_longer_shared) { - next_seq.shared = false; - } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } else if (other_no_longer_shared) { - llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; - if ((uint32_t) first_node.seq_id < seq_tails.size()) { - seq_tails[first_node.seq_id].shared = false; - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - } - } else { - GGML_ASSERT(false && "invalid seq_id"); - } - const bool removed = rs_cell.remove_node(node_ptr); - GGML_ASSERT(removed); } + // else update tree + llama_rs_seq_node node = *node_iter; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + cells[node.next_cell].prev = rs_cell.prev; + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + bool other_no_longer_shared = rs_cell.seq_nodes.size() == 2; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (seq.tail >= 0 && (uint32_t) seq.tail < size) { + llama_rs_cell & new_tail = cells[seq.tail]; + new_tail.insert_node(node.seq_id); // ensures next_cell == -1 + new_tail.tail_rc += 1; + seq.shared = cells[seq.tail].seq_nodes.size() > 1; + } else { + seq.shared = false; + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; + } + if (node_iter == rs_cell.seq_nodes.begin()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + if (seq.n_cells == 0) { + n_seqs -= 1; + } + // the next node is the new first one, so update its n_cells + // (will never be out-of-bounds because the size is > 1) + llama_rs_seq_node next_node = *(std::next(node_iter)); + if ((uint32_t) next_node.seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node.seq_id]; + next_seq.n_cells += 1; + if (next_seq.n_cells == 1) { + n_seqs += 1; + } + if (other_no_longer_shared) { + next_seq.shared = false; + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } else if (other_no_longer_shared) { + llama_rs_seq_node first_node = rs_cell.seq_nodes[0]; + if ((uint32_t) first_node.seq_id < seq_tails.size()) { + seq_tails[first_node.seq_id].shared = false; + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + rs_cell.seq_nodes.erase(node_iter); } } return false; @@ -2215,8 +2184,8 @@ struct llama_rs_cache { if (prev >= 0 && (uint32_t) prev < size) { // the targeted cell has a previous cell llama_rs_cell & prev_cell = cells[prev]; - llama_rs_seq_node * prev_node = prev_cell.get_node(id); - GGML_ASSERT(prev_node != nullptr); // TODO: recursive insert instead of failing + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id); + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken if (rs_cell.pos < 0) { GGML_ASSERT(rs_cell.is_empty()); @@ -2267,7 +2236,7 @@ struct llama_rs_cache { int32_t n_system_seqs = 0; int32_t n_system_cells = 0; for (size_t i = 0; i < seq_tails.size(); ++i) { - auto & seq = seq_tails[i]; + const auto & seq = seq_tails[i]; if (seq.tail >= 0 && (size_t) seq.tail < size) { if (seq.shared && seq.n_cells > 0) { n_system_seqs += 1;