mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
llama.cpp: remove backend-specific code where possible
This commit is contained in:
parent
295f85654a
commit
77ac8deaf1
204
llama.cpp
204
llama.cpp
@ -225,14 +225,15 @@ struct llama_model {
|
||||
llama_vocab vocab;
|
||||
|
||||
// backends
|
||||
struct backend_data {
|
||||
ggml_backend * backend;
|
||||
ggml_buffer * buf;
|
||||
ggml_context * ctx;
|
||||
};
|
||||
std::vector<backend_data> backends;
|
||||
// default backends for CPU and GPU
|
||||
ggml_backend * backend_cpu = NULL;
|
||||
ggml_buffer * buf_cpu = NULL;
|
||||
ggml_context * ctx_cpu = NULL;
|
||||
#ifdef GGML_USE_CUDA
|
||||
ggml_backend * backend_cuda = NULL;
|
||||
ggml_buffer * buf_cuda = NULL;
|
||||
ggml_context * ctx_cuda = NULL;
|
||||
#endif
|
||||
ggml_backend * backend_gpu = NULL;
|
||||
|
||||
// backend assigned to each layer
|
||||
ggml_backend * backend_inp = NULL;
|
||||
@ -240,16 +241,11 @@ struct llama_model {
|
||||
std::vector<ggml_backend *> backend_layers;
|
||||
|
||||
~llama_model() {
|
||||
if (ctx_cpu) {
|
||||
ggml_free(ctx_cpu);
|
||||
ggml_buffer_free(buf_cpu);
|
||||
for (auto & b : backends) {
|
||||
ggml_free(b.ctx);
|
||||
ggml_buffer_free(b.buf);
|
||||
ggml_backend_free(b.backend);
|
||||
}
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (ctx_cuda) {
|
||||
ggml_free(ctx_cuda);
|
||||
ggml_buffer_free(buf_cuda);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -286,10 +282,7 @@ struct llama_context {
|
||||
std::vector<float> embedding;
|
||||
|
||||
// memory buffers used to evaluate the model
|
||||
ggml_buffer * buf_compute_cpu;
|
||||
#ifdef GGML_USE_CUDA
|
||||
ggml_buffer * buf_compute_cuda;
|
||||
#endif
|
||||
std::vector<ggml_buffer *> bufs_compute;
|
||||
|
||||
// input tensors
|
||||
struct ggml_tensor * graph_tokens_in = nullptr;
|
||||
@ -308,8 +301,8 @@ struct llama_context {
|
||||
if (model_owner) {
|
||||
delete &model;
|
||||
}
|
||||
if (buf_compute_cpu) {
|
||||
ggml_buffer_free(buf_compute_cpu);
|
||||
if (ggml_buffer * buf : bufs_compute) {
|
||||
ggml_buffer_free(buf);
|
||||
}
|
||||
}
|
||||
*/
|
||||
@ -960,25 +953,27 @@ static void llama_model_load_internal(
|
||||
return;
|
||||
}
|
||||
|
||||
// initialize backends
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
|
||||
// initialize backends
|
||||
model.backend_cpu = ggml_backend_cpu_init();
|
||||
ggml_backend * backend_gpu = model.backend_cpu; // hack until we have a proper backend selection
|
||||
model.backends.push_back({model.backend_cpu, nullptr, nullptr});
|
||||
model.backend_gpu = model.backend_cpu; // default to CPU if no GPU backends are available
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (n_gpu_layers > 0) {
|
||||
model.backend_cuda = ggml_backend_cuda_init();
|
||||
backend_gpu = model.backend_cuda;
|
||||
ggml_backend * backend_cuda = ggml_backend_cuda_init();
|
||||
model.backends.push_back({backend_cuda, nullptr, nullptr});
|
||||
model.backend_gpu = backend_cuda;
|
||||
}
|
||||
#endif
|
||||
|
||||
// assign splits to the backends
|
||||
const int i_gpu_start = std::max(0, (int)n_layer - n_gpu_layers);
|
||||
model.backend_inp = n_gpu_layers > (int)n_layer ? backend_gpu : model.backend_cpu;
|
||||
model.backend_out = n_gpu_layers > 0 ? backend_gpu : model.backend_cpu;
|
||||
model.backend_inp = n_gpu_layers > (int)n_layer ? model.backend_gpu : model.backend_cpu;
|
||||
model.backend_out = n_gpu_layers > 0 ? model.backend_gpu : model.backend_cpu;
|
||||
model.backend_layers.resize(n_layer);
|
||||
std::fill(model.backend_layers.begin(), model.backend_layers.begin() + i_gpu_start, model.backend_cpu);
|
||||
std::fill(model.backend_layers.begin() + i_gpu_start, model.backend_layers.end(), backend_gpu);
|
||||
std::fill(model.backend_layers.begin() + i_gpu_start, model.backend_layers.end(), model.backend_gpu);
|
||||
|
||||
// calculate the size of each context
|
||||
std::unordered_map<struct ggml_backend *, size_t> ctx_sizes;
|
||||
@ -1008,6 +1003,7 @@ static void llama_model_load_internal(
|
||||
ctx_sizes[model.backend_cpu] = 0;
|
||||
}
|
||||
|
||||
// print the context sizes
|
||||
fprintf(stderr, "%s: ggml ctx sizes:\n", __func__);
|
||||
for (const auto & it : ctx_sizes) {
|
||||
fprintf(stderr, "%8s = %7.2f MB", ggml_backend_name(it.first), it.second / 1024.0 / 1024.0);
|
||||
@ -1017,45 +1013,34 @@ static void llama_model_load_internal(
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
// create the buffers and contexts
|
||||
{
|
||||
size_t cpu_num_tensors = ml->tensors_map.tensors.size();
|
||||
size_t ctx_size = ctx_sizes[model.backend_cpu];
|
||||
model.buf_cpu = ggml_buffer_alloc(model.backend_cpu, ctx_size, cpu_num_tensors);
|
||||
// create the buffers and contexts for each backend
|
||||
for (auto & backend_data : model.backends) {
|
||||
ggml_backend * backend = backend_data.backend;
|
||||
size_t num_tensors = ml->tensors_map.tensors.size();
|
||||
size_t ctx_size = ctx_sizes[backend];
|
||||
|
||||
backend_data.buf = ggml_buffer_alloc(backend, ctx_size, num_tensors);
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = model.buf_cpu;
|
||||
params.no_alloc = ml->use_mmap;
|
||||
model.ctx_cpu = ggml_init(params);
|
||||
if (!model.ctx_cpu) {
|
||||
throw std::runtime_error(format("ggml_init() failed for CPU backend"));
|
||||
params.buffer = backend_data.buf;
|
||||
params.no_alloc = backend == model.backend_cpu && ml->use_mmap;
|
||||
backend_data.ctx = ggml_init(params);
|
||||
if (!backend_data.ctx) {
|
||||
throw std::runtime_error(format("ggml_init() failed for backend context"));
|
||||
}
|
||||
}
|
||||
|
||||
ggml_context * ctx_gpu = model.ctx_cpu;
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (n_gpu_layers > 0) {
|
||||
size_t gpu_num_tensors = ml->tensors_map.tensors.size();
|
||||
size_t ctx_size = ctx_sizes[model.backend_cuda];
|
||||
model.buf_cuda = ggml_buffer_alloc(model.backend_cuda, ctx_size, gpu_num_tensors);
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = model.buf_cuda;
|
||||
model.ctx_cuda = ggml_init(params);
|
||||
if (!model.ctx_cuda) {
|
||||
throw std::runtime_error(format("ggml_init() failed for CUDA backend"));
|
||||
}
|
||||
ctx_gpu = model.ctx_cuda;
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: clean this
|
||||
ggml_context * ctx_input = model.ctx_cpu;
|
||||
if (model.backend_inp == backend_gpu) ctx_input = ctx_gpu;
|
||||
ggml_context * ctx_output = model.ctx_cpu;
|
||||
if (model.backend_out == backend_gpu) ctx_output = ctx_gpu;
|
||||
std::vector<ggml_context *> ctx_layers(n_layer, model.ctx_cpu);
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
if (model.backend_layers[i] == backend_gpu) {
|
||||
ctx_layers[i] = ctx_gpu;
|
||||
// find the contexts for each layer
|
||||
ggml_context * ctx_input = nullptr;
|
||||
ggml_context * ctx_output = nullptr;
|
||||
std::vector<ggml_context *> ctx_layers(n_layer, nullptr);
|
||||
for (auto & backend_data : model.backends) {
|
||||
ggml_backend * backend = backend_data.backend;
|
||||
if (backend == model.backend_inp) { ctx_input = backend_data.ctx; }
|
||||
if (backend == model.backend_out) { ctx_output = backend_data.ctx; }
|
||||
for (uint32_t i = 0; i < n_layer; ++i) {
|
||||
if (backend == model.backend_layers[i]) {
|
||||
ctx_layers[i] = backend_data.ctx;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1199,52 +1184,33 @@ static ggml_graph_splits llama_build_graph(
|
||||
const float freq_base = hparams.rope_freq_base;
|
||||
const float freq_scale = hparams.rope_freq_scale;
|
||||
|
||||
|
||||
struct ggml_graph_splits splits = ggml_graph_split_init();
|
||||
|
||||
// initialize contexts for every backend
|
||||
// initialize contexts for each backend
|
||||
|
||||
struct ggml_context * ctx_cpu = nullptr;
|
||||
if (lctx.buf_compute_cpu != nullptr) {
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = lctx.buf_compute_cpu;
|
||||
params.compute_type = compute_type;
|
||||
ctx_cpu = ggml_init(params);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
struct ggml_context * ctx_cuda = nullptr;
|
||||
if (lctx.buf_compute_cuda != nullptr) {
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = lctx.buf_compute_cuda;
|
||||
params.compute_type = compute_type;
|
||||
ctx_cuda = ggml_init(params);
|
||||
}
|
||||
#endif
|
||||
|
||||
// TODO: clean this
|
||||
struct ggml_context * ctx_i = nullptr;
|
||||
struct ggml_context * ctx_ls[80] = {nullptr};
|
||||
struct ggml_context * ctx_o = nullptr;
|
||||
struct ggml_context * ctx_kv = nullptr;
|
||||
// TODO: reuse vectors to avoid allocations
|
||||
std::vector<ggml_context *> ctx_ls(n_layer);
|
||||
std::vector<struct ggml_context *> ctxs;
|
||||
|
||||
if (lctx.model.backend_inp == lctx.model.backend_cpu) ctx_i = ctx_cpu;
|
||||
if (lctx.model.backend_out == lctx.model.backend_cpu) ctx_o = ctx_cpu;
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (lctx.model.backend_inp == lctx.model.backend_cuda) ctx_i = ctx_cuda;
|
||||
if (lctx.model.backend_out == lctx.model.backend_cuda) ctx_o = ctx_cuda;
|
||||
#endif
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
if (lctx.model.backend_layers[il] == lctx.model.backend_cpu) ctx_ls[il] = ctx_cpu;
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (lctx.model.backend_layers[il] == lctx.model.backend_cuda) ctx_ls[il] = ctx_cuda;
|
||||
#endif
|
||||
for (ggml_buffer * buf_compute : lctx.bufs_compute) {
|
||||
struct ggml_init_params params = ggml_init_params_default();
|
||||
params.buffer = buf_compute;
|
||||
params.compute_type = compute_type;
|
||||
ggml_context * ctx_buf = ggml_init(params);
|
||||
ctxs.push_back(ctx_buf);
|
||||
|
||||
ggml_backend * buf_backend = buf_compute->backend_buffer->backend;
|
||||
|
||||
if (buf_backend == lctx.model.backend_inp) { ctx_i = ctx_buf; }
|
||||
if (buf_backend == lctx.model.backend_out) { ctx_o = ctx_buf; }
|
||||
if (buf_backend == lctx.backend_kv) { ctx_kv = ctx_buf; };
|
||||
for (int il = 0; il < n_layer; il++) {
|
||||
if (buf_backend == lctx.model.backend_layers[il]) { ctx_ls[il] = ctx_buf; }
|
||||
}
|
||||
}
|
||||
if (lctx.backend_kv == lctx.model.backend_cpu) ctx_kv = ctx_cpu;
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (lctx.backend_kv == lctx.model.backend_cuda) ctx_kv = ctx_cuda;
|
||||
#endif
|
||||
|
||||
|
||||
struct ggml_tensor * inpL;
|
||||
|
||||
@ -1517,14 +1483,10 @@ static ggml_graph_splits llama_build_graph(
|
||||
//int64_t t_end_us = ggml_time_us();
|
||||
//fprintf(stderr, "%s: time = %.3f ms\n", __func__, (t_end_us-t_start_us)/1000.0);
|
||||
|
||||
if (ctx_cpu != nullptr) {
|
||||
ggml_free(ctx_cpu);
|
||||
|
||||
for (ggml_context * ctx : ctxs) {
|
||||
ggml_free(ctx);
|
||||
}
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (ctx_cuda != nullptr) {
|
||||
ggml_free(ctx_cuda);
|
||||
}
|
||||
#endif
|
||||
|
||||
return splits;
|
||||
}
|
||||
@ -1564,7 +1526,8 @@ static bool llama_eval_internal(
|
||||
// for big prompts, if BLAS is enabled, it is better to use only one thread
|
||||
// otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance
|
||||
n_threads = N >= 32 && ggml_cpu_has_blas() ? 1 : n_threads;
|
||||
ggml_backend_cpu_set_n_threads(const_cast<ggml_backend*>(model.backend_cpu), n_threads);
|
||||
// TODO: fix this - probably should be set during the model creation
|
||||
// ggml_backend_cpu_set_n_threads(const_cast<ggml_backend*>(model.backend_cpu), n_threads);
|
||||
|
||||
struct ggml_graph_splits splits = llama_build_graph(lctx, N, n_past, embd_input);
|
||||
|
||||
@ -1616,12 +1579,8 @@ static bool llama_eval_internal(
|
||||
ggml_backend_tensor_get_async(lctx.graph_embeddings_out, embedding_out.data(), 0, n_embd*sizeof(float));
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUDA
|
||||
// wait for the async copy to finish
|
||||
if (lctx.model.n_gpu_layers > 0) {
|
||||
ggml_backend_synchronize(const_cast<ggml_backend*>(lctx.model.backend_cuda));
|
||||
}
|
||||
#endif
|
||||
ggml_backend_synchronize(const_cast<ggml_backend*>(lctx.model.backend_out));
|
||||
|
||||
// measure the performance only for the single-token evals
|
||||
if (N == 1) {
|
||||
@ -2633,7 +2592,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
// TODO: choose backend depending on n_layers/low_vram
|
||||
#ifdef GGML_USE_CUDA
|
||||
if ((uint32_t)params.n_gpu_layers >= model->hparams.n_layer/2 && !params.low_vram) {
|
||||
ctx->backend_kv = model->backend_cuda;
|
||||
ctx->backend_kv = model->backend_gpu;
|
||||
} else {
|
||||
ctx->backend_kv = model->backend_cpu;
|
||||
}
|
||||
@ -2662,15 +2621,16 @@ struct llama_context * llama_new_context_with_model(
|
||||
ctx->embedding.resize(hparams.n_embd);
|
||||
}
|
||||
|
||||
// initialize compute buffers
|
||||
// TODO: size the buffers more accurately - depends on improved memory management
|
||||
ctx->buf_compute_cpu = ggml_buffer_alloc(model->backend_cpu, MEM_REQ_EVAL().at(ctx->model.type), 2048);
|
||||
// TODO: skip if no cpu layers
|
||||
for (auto & backend_data : model->backends) {
|
||||
ggml_buffer * buf_compute = ggml_buffer_alloc(backend_data.backend, MEM_REQ_EVAL().at(ctx->model.type), 2048);
|
||||
ctx->bufs_compute.push_back(buf_compute);
|
||||
}
|
||||
// TODO: pinned memory for faster host-device transfers
|
||||
//ggml_cuda_host_register(*(void**)ctx->buf_compute_cpu.backend_buffer, MEM_REQ_EVAL().at(ctx->model.type) + 128*2048);
|
||||
#ifdef GGML_USE_CUDA
|
||||
if (params.n_gpu_layers > 0) {
|
||||
ctx->buf_compute_cuda = ggml_buffer_alloc(model->backend_cuda, MEM_REQ_EVAL().at(ctx->model.type), 2048);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
// initialize the graph input/output buffers
|
||||
// input buffer
|
||||
|
Loading…
Reference in New Issue
Block a user