diff --git a/include/llama.h b/include/llama.h index 59f38936f..6f6e73c90 100644 --- a/include/llama.h +++ b/include/llama.h @@ -38,10 +38,10 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 8 +#define LLAMA_SESSION_VERSION 9 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ -#define LLAMA_STATE_SEQ_VERSION 2 +#define LLAMA_STATE_SEQ_VERSION 3 #ifdef __cplusplus extern "C" { diff --git a/src/llama.cpp b/src/llama.cpp index 213a27cc8..0f55196cf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -19839,8 +19839,28 @@ struct llama_data_write { } } + void write_rs_cache_meta(const llama_rs_cache & rs_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { + + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = rs_self.cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_nodes.size() : 0; + + write(&pos, sizeof(pos)); + write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_node : cell.seq_nodes) { + write(&seq_node.seq_id, sizeof(seq_node.seq_id)); + } + } + } + } + } + void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { - const struct llama_kv_cache & kv_self = ctx->kv_self; + const struct llama_kv_cache & kv_self = ctx->cache.kv; const struct llama_hparams & hparams = ctx->model.hparams; const uint32_t v_trans = kv_self.v_trans ? 1 : 0; @@ -19849,12 +19869,10 @@ struct llama_data_write { write(&v_trans, sizeof(v_trans)); write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Write key type const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; @@ -19874,7 +19892,7 @@ struct llama_data_write { if (!kv_self.v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; @@ -19895,7 +19913,7 @@ struct llama_data_write { // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = kv_self.size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; @@ -19922,43 +19940,151 @@ struct llama_data_write { } } - void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { - const struct llama_kv_cache & kv_self = ctx->kv_self; - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; + void write_rs_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { + const struct llama_rs_cache & rs_self = ctx->cache.rs; + const struct llama_hparams & hparams = ctx->model.hparams; - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = kv_self.size; - for (uint32_t i = 0; i < kv_self.size; ++i) { - const auto & cell = kv_self.cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { - ++cell_count; - if (cell_range_begin == kv_self.size) { - cell_range_begin = i; + const uint32_t n_layer = hparams.n_layer; + + write(&n_layer, sizeof(n_layer)); + + // Iterate and write all recurrent states, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_r = hparams.n_embd_r(il); + + // Write type + const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type; + write(&r_type_i, sizeof(r_type_i)); + + // Write row size + const uint64_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r); + write(&r_size_row, sizeof(r_size_row)); + + // Read each range of cells of r_size length each and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * r_size_row; + write_tensor_data(rs_self.r_l[il], range.first * r_size_row, buf_size); + } + } + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_s = hparams.n_embd_s(il); + + // Write type + const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type; + write(&s_type_i, sizeof(s_type_i)); + + // Write row size + const uint64_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s); + write(&s_size_row, sizeof(s_size_row)); + + // Read each range of cells of s_size length each and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * s_size_row; + write_tensor_data(rs_self.s_l[il], range.first * s_size_row, buf_size); + } + } + } + + void write_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { + const struct llama_kv_cache & kv_self = ctx->cache.kv; + const struct llama_rs_cache & rs_self = ctx->cache.rs; + std::vector> kv_cell_ranges; // ranges, from inclusive, to exclusive + std::vector> rs_cell_ranges; // ranges, from inclusive, to exclusive + uint32_t kv_cell_count = 0; + uint32_t rs_cell_count = 0; + // Transformer KV cache + { + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = kv_self.size; + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto & cell = kv_self.cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++kv_cell_count; + if (cell_range_begin == kv_self.size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != kv_self.size) { + kv_cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = kv_self.size; + } } - } else { - if (cell_range_begin != kv_self.size) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = kv_self.size; + } + if (cell_range_begin != kv_self.size) { + kv_cell_ranges.emplace_back(cell_range_begin, kv_self.size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : kv_cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(kv_cell_count == cell_count_check); + } + // Recurrent state cache + if (seq_id == -1) { + // Find all the ranges of cells + uint32_t cell_range_begin = rs_self.size; + for (uint32_t i = 0; i < rs_self.size; ++i) { + const auto & cell = rs_self.cells[i]; + if (!cell.is_empty()) { + ++rs_cell_count; + if (cell_range_begin == rs_self.size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != rs_self.size) { + rs_cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = rs_self.size; + } + } + } + if (cell_range_begin != rs_self.size) { + rs_cell_ranges.emplace_back(cell_range_begin, rs_self.size); + } + + } else { + // Find the cell ranges of the specified seq_id + if ((size_t) seq_id < rs_self.seq_tails.size()) { + int32_t tail_cell_id = rs_self.seq_tails[seq_id].tail; + if (tail_cell_id >= 0) { + ++rs_cell_count; + rs_cell_ranges.emplace_back(tail_cell_id, tail_cell_id + 1); } } } - if (cell_range_begin != kv_self.size) { - cell_ranges.emplace_back(cell_range_begin, kv_self.size); + + { + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : rs_cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(rs_cell_count == cell_count_check); } - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; + write(&kv_cell_count, sizeof(kv_cell_count)); + write(&rs_cell_count, sizeof(rs_cell_count)); + + if (seq_id == -1) { + // write metadata for both when the whole cache needs to be saved + write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id); + write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id); + } else if (kv_cell_count > 0) { + write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id); + } else { + write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id); + } + if (kv_cell_count > 0) { + write_kv_cache_data(ctx, kv_cell_ranges); + } + if (rs_cell_count > 0) { + write_rs_cache_data(ctx, rs_cell_ranges); } - GGML_ASSERT(cell_count == cell_count_check); - - write(&cell_count, sizeof(cell_count)); - - write_kv_cache_meta(kv_self, cell_ranges, seq_id); - write_kv_cache_data(ctx, cell_ranges); } }; @@ -20050,108 +20176,98 @@ struct llama_data_read { } } - bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) { - struct llama_kv_cache & kv_self = ctx->kv_self; + bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + struct llama_past & cache = ctx->cache; + struct llama_kv_cache & kv_self = cache.kv; - if (dest_seq_id != -1) { - // single sequence + // whole KV cache restore - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + if (cell_count > kv_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } - llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - batch.n_tokens = cell_count; - batch.n_seq_tokens = cell_count; - batch.n_seqs = 1; + for (uint32_t i = 0; i < cell_count; ++i) { + llama_kv_cell & cell = kv_self.cells[i]; - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; + llama_pos pos; + uint32_t n_seq_id; - read_to(&pos, sizeof(pos)); - read_to(&n_seq_id, sizeof(n_seq_id)); + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); - if (n_seq_id != 0) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); + + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); return false; } - batch.pos[i] = pos; - } - batch.n_seq_id[0] = 1; - batch.seq_id[0] = &dest_seq_id; - if (!llama_kv_cache_find_slot(kv_self, batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; + cell.seq_id.insert(seq_id); } + } - // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); - GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); - } else { - // whole KV cache restore + kv_self.head = 0; + kv_self.used = cell_count; - if (cell_count > kv_self.size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } + return true; + } - llama_kv_cache_clear(kv_self); + bool read_rs_cache_meta(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + struct llama_past & cache = ctx->cache; + struct llama_rs_cache & rs_self = cache.rs; - for (uint32_t i = 0; i < cell_count; ++i) { - llama_kv_cell & cell = kv_self.cells[i]; + // whole RS cache restore - llama_pos pos; - uint32_t n_seq_id; + if (cell_count > rs_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in rs cache\n", __func__); + return false; + } - read_to(&pos, sizeof(pos)); - read_to(&n_seq_id, sizeof(n_seq_id)); + for (uint32_t i = 0; i < cell_count; ++i) { + llama_rs_cell & cell = rs_self.cells[i]; - cell.pos = pos; + llama_pos pos; + uint32_t n_seq_id; - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - read_to(&seq_id, sizeof(seq_id)); + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - return false; - } + cell.pos = pos; + cell.src = i; - cell.seq_id.insert(seq_id); + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); - if (kv_self.recurrent) { - int32_t & tail = kv_self.cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + return false; } - } - kv_self.head = 0; - kv_self.used = cell_count; - } + cell.insert_node(seq_id); - if (kv_self.recurrent) { - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = kv_self.head + i; - // make sure the recurrent states will keep their restored state - kv_self.cells[cell_id].src = cell_id; } } + rs_self.head = 0; + rs_self.used = cell_count; + + rs_self.rebuild(/* debug */ false); + return true; } bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } const struct llama_hparams & hparams = ctx->model.hparams; - struct llama_kv_cache & kv_self = ctx->kv_self; + struct llama_kv_cache & kv_self = ctx->cache.kv; uint32_t v_trans; uint32_t n_layer; read_to(&v_trans, sizeof(v_trans)); @@ -20172,7 +20288,7 @@ struct llama_data_read { // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Read type of key int32_t k_type_i_ref; @@ -20192,15 +20308,13 @@ struct llama_data_read { return false; } - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); - } + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); } if (!kv_self.v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -20220,15 +20334,13 @@ struct llama_data_read { return false; } - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); - } + // Read and set the values for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); } } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -20256,29 +20368,174 @@ struct llama_data_read { return false; } - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; - ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } } return true; } - void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { - uint32_t cell_count; - read_to(&cell_count, sizeof(cell_count)); + bool read_rs_cache_data(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + const struct llama_hparams & hparams = ctx->model.hparams; + struct llama_rs_cache & rs_self = ctx->cache.rs; + uint32_t n_layer; + read_to(&n_layer, sizeof(n_layer)); - bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > rs_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in rs cache to restore state (%u > %u)\n", __func__, cell_count, rs_self.size); + return false; + } + + // For each layer, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_r = hparams.n_embd_r(il); + + // Read type of key + int32_t r_type_i_ref; + read_to(&r_type_i_ref, sizeof(r_type_i_ref)); + const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type; + if (r_type_i != r_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t r_size_row_ref; + read_to(&r_size_row_ref, sizeof(r_size_row_ref)); + const size_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r); + if (r_size_row != r_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il); + return false; + } + + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(rs_self.r_l[il], read(cell_count * r_size_row), rs_self.head * r_size_row, cell_count * r_size_row); + } + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_s = hparams.n_embd_s(il); + + // Read type of key + int32_t s_type_i_ref; + read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type; + if (s_type_i != s_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t s_size_row_ref; + read_to(&s_size_row_ref, sizeof(s_size_row_ref)); + const size_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s); + if (s_size_row != s_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il); + return false; + } + + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(rs_self.s_l[il], read(cell_count * s_size_row), rs_self.head * s_size_row, cell_count * s_size_row); + } + + return true; + } + + bool read_cache_seq_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) { + + if (seq_id < 0 || seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: seq_id out of range [0, %d): %d\n", __func__, llama_n_seq_max(ctx), seq_id); + return false; + } + + // single sequence + + llama_past & cache = ctx->cache; + llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &seq_id; + if (!llama_past_find_slot(cache, batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + if (cache.kv.size > 0) { + // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(cache.kv.head + cell_count <= cache.kv.size); + GGML_ASSERT(cache.kv.cells[cache.kv.head].pos == batch.pos[0]); + GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cache.kv.cells[cache.kv.head].has_seq_id(seq_id)); + GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].has_seq_id(seq_id)); + } + if (cache.rs.size > 0) { + GGML_ASSERT(cache.rs.head + cache.rs.n <= cache.rs.size); + GGML_ASSERT(cache.rs.n == 1); + GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cache.rs.cells[cache.rs.head].has_seq_id(seq_id)); + GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].has_seq_id(seq_id)); + // Prevent cells from being cleared + for (uint32_t i = cache.rs.head; i < cache.rs.head + cache.rs.n; ++i) { + cache.rs.cells[i].src = i; + } + } + + return true; + } + + void read_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { + uint32_t kv_cell_count; + read_to(&kv_cell_count, sizeof(kv_cell_count)); + uint32_t rs_cell_count; + read_to(&rs_cell_count, sizeof(rs_cell_count)); + + bool res = true; + + if (seq_id == -1) { + llama_past_clear(ctx); + res = read_kv_cache_meta(ctx, kv_cell_count) && read_rs_cache_meta(ctx, rs_cell_count); + } else { + llama_past_seq_rm(ctx, seq_id, -1, -1); + // Only a single recurrent cell at most, + // because otherwise the cells can be shuffled when a slot is allocated + if (rs_cell_count > 1) { + LLAMA_LOG_ERROR("%s: too many recurrent state cells for single-sequence session\n", __func__); + res = false; + } + res = res && read_cache_seq_meta(ctx, std::max(kv_cell_count, rs_cell_count), seq_id); + } + + res = res && read_kv_cache_data(ctx, kv_cell_count) && read_rs_cache_data(ctx, rs_cell_count); if (!res) { if (seq_id == -1) { - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); } else { - llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); + llama_past_seq_rm(ctx, seq_id, -1, -1); } throw std::runtime_error("failed to restore kv cache"); } @@ -20433,7 +20690,7 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da data_ctx.write_logits(ctx); data_ctx.write_embeddings(ctx); - data_ctx.write_kv_cache(ctx); + data_ctx.write_cache(ctx); return data_ctx.get_size_written(); } @@ -20473,7 +20730,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da data_ctx.read_logits(ctx); data_ctx.read_embeddings(ctx); - data_ctx.read_kv_cache(ctx); + data_ctx.read_cache(ctx); return data_ctx.get_size_read(); } @@ -20569,7 +20826,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) { llama_synchronize(ctx); - data_ctx.write_kv_cache(ctx, seq_id); + data_ctx.write_cache(ctx, seq_id); return data_ctx.get_size_written(); } @@ -20592,7 +20849,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) { llama_synchronize(ctx); - data_ctx.read_kv_cache(ctx, dest_seq_id); + data_ctx.read_cache(ctx, dest_seq_id); return data_ctx.get_size_read(); }