only allocate as much memory as is required in each backend for the model

This commit is contained in:
slaren 2023-07-17 11:18:19 +02:00
parent 9c72e7e916
commit c2beeb8e3a

154
llama.cpp
View File

@ -1,5 +1,5 @@
#define DEFAULT_COMPUTE_TYPE GGML_TYPE_F32
//#define DEFAULT_COMPUTE_TYPE GGML_TYPE_F16
#define LLAMA_DEFAULT_COMPUTE_TYPE GGML_TYPE_F32
//#define LLAMA_DEFAULT_COMPUTE_TYPE GGML_TYPE_F16
// Defines fileno on msys:
#ifndef _GNU_SOURCE
@ -276,8 +276,7 @@ struct llama_context {
// key + value cache for the self attention
struct llama_kv_cache kv_self;
size_t mem_per_token = 0;
ggml_backend * backend_kv = NULL;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
@ -287,25 +286,22 @@ struct llama_context {
std::vector<float> embedding;
// memory buffers used to evaluate the model
ggml_buffer buf_compute_cpu;
ggml_buffer buf_compute_cpu = {};
#ifdef GGML_USE_CUDA
ggml_buffer buf_compute_cuda;
ggml_buffer buf_compute_cuda = {};
#endif
// input/output tensors
// inputs
// input tensors
struct ggml_tensor * graph_tokens_in = nullptr;
struct ggml_tensor * graph_embeddings_in = nullptr;
// outputs
// output tensors
struct ggml_tensor * graph_logits = nullptr;
struct ggml_tensor * graph_embeddings_out = nullptr;
// buffers to store the inputs and outputs of the graphs
ggml_buffer buf_input;
ggml_buffer buf_output;
ggml_backend * backend_kv = NULL;
ggml_buffer buf_input = {};
ggml_buffer buf_output = {};
};
template <typename T>
@ -942,11 +938,6 @@ static void llama_model_load_internal(
return;
}
size_t ctx_size;
size_t mmapped_size;
ml->calc_sizes(&ctx_size, &mmapped_size);
fprintf(stderr, "%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0);
// initialize backends
const uint32_t n_layer = hparams.n_layer;
@ -959,13 +950,56 @@ static void llama_model_load_internal(
}
#endif
// assign splits to the backends
const int i_gpu_start = std::max(0, (int)n_layer - n_gpu_layers);
model.backend_input = n_gpu_layers > (int)n_layer ? backend_gpu : &model.backend_cpu;
model.backend_output = n_gpu_layers > 0 ? backend_gpu : &model.backend_cpu;
model.backend_layers.resize(n_layer);
std::fill(model.backend_layers.begin(), model.backend_layers.begin() + i_gpu_start, &model.backend_cpu);
std::fill(model.backend_layers.begin() + i_gpu_start, model.backend_layers.end(), backend_gpu);
// calculate the size of each context
std::unordered_map<struct ggml_backend *, size_t> ctx_sizes;
for (const llama_load_tensor & lt : ml->tensors_map.tensors) {
if (lt.name == "tok_embeddings.weight") {
ctx_sizes[model.backend_input] += lt.size;
}
else if (lt.name == "norm.weight" || lt.name == "output.weight") {
ctx_sizes[model.backend_output] += lt.size;
}
else {
// parse layer number from name
int layer = -1;
if (sscanf(lt.name.c_str(), "layers.%d.", &layer) != 1) {
throw std::runtime_error(format("failed to parse layer number from tensor name '%s'", lt.name.c_str()));
}
if (layer < 0 || layer >= (int)n_layer) {
throw std::runtime_error(format("invalid layer number %d", layer));
}
ctx_sizes[model.backend_layers[layer]] += lt.size;
}
}
// TODO: generalize support for mmap
size_t mmap_size = 0;
if (ml->use_mmap) {
mmap_size = ctx_sizes[&model.backend_cpu];
ctx_sizes[&model.backend_cpu] = 0;
}
fprintf(stderr, "%s: ggml ctx sizes:\n", __func__);
for (const auto & it : ctx_sizes) {
fprintf(stderr, "%8s = %7.2f MB", ggml_backend_name(it.first), it.second / 1024.0 / 1024.0);
if (it.first == &model.backend_cpu && ml->use_mmap) {
fprintf(stderr, " + %7.2f MB (mmap)", mmap_size / 1024.0 / 1024.0);
}
fprintf(stderr, "\n");
}
// create the buffers and contexts
// TODO: only allocate the amount of memory needed for each backend
// TODO: all of this is bad, clean up
{
size_t cpu_num_tensors = ml->tensors_map.tensors.size();
size_t cpu_ctx_size = ctx_size;
model.buf_cpu = ggml_backend_alloc_buffer(&model.backend_cpu, cpu_ctx_size, cpu_num_tensors);
size_t ctx_size = ctx_sizes[&model.backend_cpu];
model.buf_cpu = ggml_backend_alloc_buffer(&model.backend_cpu, ctx_size, cpu_num_tensors);
struct ggml_init_params params = ggml_init_params_default();
params.buffer = &model.buf_cpu;
params.no_alloc = ml->use_mmap;
@ -979,8 +1013,8 @@ static void llama_model_load_internal(
#ifdef GGML_USE_CUDA
if (n_gpu_layers > 0) {
size_t gpu_num_tensors = ml->tensors_map.tensors.size();
size_t gpu_ctx_size = ctx_size + mmapped_size;
model.buf_cuda = ggml_backend_alloc_buffer(&model.backend_cuda, gpu_ctx_size, gpu_num_tensors);
size_t ctx_size = ctx_sizes[&model.backend_cuda];
model.buf_cuda = ggml_backend_alloc_buffer(&model.backend_cuda, ctx_size, gpu_num_tensors);
struct ggml_init_params params = ggml_init_params_default();
params.buffer = &model.buf_cuda;
model.ctx_cuda = ggml_init(params);
@ -990,30 +1024,6 @@ static void llama_model_load_internal(
ctx_gpu = model.ctx_cuda;
}
#endif
if ((uint32_t)n_gpu_layers > n_layer) {
model.backend_input = backend_gpu;
} else {
model.backend_input = &model.backend_cpu;
}
if (n_gpu_layers > 0) {
model.backend_output = backend_gpu;
} else {
model.backend_output = &model.backend_cpu;
}
// assign splits to the backends
const int i_gpu_start = n_layer - n_gpu_layers;
model.backend_layers.resize(n_layer);
for (int i = 0; i < (int)n_layer; ++i) {
struct ggml_backend * layer_backend;
if (i >= i_gpu_start) {
layer_backend = backend_gpu;
} else {
layer_backend = &model.backend_cpu;
}
model.backend_layers[i] = layer_backend;
}
// TODO: clean this
ggml_context * ctx_input = model.ctx_cpu;
@ -1074,10 +1084,16 @@ static void llama_model_load_internal(
{
const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1;
// FIXME: this is not very useful without knowing the CPU/GPU memory split
// this is the total memory required to run the inference
size_t ctx_sum = mmap_size;
for (const auto & it : ctx_sizes) {
ctx_sum += it.second;
}
const size_t mem_required =
ctx_size + mmapped_size +
MEM_REQ_EVAL().at (model.type);
ctx_sum + MEM_REQ_EVAL().at(model.type);
// this is the memory required by one llama_state
const size_t mem_required_state =
@ -1085,11 +1101,9 @@ static void llama_model_load_internal(
fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__,
mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0);
}
// populate `tensors_by_name`
// populate tensors_by_name
for (llama_load_tensor & lt : ml->tensors_map.tensors) {
model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor);
}
@ -1140,7 +1154,7 @@ static ggml_graph_splits llama_build_graph(
const int n_tokens,
const int n_past,
bool embeddings_input = false,
ggml_type compute_type = DEFAULT_COMPUTE_TYPE) {
ggml_type compute_type = LLAMA_DEFAULT_COMPUTE_TYPE) {
// const int64_t t_start_us = ggml_time_us();
@ -1164,15 +1178,12 @@ static ggml_graph_splits llama_build_graph(
const float freq_scale = hparams.rope_freq_scale;
//auto & mem_per_token = lctx.mem_per_token;
struct ggml_graph_splits splits = ggml_graph_split_init();
// initialize contexts for every backend
struct ggml_context * ctx_cpu = nullptr;
// TODO: don't create context if there are no CPU layers
{
if (lctx.buf_compute_cpu.mem_size > 0) {
struct ggml_init_params params = ggml_init_params_default();
params.buffer = &lctx.buf_compute_cpu;
params.compute_type = compute_type;
@ -1181,8 +1192,7 @@ static ggml_graph_splits llama_build_graph(
#ifdef GGML_USE_CUDA
struct ggml_context * ctx_cuda = nullptr;
// TODO: don't create context if there are no CUDA layers
if (lctx.model.n_gpu_layers > 0) {
if (lctx.buf_compute_cuda.mem_size > 0) {
struct ggml_init_params params = ggml_init_params_default();
params.buffer = &lctx.buf_compute_cuda;
params.compute_type = compute_type;
@ -1436,10 +1446,12 @@ static ggml_graph_splits llama_build_graph(
if (embeddings != nullptr) {
// TODO: fix this, only the last embedding has to be copied
LLAMA_ASSERT(false);
ggml_cpy(ctx_o, cur, embeddings);
cur = ggml_cpy(ctx_o, cur, embeddings);
}
}
// TODO: skip output layer when using embeddings?
// lm_head
cur = ggml_mul_mat(ctx_o, model.output, cur);
ggml_set_name(cur, "result_output");
@ -1473,10 +1485,6 @@ static ggml_graph_splits llama_build_graph(
}
#endif
//if (mem_per_token == 0) {
// mem_per_token = ggml_used_mem(ctx0)/N;
//}
#if 0
printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__,
ggml_used_mem(ctx0)/1024.0/1024.0,
@ -1538,17 +1546,14 @@ static bool llama_eval_internal(
struct ggml_graph_splits splits = llama_build_graph(lctx, N, n_past, embd_input);
// TODO: use backend functions
if (tokens != nullptr) {
// copy the tokens to the input tensor
ggml_backend_set_tensor(lctx.graph_tokens_in, tokens, 0, N*ggml_element_size(lctx.graph_tokens_in));
ggml_backend_set_tensor_async(lctx.graph_tokens_in, tokens, 0, N*ggml_element_size(lctx.graph_tokens_in));
} else {
// copy the embeddings to the input tensor
ggml_backend_set_tensor(lctx.graph_embeddings_in, embd, 0, N*n_embd*ggml_element_size(lctx.graph_embeddings_in));
ggml_backend_set_tensor_async(lctx.graph_embeddings_in, embd, 0, N*n_embd*ggml_element_size(lctx.graph_embeddings_in));
}
// run the computation
ggml_graph_splits_compute(&splits);
ggml_graph_splits_free(&splits);
@ -2702,14 +2707,15 @@ struct llama_context * llama_new_context_with_model(
}
}
printf("input: %s, ", ggml_backend_name(ctx->model.backend_input));
fprintf(stderr, "%s: layer backends: ", __func__);
fprintf(stderr, "input: %s, ", ggml_backend_name(ctx->model.backend_input));
for (int i = 0; i < (int)ctx->model.hparams.n_layer; i++) {
if (i == 0 || ctx->model.backend_layers[i] != ctx->model.backend_layers[i-1]) {
printf("layer %d: %s, ", i, ggml_backend_name(ctx->model.backend_layers[i]));
fprintf(stderr, "layer %d: %s, ", i, ggml_backend_name(ctx->model.backend_layers[i]));
}
}
printf("output: %s, ", ggml_backend_name(ctx->model.backend_output));
printf("kv: %s\n", ggml_backend_name(ctx->backend_kv));
fprintf(stderr, "output: %s, ", ggml_backend_name(ctx->model.backend_output));
fprintf(stderr, "kv: %s\n", ggml_backend_name(ctx->backend_kv));
#ifdef GGML_USE_MPI
ctx->ctx_mpi = ggml_mpi_init();