From 704a3033236bf39a5e50f172a5a4413d37f3d55f Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 28 Jul 2024 01:59:10 -0400 Subject: [PATCH] llama : fix Mamba session save and restore --- src/llama.cpp | 98 +++++++++++++++++++++++++++------------------------ 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 6b8306ddb..1145d3d55 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3508,11 +3508,11 @@ static bool llama_kv_cache_find_slot( int32_t cell_id = s + min; llama_kv_cell & cell = cache.cells[cell_id]; - if (last_pos != cell.pos + (llama_pos) n_seq_tokens) { + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { // What should happen when the pos backtracks or skips a value? // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", - __func__, last_pos, cell.pos, batch.seq_id[s][0]); + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); } cell.pos = last_pos; cell.seq_id.clear(); @@ -15013,12 +15013,6 @@ static int llama_decode_internal( const auto n_ubatch = cparams.n_ubatch; - // TODO: simplify or deprecate - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_arr; - std::vector> seq_id; - // this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE; @@ -15636,6 +15630,44 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } } +// make the outputs have the same order they had in the user-provided batch +static void llama_reorder_outputs(struct llama_context * ctx) { + std::vector & out_ids = ctx->sbatch.out_ids; + if (!out_ids.empty()) { + uint32_t n_vocab = ctx->model.hparams.n_vocab; + uint32_t n_embd = ctx->model.hparams.n_embd; + int32_t n_outputs = ctx->n_outputs; + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); + if (ctx->logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]); + } + } + if (ctx->embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]); + } + } + } + std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + ctx->output_ids[out_ids[i]] = i; + } + out_ids.clear(); + } +} + // // quantization // @@ -17822,6 +17854,8 @@ struct llama_data_write { } void write_output_ids(struct llama_context * ctx) { + llama_reorder_outputs(ctx); + const uint32_t n_outputs = ctx->n_outputs; std::vector output_pos; @@ -18192,6 +18226,14 @@ struct llama_data_read { kv_self.used = cell_count; } + if (kv_self.recurrent) { + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = kv_self.head + i; + // make sure the recurrent states will keep their restored state + kv_self.cells[cell_id].src = cell_id; + } + } + return true; } @@ -18843,44 +18885,6 @@ void llama_synchronize(struct llama_context * ctx) { ctx->t_compute_start_us = 0; } -// make the outputs have the same order they had in the user-provided batch -static void llama_reorder_outputs(struct llama_context * ctx) { - std::vector & out_ids = ctx->sbatch.out_ids; - if (!out_ids.empty()) { - uint32_t n_vocab = ctx->model.hparams.n_vocab; - uint32_t n_embd = ctx->model.hparams.n_embd; - int32_t n_outputs = ctx->n_outputs; - GGML_ASSERT((size_t) n_outputs == out_ids.size()); - // TODO: is there something more efficient which also minimizes swaps? - // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) - for (int32_t i = 0; i < n_outputs - 1; ++i) { - int32_t j_min = i; - for (int32_t j = i + 1; j < n_outputs; ++j) { - if (out_ids[j] < out_ids[j_min]) { - j_min = j; - } - } - if (j_min == i) { continue; } - std::swap(out_ids[i], out_ids[j_min]); - if (ctx->logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]); - } - } - if (ctx->embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]); - } - } - } - std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1); - for (int32_t i = 0; i < n_outputs; ++i) { - ctx->output_ids[out_ids[i]] = i; - } - out_ids.clear(); - } -} - float * llama_get_logits(struct llama_context * ctx) { llama_synchronize(ctx);