llama : free ggml context in set / copy state data (close #1425)

This commit is contained in:
Georgi Gerganov 2023-05-13 09:08:52 +03:00
parent 699b1ad7fe
commit 738ace394a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 29 additions and 21 deletions

View File

@ -2450,8 +2450,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
} }
// Copies the state to the specified destination address // Copies the state to the specified destination address
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
uint8_t * out = dest; uint8_t * out = dst;
// copy rng // copy rng
{ {
@ -2511,7 +2511,9 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
if (kv_size) { if (kv_size) {
const size_t elt_size = ggml_element_size(kv_self.k); const size_t elt_size = ggml_element_size(kv_self.k);
char buffer[4096]; char buffer[4096];
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
ggml_cgraph gf{}; ggml_cgraph gf{};
gf.n_threads = 1; gf.n_threads = 1;
@ -2535,10 +2537,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
ggml_graph_compute(cpy_ctx, &gf); ggml_graph_compute(cpy_ctx, &gf);
ggml_free(cpy_ctx);
} }
} }
const size_t written = out - dest; const size_t written = out - dst;
const size_t max_size = llama_get_state_size(ctx); const size_t max_size = llama_get_state_size(ctx);
LLAMA_ASSERT(written <= max_size); LLAMA_ASSERT(written <= max_size);
@ -2548,15 +2552,15 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
// Sets the state reading from the specified source address // Sets the state reading from the specified source address
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
const uint8_t * in = src; const uint8_t * inp = src;
// set rng // set rng
{ {
size_t rng_size; size_t rng_size;
char rng_buf[LLAMA_MAX_RNG_STATE]; char rng_buf[LLAMA_MAX_RNG_STATE];
memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size); memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE; memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE;
std::stringstream rng_ss; std::stringstream rng_ss;
rng_ss.str(std::string(&rng_buf[0], rng_size)); rng_ss.str(std::string(&rng_buf[0], rng_size));
@ -2570,30 +2574,30 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
size_t logits_cap; size_t logits_cap;
size_t logits_size; size_t logits_size;
memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap); memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap);
memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size); memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
LLAMA_ASSERT(ctx->logits.capacity() == logits_cap); LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
if (logits_size) { if (logits_size) {
ctx->logits.resize(logits_size); ctx->logits.resize(logits_size);
memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
} }
in += logits_cap * sizeof(float); inp += logits_cap * sizeof(float);
} }
// set embeddings // set embeddings
{ {
size_t embedding_size; size_t embedding_size;
memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size); memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size);
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
if (embedding_size) { if (embedding_size) {
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float));
in += embedding_size * sizeof(float); inp += embedding_size * sizeof(float);
} }
} }
@ -2608,25 +2612,27 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
size_t kv_size; size_t kv_size;
int kv_ntok; int kv_ntok;
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size); memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok); memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok);
if (kv_size) { if (kv_size) {
LLAMA_ASSERT(kv_self.buf.size == kv_size); LLAMA_ASSERT(kv_self.buf.size == kv_size);
const size_t elt_size = ggml_element_size(kv_self.k); const size_t elt_size = ggml_element_size(kv_self.k);
char buffer[4096]; char buffer[4096];
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true }); ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
ggml_cgraph gf{}; ggml_cgraph gf{};
gf.n_threads = 1; gf.n_threads = 1;
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
kin3d->data = (void *) in; kin3d->data = (void *) inp;
in += ggml_nbytes(kin3d); inp += ggml_nbytes(kin3d);
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer); ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
vin3d->data = (void *) in; vin3d->data = (void *) inp;
in += ggml_nbytes(vin3d); inp += ggml_nbytes(vin3d);
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k, ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
n_embd, kv_ntok, n_layer, n_embd, kv_ntok, n_layer,
@ -2639,12 +2645,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
ggml_graph_compute(cpy_ctx, &gf); ggml_graph_compute(cpy_ctx, &gf);
ggml_free(cpy_ctx);
} }
ctx->model.kv_self.n = kv_ntok; ctx->model.kv_self.n = kv_ntok;
} }
const size_t nread = in - src; const size_t nread = inp - src;
const size_t max_size = llama_get_state_size(ctx); const size_t max_size = llama_get_state_size(ctx);
LLAMA_ASSERT(nread <= max_size); LLAMA_ASSERT(nread <= max_size);

View File

@ -134,7 +134,7 @@ extern "C" {
// Copies the state to the specified destination address. // Copies the state to the specified destination address.
// Destination needs to have allocated enough memory. // Destination needs to have allocated enough memory.
// Returns the number of bytes copied // Returns the number of bytes copied
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest); LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
// Set the state reading from the specified address // Set the state reading from the specified address
// Returns the number of bytes read // Returns the number of bytes read