Fix memory allocation issues and seg faults

This commit is contained in:
Georgi Gerganov 2023-03-24 00:11:53 +02:00
parent 483bab2e3d
commit 4870e455b3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -102,6 +102,9 @@ struct llama_context {
// decode output (2-dimensional array: [n_tokens][n_vocab]) // decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits; std::vector<float> logits;
bool logits_all = false; bool logits_all = false;
// work buffer for transformer evaluation
std::vector<uint8_t> buf_eval;
}; };
struct llama_context_params llama_context_default_params() { struct llama_context_params llama_context_default_params() {
@ -627,27 +630,19 @@ static bool llama_eval_internal(
const int n_rot = hparams.n_embd/hparams.n_head; const int n_rot = hparams.n_embd/hparams.n_head;
auto & mem_per_token = lctx.mem_per_token; auto & mem_per_token = lctx.mem_per_token;
auto & buf_eval = lctx.buf_eval;
// TODO: fix this hardcoded size if (mem_per_token*(n_past + N + 16) > buf_eval.size()) {
static size_t buf_size = 512u*1024*1024; const size_t buf_size_new = 1.618*buf_eval.size();
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) { //fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_eval.size(), buf_size_new);
const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead
//fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate buf_eval.resize(buf_size_new);
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
return false;
}
} }
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ buf_size, /*.mem_size =*/ buf_eval.size(),
/*.mem_buffer =*/ buf, /*.mem_buffer =*/ buf_eval.data(),
}; };
struct ggml_context * ctx0 = ggml_init(params); struct ggml_context * ctx0 = ggml_init(params);
@ -832,10 +827,11 @@ static bool llama_eval_internal(
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
} }
if (mem_per_token == 0) { if (N == 1) {
mem_per_token = ggml_used_mem(ctx0)/N; mem_per_token = ggml_used_mem(ctx0)/(n_past + N);
} }
//fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
//fprintf(stderr, "\nused_mem = %zu, %zu MB\n", ggml_used_mem(ctx0), ggml_used_mem(ctx0)/1024/1024);
ggml_free(ctx0); ggml_free(ctx0);
@ -1416,6 +1412,8 @@ struct llama_context * llama_init_from_file(
return nullptr; return nullptr;
} }
ctx->buf_eval.resize(512u*1024u*1024u);
return ctx; return ctx;
} }