llama : fix llama_copy_state_data with fragmented KV cache (#5840)

The row size of the saved states was based on kv_self.head while
it should be based on llama_kv_cache_cell_max.

Existing session files should still work.

* llama : fix llama_kv_cache_cell_max inability to return 1

I've also changed its return type to uint32_t,
because this function is always used to set the value of uint32_t variables,
and because the index already has this type.

* llama : fix state size calculation

Some bytes in the state were unaccounted for in llama_get_state_size.
Since the logits reserve so much space, it did not cause problems.
This commit is contained in:
compilade 2024-03-03 03:41:55 -05:00 committed by GitHub
parent e6029348e8
commit de9692a7d2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -2156,10 +2156,12 @@ static bool llama_kv_cache_find_slot(
} }
// find how many cells are currently in use // find how many cells are currently in use
static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
for (uint32_t i = cache.size - 1; i > 0; --i) { for (uint32_t i = cache.size; i > 0; --i) {
if (cache.cells[i].pos >= 0 && !cache.cells[i].is_empty()) { const llama_kv_cell & cell = cache.cells[i - 1];
return i + 1;
if (cell.pos >= 0 && !cell.is_empty()) {
return i;
} }
} }
@ -8178,7 +8180,7 @@ static int llama_decode_internal(
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears // after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important // if we start defragmenting the cache, the benefit from this will be more important
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); kv_self.n = std::min(cparams.n_ctx, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
//kv_self.n = llama_kv_cache_cell_max(kv_self); //kv_self.n = llama_kv_cache_cell_max(kv_self);
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
@ -12615,9 +12617,14 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
const size_t s_logits = ctx->logits.capacity() * sizeof(float); const size_t s_logits = ctx->logits.capacity() * sizeof(float);
const size_t s_embedding_size = sizeof(size_t); const size_t s_embedding_size = sizeof(size_t);
const size_t s_embedding = ctx->embedding.size() * sizeof(float); const size_t s_embedding = ctx->embedding.size() * sizeof(float);
const size_t s_kv_size = sizeof(size_t); const size_t s_kv_buf_size = sizeof(size_t);
const size_t s_kv_ntok = sizeof(int); const size_t s_kv_head = sizeof(uint32_t);
const size_t s_kv_size = sizeof(uint32_t);
const size_t s_kv_used = sizeof(uint32_t);
const size_t s_kv = ctx->kv_self.total_size(); const size_t s_kv = ctx->kv_self.total_size();
// TODO: assume the max is more than 1 seq_id per KV cell
const size_t s_kv_cell = sizeof(llama_pos) + sizeof(size_t) + sizeof(llama_seq_id);
const size_t s_kv_cells = ctx->kv_self.size * s_kv_cell;
const size_t s_total = ( const size_t s_total = (
+ s_rng_size + s_rng_size
@ -12626,9 +12633,12 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
+ s_logits + s_logits
+ s_embedding_size + s_embedding_size
+ s_embedding + s_embedding
+ s_kv_buf_size
+ s_kv_head
+ s_kv_size + s_kv_size
+ s_kv_ntok + s_kv_used
+ s_kv + s_kv
+ s_kv_cells
); );
return s_total; return s_total;
@ -12728,15 +12738,13 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
{ {
const auto & kv_self = ctx->kv_self; const auto & kv_self = ctx->kv_self;
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
const auto & cparams = ctx->cparams;
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t n_ctx = cparams.n_ctx;
const size_t kv_buf_size = kv_self.total_size(); const size_t kv_buf_size = kv_self.total_size();
const uint32_t kv_head = kv_self.head; const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
const uint32_t kv_size = kv_self.size; const uint32_t kv_size = kv_self.size;
const uint32_t kv_used = kv_self.used; const uint32_t kv_used = kv_self.used;
@ -12756,7 +12764,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
// v is not contiguous, copy row by row // v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
tmp_buf.resize(v_row_size); tmp_buf.resize(v_row_size);
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) { for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
@ -12766,7 +12774,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
} }
} }
for (uint32_t i = 0; i < kv_size; ++i) { for (uint32_t i = 0; i < kv_head; ++i) {
const auto & cell = kv_self.cells[i]; const auto & cell = kv_self.cells[i];
const llama_pos pos = cell.pos; const llama_pos pos = cell.pos;
@ -12842,12 +12850,10 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
{ {
const auto & kv_self = ctx->kv_self; const auto & kv_self = ctx->kv_self;
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
const auto & cparams = ctx->cparams;
const uint32_t n_layer = hparams.n_layer; const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t n_ctx = cparams.n_ctx;
size_t kv_buf_size; size_t kv_buf_size;
uint32_t kv_head; uint32_t kv_head;
@ -12870,7 +12876,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
// v is not contiguous, copy row by row // v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, n_ctx); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) { for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size); ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
@ -12879,13 +12885,15 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
} }
} }
GGML_ASSERT(kv_self.size == kv_size);
ctx->kv_self.head = kv_head; ctx->kv_self.head = kv_head;
ctx->kv_self.size = kv_size; ctx->kv_self.size = kv_size;
ctx->kv_self.used = kv_used; ctx->kv_self.used = kv_used;
ctx->kv_self.cells.resize(kv_size); ctx->kv_self.cells.resize(kv_size);
for (uint32_t i = 0; i < kv_size; ++i) { for (uint32_t i = 0; i < kv_head; ++i) {
llama_pos pos; llama_pos pos;
size_t seq_id_size; size_t seq_id_size;
@ -12901,6 +12909,11 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
ctx->kv_self.cells[i].seq_id.insert(seq_id); ctx->kv_self.cells[i].seq_id.insert(seq_id);
} }
} }
for (uint32_t i = kv_head; i < kv_size; ++i) {
ctx->kv_self.cells[i].pos = -1;
ctx->kv_self.cells[i].seq_id.clear();
}
} }
const size_t nread = inp - src; const size_t nread = inp - src;