From c2beeb8e3a4a1ed52bd496b9ba2fdade1871f726 Mon Sep 17 00:00:00 2001 From: slaren Date: Mon, 17 Jul 2023 11:18:19 +0200 Subject: [PATCH] only allocate as much memory as is required in each backend for the model --- llama.cpp | 154 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 80 insertions(+), 74 deletions(-) diff --git a/llama.cpp b/llama.cpp index 8a40002d8..8eacdc33a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1,5 +1,5 @@ -#define DEFAULT_COMPUTE_TYPE GGML_TYPE_F32 -//#define DEFAULT_COMPUTE_TYPE GGML_TYPE_F16 +#define LLAMA_DEFAULT_COMPUTE_TYPE GGML_TYPE_F32 +//#define LLAMA_DEFAULT_COMPUTE_TYPE GGML_TYPE_F16 // Defines fileno on msys: #ifndef _GNU_SOURCE @@ -276,8 +276,7 @@ struct llama_context { // key + value cache for the self attention struct llama_kv_cache kv_self; - - size_t mem_per_token = 0; + ggml_backend * backend_kv = NULL; // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; @@ -287,25 +286,22 @@ struct llama_context { std::vector embedding; // memory buffers used to evaluate the model - ggml_buffer buf_compute_cpu; + ggml_buffer buf_compute_cpu = {}; #ifdef GGML_USE_CUDA - ggml_buffer buf_compute_cuda; + ggml_buffer buf_compute_cuda = {}; #endif - // input/output tensors - // inputs + // input tensors struct ggml_tensor * graph_tokens_in = nullptr; struct ggml_tensor * graph_embeddings_in = nullptr; - // outputs + // output tensors struct ggml_tensor * graph_logits = nullptr; struct ggml_tensor * graph_embeddings_out = nullptr; // buffers to store the inputs and outputs of the graphs - ggml_buffer buf_input; - ggml_buffer buf_output; - - ggml_backend * backend_kv = NULL; + ggml_buffer buf_input = {}; + ggml_buffer buf_output = {}; }; template @@ -942,11 +938,6 @@ static void llama_model_load_internal( return; } - size_t ctx_size; - size_t mmapped_size; - ml->calc_sizes(&ctx_size, &mmapped_size); - fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); - // initialize backends const uint32_t n_layer = hparams.n_layer; @@ -959,13 +950,56 @@ static void llama_model_load_internal( } #endif + // assign splits to the backends + const int i_gpu_start = std::max(0, (int)n_layer - n_gpu_layers); + model.backend_input = n_gpu_layers > (int)n_layer ? backend_gpu : &model.backend_cpu; + model.backend_output = n_gpu_layers > 0 ? backend_gpu : &model.backend_cpu; + model.backend_layers.resize(n_layer); + std::fill(model.backend_layers.begin(), model.backend_layers.begin() + i_gpu_start, &model.backend_cpu); + std::fill(model.backend_layers.begin() + i_gpu_start, model.backend_layers.end(), backend_gpu); + + // calculate the size of each context + std::unordered_map ctx_sizes; + for (const llama_load_tensor & lt : ml->tensors_map.tensors) { + if (lt.name == "tok_embeddings.weight") { + ctx_sizes[model.backend_input] += lt.size; + } + else if (lt.name == "norm.weight" || lt.name == "output.weight") { + ctx_sizes[model.backend_output] += lt.size; + } + else { + // parse layer number from name + int layer = -1; + if (sscanf(lt.name.c_str(), "layers.%d.", &layer) != 1) { + throw std::runtime_error(format("failed to parse layer number from tensor name '%s'", lt.name.c_str())); + } + if (layer < 0 || layer >= (int)n_layer) { + throw std::runtime_error(format("invalid layer number %d", layer)); + } + ctx_sizes[model.backend_layers[layer]] += lt.size; + } + } + // TODO: generalize support for mmap + size_t mmap_size = 0; + if (ml->use_mmap) { + mmap_size = ctx_sizes[&model.backend_cpu]; + ctx_sizes[&model.backend_cpu] = 0; + } + + fprintf(stderr, "%s: ggml ctx sizes:\n", __func__); + for (const auto & it : ctx_sizes) { + fprintf(stderr, "%8s = %7.2f MB", ggml_backend_name(it.first), it.second / 1024.0 / 1024.0); + if (it.first == &model.backend_cpu && ml->use_mmap) { + fprintf(stderr, " + %7.2f MB (mmap)", mmap_size / 1024.0 / 1024.0); + } + fprintf(stderr, "\n"); + } + // create the buffers and contexts - // TODO: only allocate the amount of memory needed for each backend - // TODO: all of this is bad, clean up { size_t cpu_num_tensors = ml->tensors_map.tensors.size(); - size_t cpu_ctx_size = ctx_size; - model.buf_cpu = ggml_backend_alloc_buffer(&model.backend_cpu, cpu_ctx_size, cpu_num_tensors); + size_t ctx_size = ctx_sizes[&model.backend_cpu]; + model.buf_cpu = ggml_backend_alloc_buffer(&model.backend_cpu, ctx_size, cpu_num_tensors); struct ggml_init_params params = ggml_init_params_default(); params.buffer = &model.buf_cpu; params.no_alloc = ml->use_mmap; @@ -979,8 +1013,8 @@ static void llama_model_load_internal( #ifdef GGML_USE_CUDA if (n_gpu_layers > 0) { size_t gpu_num_tensors = ml->tensors_map.tensors.size(); - size_t gpu_ctx_size = ctx_size + mmapped_size; - model.buf_cuda = ggml_backend_alloc_buffer(&model.backend_cuda, gpu_ctx_size, gpu_num_tensors); + size_t ctx_size = ctx_sizes[&model.backend_cuda]; + model.buf_cuda = ggml_backend_alloc_buffer(&model.backend_cuda, ctx_size, gpu_num_tensors); struct ggml_init_params params = ggml_init_params_default(); params.buffer = &model.buf_cuda; model.ctx_cuda = ggml_init(params); @@ -990,30 +1024,6 @@ static void llama_model_load_internal( ctx_gpu = model.ctx_cuda; } #endif - if ((uint32_t)n_gpu_layers > n_layer) { - model.backend_input = backend_gpu; - } else { - model.backend_input = &model.backend_cpu; - } - - if (n_gpu_layers > 0) { - model.backend_output = backend_gpu; - } else { - model.backend_output = &model.backend_cpu; - } - - // assign splits to the backends - const int i_gpu_start = n_layer - n_gpu_layers; - model.backend_layers.resize(n_layer); - for (int i = 0; i < (int)n_layer; ++i) { - struct ggml_backend * layer_backend; - if (i >= i_gpu_start) { - layer_backend = backend_gpu; - } else { - layer_backend = &model.backend_cpu; - } - model.backend_layers[i] = layer_backend; - } // TODO: clean this ggml_context * ctx_input = model.ctx_cpu; @@ -1074,10 +1084,16 @@ static void llama_model_load_internal( { const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; + // FIXME: this is not very useful without knowing the CPU/GPU memory split + // this is the total memory required to run the inference + size_t ctx_sum = mmap_size; + for (const auto & it : ctx_sizes) { + ctx_sum += it.second; + } + const size_t mem_required = - ctx_size + mmapped_size + - MEM_REQ_EVAL().at (model.type); + ctx_sum + MEM_REQ_EVAL().at(model.type); // this is the memory required by one llama_state const size_t mem_required_state = @@ -1085,11 +1101,9 @@ static void llama_model_load_internal( fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); - - } - // populate `tensors_by_name` + // populate tensors_by_name for (llama_load_tensor & lt : ml->tensors_map.tensors) { model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor); } @@ -1140,7 +1154,7 @@ static ggml_graph_splits llama_build_graph( const int n_tokens, const int n_past, bool embeddings_input = false, - ggml_type compute_type = DEFAULT_COMPUTE_TYPE) { + ggml_type compute_type = LLAMA_DEFAULT_COMPUTE_TYPE) { // const int64_t t_start_us = ggml_time_us(); @@ -1164,15 +1178,12 @@ static ggml_graph_splits llama_build_graph( const float freq_scale = hparams.rope_freq_scale; - //auto & mem_per_token = lctx.mem_per_token; - struct ggml_graph_splits splits = ggml_graph_split_init(); // initialize contexts for every backend struct ggml_context * ctx_cpu = nullptr; - // TODO: don't create context if there are no CPU layers - { + if (lctx.buf_compute_cpu.mem_size > 0) { struct ggml_init_params params = ggml_init_params_default(); params.buffer = &lctx.buf_compute_cpu; params.compute_type = compute_type; @@ -1181,8 +1192,7 @@ static ggml_graph_splits llama_build_graph( #ifdef GGML_USE_CUDA struct ggml_context * ctx_cuda = nullptr; - // TODO: don't create context if there are no CUDA layers - if (lctx.model.n_gpu_layers > 0) { + if (lctx.buf_compute_cuda.mem_size > 0) { struct ggml_init_params params = ggml_init_params_default(); params.buffer = &lctx.buf_compute_cuda; params.compute_type = compute_type; @@ -1436,10 +1446,12 @@ static ggml_graph_splits llama_build_graph( if (embeddings != nullptr) { // TODO: fix this, only the last embedding has to be copied LLAMA_ASSERT(false); - ggml_cpy(ctx_o, cur, embeddings); + cur = ggml_cpy(ctx_o, cur, embeddings); } } + // TODO: skip output layer when using embeddings? + // lm_head cur = ggml_mul_mat(ctx_o, model.output, cur); ggml_set_name(cur, "result_output"); @@ -1473,10 +1485,6 @@ static ggml_graph_splits llama_build_graph( } #endif - //if (mem_per_token == 0) { - // mem_per_token = ggml_used_mem(ctx0)/N; - //} - #if 0 printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, @@ -1538,17 +1546,14 @@ static bool llama_eval_internal( struct ggml_graph_splits splits = llama_build_graph(lctx, N, n_past, embd_input); - // TODO: use backend functions if (tokens != nullptr) { // copy the tokens to the input tensor - ggml_backend_set_tensor(lctx.graph_tokens_in, tokens, 0, N*ggml_element_size(lctx.graph_tokens_in)); + ggml_backend_set_tensor_async(lctx.graph_tokens_in, tokens, 0, N*ggml_element_size(lctx.graph_tokens_in)); } else { // copy the embeddings to the input tensor - ggml_backend_set_tensor(lctx.graph_embeddings_in, embd, 0, N*n_embd*ggml_element_size(lctx.graph_embeddings_in)); + ggml_backend_set_tensor_async(lctx.graph_embeddings_in, embd, 0, N*n_embd*ggml_element_size(lctx.graph_embeddings_in)); } - - // run the computation ggml_graph_splits_compute(&splits); ggml_graph_splits_free(&splits); @@ -2702,14 +2707,15 @@ struct llama_context * llama_new_context_with_model( } } - printf("input: %s, ", ggml_backend_name(ctx->model.backend_input)); + fprintf(stderr, "%s: layer backends: ", __func__); + fprintf(stderr, "input: %s, ", ggml_backend_name(ctx->model.backend_input)); for (int i = 0; i < (int)ctx->model.hparams.n_layer; i++) { if (i == 0 || ctx->model.backend_layers[i] != ctx->model.backend_layers[i-1]) { - printf("layer %d: %s, ", i, ggml_backend_name(ctx->model.backend_layers[i])); + fprintf(stderr, "layer %d: %s, ", i, ggml_backend_name(ctx->model.backend_layers[i])); } } - printf("output: %s, ", ggml_backend_name(ctx->model.backend_output)); - printf("kv: %s\n", ggml_backend_name(ctx->backend_kv)); + fprintf(stderr, "output: %s, ", ggml_backend_name(ctx->model.backend_output)); + fprintf(stderr, "kv: %s\n", ggml_backend_name(ctx->backend_kv)); #ifdef GGML_USE_MPI ctx->ctx_mpi = ggml_mpi_init();