mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
llama : fix Mamba session save and restore
This commit is contained in:
parent
0dea4263aa
commit
704a303323
@ -3508,11 +3508,11 @@ static bool llama_kv_cache_find_slot(
|
||||
int32_t cell_id = s + min;
|
||||
llama_kv_cell & cell = cache.cells[cell_id];
|
||||
|
||||
if (last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
||||
if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) {
|
||||
// What should happen when the pos backtracks or skips a value?
|
||||
// Clearing the state mid-batch would require special-casing which isn't done.
|
||||
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n",
|
||||
__func__, last_pos, cell.pos, batch.seq_id[s][0]);
|
||||
LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n",
|
||||
__func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens);
|
||||
}
|
||||
cell.pos = last_pos;
|
||||
cell.seq_id.clear();
|
||||
@ -15013,12 +15013,6 @@ static int llama_decode_internal(
|
||||
|
||||
const auto n_ubatch = cparams.n_ubatch;
|
||||
|
||||
// TODO: simplify or deprecate
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<int32_t> n_seq_id;
|
||||
std::vector<llama_seq_id *> seq_id_arr;
|
||||
std::vector<std::vector<llama_seq_id>> seq_id;
|
||||
|
||||
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
||||
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
||||
|
||||
@ -15636,6 +15630,44 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
}
|
||||
}
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
static void llama_reorder_outputs(struct llama_context * ctx) {
|
||||
std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
|
||||
if (!out_ids.empty()) {
|
||||
uint32_t n_vocab = ctx->model.hparams.n_vocab;
|
||||
uint32_t n_embd = ctx->model.hparams.n_embd;
|
||||
int32_t n_outputs = ctx->n_outputs;
|
||||
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
||||
// TODO: is there something more efficient which also minimizes swaps?
|
||||
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
||||
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
||||
int32_t j_min = i;
|
||||
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
||||
if (out_ids[j] < out_ids[j_min]) {
|
||||
j_min = j;
|
||||
}
|
||||
}
|
||||
if (j_min == i) { continue; }
|
||||
std::swap(out_ids[i], out_ids[j_min]);
|
||||
if (ctx->logits_size > 0) {
|
||||
for (uint32_t k = 0; k < n_vocab; k++) {
|
||||
std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
if (ctx->embd_size > 0) {
|
||||
for (uint32_t k = 0; k < n_embd; k++) {
|
||||
std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
|
||||
for (int32_t i = 0; i < n_outputs; ++i) {
|
||||
ctx->output_ids[out_ids[i]] = i;
|
||||
}
|
||||
out_ids.clear();
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// quantization
|
||||
//
|
||||
@ -17822,6 +17854,8 @@ struct llama_data_write {
|
||||
}
|
||||
|
||||
void write_output_ids(struct llama_context * ctx) {
|
||||
llama_reorder_outputs(ctx);
|
||||
|
||||
const uint32_t n_outputs = ctx->n_outputs;
|
||||
|
||||
std::vector<int32_t> output_pos;
|
||||
@ -18192,6 +18226,14 @@ struct llama_data_read {
|
||||
kv_self.used = cell_count;
|
||||
}
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||
uint32_t cell_id = kv_self.head + i;
|
||||
// make sure the recurrent states will keep their restored state
|
||||
kv_self.cells[cell_id].src = cell_id;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -18843,44 +18885,6 @@ void llama_synchronize(struct llama_context * ctx) {
|
||||
ctx->t_compute_start_us = 0;
|
||||
}
|
||||
|
||||
// make the outputs have the same order they had in the user-provided batch
|
||||
static void llama_reorder_outputs(struct llama_context * ctx) {
|
||||
std::vector<size_t> & out_ids = ctx->sbatch.out_ids;
|
||||
if (!out_ids.empty()) {
|
||||
uint32_t n_vocab = ctx->model.hparams.n_vocab;
|
||||
uint32_t n_embd = ctx->model.hparams.n_embd;
|
||||
int32_t n_outputs = ctx->n_outputs;
|
||||
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
||||
// TODO: is there something more efficient which also minimizes swaps?
|
||||
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
||||
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
||||
int32_t j_min = i;
|
||||
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
||||
if (out_ids[j] < out_ids[j_min]) {
|
||||
j_min = j;
|
||||
}
|
||||
}
|
||||
if (j_min == i) { continue; }
|
||||
std::swap(out_ids[i], out_ids[j_min]);
|
||||
if (ctx->logits_size > 0) {
|
||||
for (uint32_t k = 0; k < n_vocab; k++) {
|
||||
std::swap(ctx->logits[i*n_vocab + k], ctx->logits[j_min*n_vocab + k]);
|
||||
}
|
||||
}
|
||||
if (ctx->embd_size > 0) {
|
||||
for (uint32_t k = 0; k < n_embd; k++) {
|
||||
std::swap(ctx->embd[i*n_embd + k], ctx->embd[j_min*n_embd + k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
std::fill(ctx->output_ids.begin(), ctx->output_ids.end(), -1);
|
||||
for (int32_t i = 0; i < n_outputs; ++i) {
|
||||
ctx->output_ids[out_ids[i]] = i;
|
||||
}
|
||||
out_ids.clear();
|
||||
}
|
||||
}
|
||||
|
||||
float * llama_get_logits(struct llama_context * ctx) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user