diff --git a/llama.cpp b/llama.cpp index d55219256..cf796cce3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -102,6 +102,9 @@ struct llama_context { // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; bool logits_all = false; + + // work buffer for transformer evaluation + std::vector buf_eval; }; struct llama_context_params llama_context_default_params() { @@ -627,27 +630,19 @@ static bool llama_eval_internal( const int n_rot = hparams.n_embd/hparams.n_head; auto & mem_per_token = lctx.mem_per_token; + auto & buf_eval = lctx.buf_eval; - // TODO: fix this hardcoded size - static size_t buf_size = 512u*1024*1024; - static void * buf = malloc(buf_size); + if (mem_per_token*(n_past + N + 16) > buf_eval.size()) { + const size_t buf_size_new = 1.618*buf_eval.size(); - if (mem_per_token > 0 && mem_per_token*N > buf_size) { - const size_t buf_size_new = 1.3*(mem_per_token*N); // add 30% to account for ggml object overhead - //fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); + //fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_eval.size(), buf_size_new); - // reallocate - buf_size = buf_size_new; - buf = realloc(buf, buf_size); - if (buf == nullptr) { - fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); - return false; - } + buf_eval.resize(buf_size_new); } struct ggml_init_params params = { - /*.mem_size =*/ buf_size, - /*.mem_buffer =*/ buf, + /*.mem_size =*/ buf_eval.size(), + /*.mem_buffer =*/ buf_eval.data(), }; struct ggml_context * ctx0 = ggml_init(params); @@ -832,10 +827,11 @@ static bool llama_eval_internal( memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); } - if (mem_per_token == 0) { - mem_per_token = ggml_used_mem(ctx0)/N; + if (N == 1) { + mem_per_token = ggml_used_mem(ctx0)/(n_past + N); } - //fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0)); + + //fprintf(stderr, "\nused_mem = %zu, %zu MB\n", ggml_used_mem(ctx0), ggml_used_mem(ctx0)/1024/1024); ggml_free(ctx0); @@ -1416,6 +1412,8 @@ struct llama_context * llama_init_from_file( return nullptr; } + ctx->buf_eval.resize(512u*1024u*1024u); + return ctx; }