mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
llama : free ggml context in set / copy state data (close #1425)
This commit is contained in:
parent
699b1ad7fe
commit
738ace394a
48
llama.cpp
48
llama.cpp
@ -2450,8 +2450,8 @@ size_t llama_get_state_size(const struct llama_context * ctx) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Copies the state to the specified destination address
|
// Copies the state to the specified destination address
|
||||||
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
||||||
uint8_t * out = dest;
|
uint8_t * out = dst;
|
||||||
|
|
||||||
// copy rng
|
// copy rng
|
||||||
{
|
{
|
||||||
@ -2511,7 +2511,9 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
|||||||
|
|
||||||
if (kv_size) {
|
if (kv_size) {
|
||||||
const size_t elt_size = ggml_element_size(kv_self.k);
|
const size_t elt_size = ggml_element_size(kv_self.k);
|
||||||
|
|
||||||
char buffer[4096];
|
char buffer[4096];
|
||||||
|
|
||||||
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
|
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
|
||||||
ggml_cgraph gf{};
|
ggml_cgraph gf{};
|
||||||
gf.n_threads = 1;
|
gf.n_threads = 1;
|
||||||
@ -2535,10 +2537,12 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
|||||||
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
|
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d));
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
|
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d));
|
||||||
ggml_graph_compute(cpy_ctx, &gf);
|
ggml_graph_compute(cpy_ctx, &gf);
|
||||||
|
|
||||||
|
ggml_free(cpy_ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t written = out - dest;
|
const size_t written = out - dst;
|
||||||
const size_t max_size = llama_get_state_size(ctx);
|
const size_t max_size = llama_get_state_size(ctx);
|
||||||
|
|
||||||
LLAMA_ASSERT(written <= max_size);
|
LLAMA_ASSERT(written <= max_size);
|
||||||
@ -2548,15 +2552,15 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) {
|
|||||||
|
|
||||||
// Sets the state reading from the specified source address
|
// Sets the state reading from the specified source address
|
||||||
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
const uint8_t * in = src;
|
const uint8_t * inp = src;
|
||||||
|
|
||||||
// set rng
|
// set rng
|
||||||
{
|
{
|
||||||
size_t rng_size;
|
size_t rng_size;
|
||||||
char rng_buf[LLAMA_MAX_RNG_STATE];
|
char rng_buf[LLAMA_MAX_RNG_STATE];
|
||||||
|
|
||||||
memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
|
memcpy(&rng_size, inp, sizeof(rng_size)); inp += sizeof(rng_size);
|
||||||
memcpy(&rng_buf[0], in, LLAMA_MAX_RNG_STATE); in += LLAMA_MAX_RNG_STATE;
|
memcpy(&rng_buf[0], inp, LLAMA_MAX_RNG_STATE); inp += LLAMA_MAX_RNG_STATE;
|
||||||
|
|
||||||
std::stringstream rng_ss;
|
std::stringstream rng_ss;
|
||||||
rng_ss.str(std::string(&rng_buf[0], rng_size));
|
rng_ss.str(std::string(&rng_buf[0], rng_size));
|
||||||
@ -2570,30 +2574,30 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||||||
size_t logits_cap;
|
size_t logits_cap;
|
||||||
size_t logits_size;
|
size_t logits_size;
|
||||||
|
|
||||||
memcpy(&logits_cap, in, sizeof(logits_cap)); in += sizeof(logits_cap);
|
memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap);
|
||||||
memcpy(&logits_size, in, sizeof(logits_size)); in += sizeof(logits_size);
|
memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size);
|
||||||
|
|
||||||
LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
|
LLAMA_ASSERT(ctx->logits.capacity() == logits_cap);
|
||||||
|
|
||||||
if (logits_size) {
|
if (logits_size) {
|
||||||
ctx->logits.resize(logits_size);
|
ctx->logits.resize(logits_size);
|
||||||
memcpy(ctx->logits.data(), in, logits_size * sizeof(float));
|
memcpy(ctx->logits.data(), inp, logits_size * sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
in += logits_cap * sizeof(float);
|
inp += logits_cap * sizeof(float);
|
||||||
}
|
}
|
||||||
|
|
||||||
// set embeddings
|
// set embeddings
|
||||||
{
|
{
|
||||||
size_t embedding_size;
|
size_t embedding_size;
|
||||||
|
|
||||||
memcpy(&embedding_size, in, sizeof(embedding_size)); in += sizeof(embedding_size);
|
memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size);
|
||||||
|
|
||||||
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
|
LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size);
|
||||||
|
|
||||||
if (embedding_size) {
|
if (embedding_size) {
|
||||||
memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float));
|
memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float));
|
||||||
in += embedding_size * sizeof(float);
|
inp += embedding_size * sizeof(float);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2608,25 +2612,27 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||||||
size_t kv_size;
|
size_t kv_size;
|
||||||
int kv_ntok;
|
int kv_ntok;
|
||||||
|
|
||||||
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
|
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
|
||||||
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
|
memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok);
|
||||||
|
|
||||||
if (kv_size) {
|
if (kv_size) {
|
||||||
LLAMA_ASSERT(kv_self.buf.size == kv_size);
|
LLAMA_ASSERT(kv_self.buf.size == kv_size);
|
||||||
|
|
||||||
const size_t elt_size = ggml_element_size(kv_self.k);
|
const size_t elt_size = ggml_element_size(kv_self.k);
|
||||||
|
|
||||||
char buffer[4096];
|
char buffer[4096];
|
||||||
|
|
||||||
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
|
ggml_context * cpy_ctx = ggml_init({ sizeof(buffer), buffer, /* no_alloc */ true });
|
||||||
ggml_cgraph gf{};
|
ggml_cgraph gf{};
|
||||||
gf.n_threads = 1;
|
gf.n_threads = 1;
|
||||||
|
|
||||||
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer);
|
||||||
kin3d->data = (void *) in;
|
kin3d->data = (void *) inp;
|
||||||
in += ggml_nbytes(kin3d);
|
inp += ggml_nbytes(kin3d);
|
||||||
|
|
||||||
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
|
ggml_tensor * vin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.v->type, kv_ntok, n_embd, n_layer);
|
||||||
vin3d->data = (void *) in;
|
vin3d->data = (void *) inp;
|
||||||
in += ggml_nbytes(vin3d);
|
inp += ggml_nbytes(vin3d);
|
||||||
|
|
||||||
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
|
ggml_tensor * k3d = ggml_view_3d(cpy_ctx, kv_self.k,
|
||||||
n_embd, kv_ntok, n_layer,
|
n_embd, kv_ntok, n_layer,
|
||||||
@ -2639,12 +2645,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||||||
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
|
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d));
|
||||||
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
|
ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d));
|
||||||
ggml_graph_compute(cpy_ctx, &gf);
|
ggml_graph_compute(cpy_ctx, &gf);
|
||||||
|
|
||||||
|
ggml_free(cpy_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx->model.kv_self.n = kv_ntok;
|
ctx->model.kv_self.n = kv_ntok;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t nread = in - src;
|
const size_t nread = inp - src;
|
||||||
const size_t max_size = llama_get_state_size(ctx);
|
const size_t max_size = llama_get_state_size(ctx);
|
||||||
|
|
||||||
LLAMA_ASSERT(nread <= max_size);
|
LLAMA_ASSERT(nread <= max_size);
|
||||||
|
2
llama.h
2
llama.h
@ -134,7 +134,7 @@ extern "C" {
|
|||||||
// Copies the state to the specified destination address.
|
// Copies the state to the specified destination address.
|
||||||
// Destination needs to have allocated enough memory.
|
// Destination needs to have allocated enough memory.
|
||||||
// Returns the number of bytes copied
|
// Returns the number of bytes copied
|
||||||
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest);
|
LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst);
|
||||||
|
|
||||||
// Set the state reading from the specified address
|
// Set the state reading from the specified address
|
||||||
// Returns the number of bytes read
|
// Returns the number of bytes read
|
||||||
|
Loading…
Reference in New Issue
Block a user