mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 11:40:17 +00:00
llama : fix save/load state
This commit is contained in:
parent
29ab5a0ed1
commit
b59ddf945e
@ -18757,8 +18757,6 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||
|
||||
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
|
||||
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
|
||||
@ -18778,6 +18776,9 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
|
||||
|
||||
std::vector<uint8_t> tmp_buf;
|
||||
for (int il = 0; il < (int) n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
|
||||
|
||||
tmp_buf.resize(k_size);
|
||||
@ -18910,8 +18911,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||
|
||||
size_t kv_buf_size;
|
||||
uint32_t kv_head;
|
||||
@ -18943,6 +18942,9 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
||||
GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
|
||||
|
||||
for (int il = 0; il < (int) n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
|
||||
|
||||
ggml_backend_tensor_set(kv_self.k_l[il], inp, 0, k_size);
|
||||
@ -19105,8 +19107,6 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||
|
||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||
const auto & cell = kv_self.cells[i];
|
||||
@ -19117,6 +19117,9 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
|
||||
}
|
||||
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// types of keys and values
|
||||
s_cell_data_size += sizeof(int32_t) * 2;
|
||||
// k_size_row and v_size_el values of layer
|
||||
@ -19191,14 +19194,15 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
||||
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||
|
||||
// Write the layer count
|
||||
data_ctx.write(&n_layer, sizeof(n_layer));
|
||||
|
||||
// Write n_embd_v_gqa
|
||||
data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
|
||||
// Write n_embd_v_gqa (reference value)
|
||||
{
|
||||
const uint32_t n_embd_v_gqa_ref = hparams.n_embd_v_gqa() + hparams.n_embd_k_s();
|
||||
data_ctx.write(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref));
|
||||
}
|
||||
|
||||
// Iterate the ranges and write all the pos (this is the token position in the prompt)
|
||||
for (const auto & range : cell_ranges) {
|
||||
@ -19212,6 +19216,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
||||
// Get whole range at a time
|
||||
std::vector<uint8_t> tmp_buf;
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
|
||||
// Write key type
|
||||
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
||||
data_ctx.write(&k_type_i, sizeof(k_type_i));
|
||||
@ -19232,6 +19238,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
||||
// TODO: simplify, reduce copy-paste
|
||||
if (!kv_self.v_trans) {
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
||||
@ -19252,6 +19260,8 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
||||
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
||||
@ -19320,14 +19330,14 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
||||
// Sanity check model compatibility
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||
|
||||
if (n_layer != n_layer_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched n_layer (%d != %d)\n", __func__, n_layer, n_layer_ref);
|
||||
return 0;
|
||||
}
|
||||
if (n_embd_v_gqa != n_embd_v_gqa_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref);
|
||||
|
||||
if (hparams.n_embd_v_gqa() != n_embd_v_gqa_ref) {
|
||||
LLAMA_LOG_ERROR("%s: mismatched n_embd_v_gqa (%d != %d)\n", __func__, hparams.n_embd_v_gqa(), n_embd_v_gqa_ref);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -19367,6 +19377,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
||||
|
||||
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
||||
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
|
||||
@ -19399,6 +19411,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
||||
// TODO: simplify, reduce copy-paste
|
||||
if (!kv_self.v_trans) {
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
||||
@ -19430,6 +19444,8 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
||||
} else {
|
||||
// For each layer, read the values for each cell (transposed)
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
||||
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
||||
|
Loading…
Reference in New Issue
Block a user