This commit is contained in:
Jared Van Bortel 2024-01-10 11:29:04 -05:00
parent 3773e1afe7
commit 1eb8804c18
18 changed files with 2183 additions and 2081 deletions

View File

@ -543,9 +543,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params.n_gpu_layers = std::stoi(argv[i]);
#else
#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
@ -554,9 +553,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params.n_gpu_layers_draft = std::stoi(argv[i]);
#else
#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
@ -565,25 +563,44 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS
params.main_gpu = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
#endif
#ifndef GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the main GPU has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--split-mode" || arg == "-sm") {
if (++i >= argc) {
invalid_param = true;
break;
}
std::string arg_next = argv[i];
if (arg_next == "none") {
params.split_mode = LLAMA_SPLIT_NONE;
} else if (arg_next == "layer") {
params.split_mode = LLAMA_SPLIT_LAYER;
} else if (arg_next == "row") {
params.split_mode = LLAMA_SPLIT_ROW;
} else {
invalid_param = true;
break;
}
#ifndef GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--tensor-split" || arg == "-ts") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS
std::string arg_next = argv[i];
// split string by , and /
const std::regex regex{R"([,/]+)"};
std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
std::vector<std::string> split_arg{it, {}};
GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
if (split_arg.size() >= LLAMA_MAX_DEVICES) {
invalid_param = true;
break;
}
for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
if (i < split_arg.size()) {
params.tensor_split[i] = std::stof(split_arg[i]);
@ -591,14 +608,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.tensor_split[i] = 0.0f;
}
}
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mul-mat-q" || arg == "-nommq") {
#ifdef GGML_USE_CUBLAS
params.mul_mat_q = false;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Disabling mul_mat_q kernels has no effect.\n");
#ifndef GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting a tensor split has no effect.\n");
#endif // GGML_USE_CUBLAS
} else if (arg == "--no-mmap") {
params.use_mmap = false;
@ -909,14 +920,15 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" number of layers to store in VRAM\n");
printf(" -ngld N, --n-gpu-layers-draft N\n");
printf(" number of layers to store in VRAM for the draft model\n");
printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
printf(" how to split the model across multiple GPUs, one of:\n");
printf(" - none: use one GPU only\n");
printf(" - layer (default): split layers and KV across GPUs\n");
printf(" - row: split rows across GPUs\n");
printf(" -ts SPLIT --tensor-split SPLIT\n");
printf(" how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
#ifdef GGML_USE_CUBLAS
printf(" -nommq, --no-mul-mat-q\n");
printf(" use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
printf(" Not recommended since this is both slower and uses more VRAM.\n");
#endif // GGML_USE_CUBLAS
printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
#endif
printf(" -gan N, --grp-attn-n N\n");
printf(" group-attention factor (default: %d)\n", params.grp_attn_n);
@ -1033,6 +1045,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
mparams.n_gpu_layers = params.n_gpu_layers;
}
mparams.main_gpu = params.main_gpu;
mparams.split_mode = params.split_mode;
mparams.tensor_split = params.tensor_split;
mparams.use_mmap = params.use_mmap;
mparams.use_mlock = params.use_mlock;

View File

@ -59,6 +59,7 @@ struct gpt_params {
float p_split = 0.1f; // speculative decoding split probability
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
int32_t n_beams = 0; // if non-zero then use beam search of given width.

View File

@ -88,7 +88,10 @@ int main(int argc, char ** argv) {
llama_model_params model_params = llama_model_default_params();
const std::vector<float> t_split (LLAMA_MAX_DEVICES, 0.0f);
model_params.n_gpu_layers = n_gpu_layers;
model_params.tensor_split = t_split.data();
llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);

View File

@ -229,6 +229,7 @@ void ggml_tallocr_reset(ggml_tallocr_t alloc) {
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
} else {
alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
ggml_backend_buffer_reset(alloc->buffer);
}
}
@ -779,10 +780,21 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
if (nbytes == 0) {
// all the tensors in the context are already allocated
#ifndef NDEBUG
fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__);
#endif
return NULL;
}
ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, nbytes);
if (buffer == NULL) {
// failed to allocate buffer
#ifndef NDEBUG
fprintf(stderr, "%s: failed to allocate buffer\n", __func__);
#endif
return NULL;
}
ggml_tallocr_t tallocr = ggml_tallocr_new_from_buffer(buffer);
for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {

View File

@ -16,9 +16,10 @@ extern "C" {
typedef void * ggml_backend_buffer_type_context_t;
struct ggml_backend_buffer_type_i {
const char * (*get_name) (ggml_backend_buffer_type_t buft);
ggml_backend_buffer_t (*alloc_buffer) (ggml_backend_buffer_type_t buft, size_t size);
size_t (*get_alignment) (ggml_backend_buffer_type_t buft); // tensor alignment
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
size_t (*get_alloc_size) (ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor); // data size needed to allocate the tensor, including padding
bool (*supports_backend)(ggml_backend_buffer_type_t buft, ggml_backend_t backend); // check if the buffer type is usable by the backend
// check if tensor data is in host memory
// should be equivalent to supports_backend(buft, ggml_backend_cpu_init())
@ -34,16 +35,17 @@ extern "C" {
typedef void * ggml_backend_buffer_context_t;
struct ggml_backend_buffer_i {
const char * (*get_name) (ggml_backend_buffer_t buffer);
void (*free_buffer) (ggml_backend_buffer_t buffer);
//void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
void * (*get_base) (ggml_backend_buffer_t buffer);
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
void (*set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) copy tensor between different buffer-type, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_from)(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst);
void (*clear) (ggml_backend_buffer_t buffer, uint8_t value);
void (*reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras
};
struct ggml_backend_buffer {
@ -51,6 +53,7 @@ extern "C" {
ggml_backend_buffer_type_t buft;
ggml_backend_buffer_context_t context;
size_t size;
enum ggml_backend_buffer_usage usage;
};
ggml_backend_buffer_t ggml_backend_buffer_init(
@ -79,13 +82,13 @@ extern "C" {
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// (optional) asynchroneous tensor copy
void (*cpy_tensor_from_async)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to_async) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_from_async)(ggml_backend_t backend, const struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to_async) (ggml_backend_t backend, const struct ggml_tensor * src, struct ggml_tensor * dst);
void (*synchronize)(ggml_backend_t backend);
// compute graph with a plan
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, const struct ggml_cgraph * cgraph);
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);

View File

@ -15,6 +15,10 @@
// backend buffer type
const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name(buft);
}
ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
return buft->iface.alloc_buffer(buft, size);
}
@ -58,11 +62,16 @@ ggml_backend_buffer_t ggml_backend_buffer_init(
/* .buft = */ buft,
/* .context = */ context,
/* .size = */ size,
/* .usage = */ GGML_BACKEND_BUFFER_USAGE_ANY
};
return buffer;
}
const char * ggml_backend_buffer_name(ggml_backend_buffer_t buffer) {
return buffer->iface.get_name(buffer);
}
void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
if (buffer == NULL) {
return;
@ -94,11 +103,11 @@ void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_t
}
size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer) {
return ggml_backend_buft_get_alignment(ggml_backend_buffer_type(buffer));
return ggml_backend_buft_get_alignment(ggml_backend_buffer_get_type(buffer));
}
size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type(buffer), tensor);
return ggml_backend_buft_get_alloc_size(ggml_backend_buffer_get_type(buffer), tensor);
}
void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
@ -106,13 +115,23 @@ void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
}
bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
return ggml_backend_buft_is_host(ggml_backend_buffer_type(buffer));
return ggml_backend_buft_is_host(ggml_backend_buffer_get_type(buffer));
}
ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer) {
void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
buffer->usage = usage;
}
ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
return buffer->buft;
}
void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
if (buffer->iface.reset) {
buffer->iface.reset(buffer);
}
}
// backend
const char * ggml_backend_name(ggml_backend_t backend) {
@ -392,6 +411,12 @@ ggml_backend_buffer_t ggml_backend_reg_alloc_buffer(size_t i, size_t size) {
// backend CPU
static const char * ggml_backend_cpu_buffer_name(ggml_backend_buffer_t buffer) {
return "CPU";
GGML_UNUSED(buffer);
}
static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
return (void *)buffer->context;
}
@ -412,13 +437,13 @@ static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, con
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
static void ggml_backend_cpu_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
GGML_UNUSED(buffer);
}
static void ggml_backend_cpu_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
static void ggml_backend_cpu_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
GGML_UNUSED(buffer);
@ -429,6 +454,7 @@ static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t
}
static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
/* .get_name = */ ggml_backend_cpu_buffer_name,
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required
@ -437,10 +463,12 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
/* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
/* .cpy_tensor_to = */ ggml_backend_cpu_buffer_cpy_tensor_to,
/* .clear = */ ggml_backend_cpu_buffer_clear,
/* .reset = */ NULL,
};
// for buffers from ptr, free is not called
static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
/* .get_name = */ ggml_backend_cpu_buffer_name,
/* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required
@ -449,10 +477,17 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
/* .cpy_tensor_from = */ ggml_backend_cpu_buffer_cpy_tensor_from,
/* .cpy_tensor_to = */ ggml_backend_cpu_buffer_cpy_tensor_to,
/* .clear = */ ggml_backend_cpu_buffer_clear,
/* .reset = */ NULL,
};
static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU";
GGML_UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
@ -483,6 +518,7 @@ static bool ggml_backend_cpu_buffer_type_is_host(ggml_backend_buffer_type_t buft
ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
@ -501,6 +537,18 @@ ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void) {
#include <hbwmalloc.h>
static const char * ggml_backend_cpu_hbm_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "CPU_HBM";
GGML_UNUSED(buft);
}
static const char * ggml_backend_cpu_hbm_buffer_get_name(ggml_backend_buffer_t buf) {
return "CPU_HBM";
GGML_UNUSED(buf);
}
static void ggml_backend_cpu_hbm_buffer_free_buffer(ggml_backend_buffer_t buffer) {
hbw_free(buffer->context);
}
@ -514,17 +562,18 @@ static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_
return NULL;
}
// FIXME: this is a hack to avoid having to implement a new buffer type
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
buffer->buft = buft;
buffer->iface.get_name = ggml_backend_cpu_hbm_buffer_get_name;
buffer->iface.free_buffer = ggml_backend_cpu_hbm_buffer_free_buffer;
return buffer;
}
ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type() {
ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_hbm = {
/* .iface = */ {
/* .get_name = */ ggml_backend_cpu_hbm_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_cpu_hbm_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_buffer_type_get_alignment,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
@ -568,7 +617,7 @@ struct ggml_backend_plan_cpu {
struct ggml_cgraph cgraph;
};
static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu));
@ -661,7 +710,7 @@ ggml_backend_t ggml_backend_cpu_init(void) {
}
bool ggml_backend_is_cpu(ggml_backend_t backend) {
return backend->iface.get_name == ggml_backend_cpu_name;
return backend && backend->iface.get_name == ggml_backend_cpu_name;
}
void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
@ -685,7 +734,7 @@ static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user
// scheduler
#define GGML_MAX_BACKENDS 4
#define GGML_MAX_BACKENDS 16
#define GGML_MAX_SPLITS 256
#define GGML_MAX_SPLIT_INPUTS 16
@ -695,9 +744,16 @@ struct ggml_backend_sched_split {
int i_end;
struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
int n_inputs;
// graph view of this split
struct ggml_cgraph graph;
};
// TODO: group all the hash values into a single struct for clarity
//struct sched_hash_value {
// ggml_tallocr_t tallocr;
// struct ggml_tensor * copies[GGML_MAX_BACKENDS];
//};
struct ggml_backend_sched {
int n_backends;
ggml_backend_t backends[GGML_MAX_BACKENDS];
@ -705,11 +761,15 @@ struct ggml_backend_sched {
ggml_gallocr_t galloc;
// hash keys of the nodes in the graph
struct ggml_hash_set hash_set;
ggml_tallocr_t * node_talloc; // [hash_set.size]
struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // [hash_set.size][GGML_MAX_BACKENDS]
// hash values (arrays of [hash_set.size])
ggml_tallocr_t * node_talloc; // tallocr assigned to each node (indirectly this is the backend)
struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // copies of each node for each destination backend
// copy of the graph with modified inputs
struct ggml_cgraph * graph;
struct ggml_backend_sched_split splits[GGML_MAX_SPLITS];
int n_splits;
@ -777,7 +837,7 @@ static ggml_backend_t get_allocr_backend(ggml_backend_sched_t sched, ggml_talloc
}
#if 0
static char causes[GGML_DEFAULT_GRAPH_SIZE*8 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
#define SET_CAUSE(node, ...) sprintf(causes[hash_id(node)], __VA_ARGS__)
#define GET_CAUSE(node) causes[hash_id(node)]
#else
@ -790,6 +850,7 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
// if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
// ie. kv cache updates
// note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
// dst
ggml_backend_t cur_backend = get_buffer_backend(sched, node->buffer);
if (cur_backend != NULL) {
@ -804,7 +865,6 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
}
// src
int cur_prio = INT_MAX;
size_t cur_size = 0;
for (int i = 0; i < GGML_MAX_SRC; i++) {
@ -812,18 +872,22 @@ static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct
if (src == NULL) {
break;
}
ggml_backend_t src_backend = get_buffer_backend(sched, src->buffer);
if (src_backend != NULL) {
int src_prio = sched_backend_prio(sched, src_backend);
if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
// operations with weights are always on the same backend as the weights
cur_backend = src_backend;
SET_CAUSE(node, "1.wgt%d", i);
break;
}
size_t src_size = ggml_nbytes(src);
if (src_prio < cur_prio && src_size >= cur_size) {
cur_prio = src_prio;
if (src_size >= cur_size) {
cur_size = src_size;
cur_backend = src_backend;
SET_CAUSE(node, "1.src%d", i);
}
}
}
return cur_backend;
}
@ -857,7 +921,7 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra
}
ggml_tallocr_t node_allocr = node_allocr(node);
ggml_backend_t node_backend = node_allocr ? get_allocr_backend(sched, node_allocr) : NULL; // FIXME:
fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name,
fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", GET_CAUSE(node));
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
@ -866,7 +930,7 @@ static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgra
}
ggml_tallocr_t src_allocr = node_allocr(src);
ggml_backend_t src_backend = src_allocr ? get_allocr_backend(sched, src_allocr) : NULL;
fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name,
fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
}
fprintf(stderr, "\n");
@ -882,14 +946,16 @@ static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, co
return dup;
}
//#define DEBUG_PASS1
//#define DEBUG_PASS2
//#define DEBUG_PASS3
//#define DEBUG_PASS4
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
// TODO: merge passes
static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
// reset state
size_t hash_size = sched->hash_set.size;
memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
// reset splits
sched->n_splits = 0;
struct ggml_init_params params = {
@ -898,11 +964,13 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
/* .no_alloc = */ true
};
if (sched->ctx != NULL) {
ggml_free(sched->ctx);
}
sched->ctx = ggml_init(params);
if (sched->ctx == NULL) {
fprintf(stderr, "%s: failed to initialize context\n", __func__);
GGML_ASSERT(false);
}
// pass 1: assign backends to ops with allocated inputs
for (int i = 0; i < graph->n_leafs; i++) {
@ -931,45 +999,91 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
node_allocr(node) = ggml_backend_sched_get_tallocr(sched, node_backend);
}
}
//printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
#ifdef DEBUG_PASS1
fprintf(stderr, "PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
#endif
// pass 2: assign backends to ops from current assignments
// TODO:
// - reuse sched_backend_from_cur
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
ggml_tallocr_t node_allocr = node_allocr(node);
if (node_allocr == NULL) {
int cur_prio = INT_MAX;
size_t cur_size = 0;
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
break;
}
ggml_tallocr_t src_allocr = node_allocr(src);
if (src_allocr != NULL) {
int src_prio = sched_allocr_prio(sched, src_allocr);
size_t src_size = ggml_nbytes(src);
if (src_prio < cur_prio && src_size >= cur_size) {
cur_prio = src_prio;
cur_size = src_size;
node_allocr = src_allocr;
SET_CAUSE(node, "2.src%d", j);
}
}
}
if (node_allocr != NULL) {
node_allocr(node) = node_allocr;
}
}
}
//printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
// start from the end and assign the same backend to previous ops
// pass 3: assign backends to remaining src from dst (should only be leafs)
// expand gpu backends (i.e. non last prio) up and down, ignoring cpu
// thus, cpu will never be used unless weights are on cpu, or there are no gpu ops between cpu ops
// pass 2.1 expand gpu up
{
ggml_tallocr_t cur_allocr = NULL;
for (int i = graph->n_nodes - 1; i >= 0; i--) {
struct ggml_tensor * node = graph->nodes[i];
if (ggml_is_view_op(node->op)) {
continue;
}
ggml_tallocr_t node_allocr = node_allocr(node);
if (node_allocr != NULL) {
if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) {
// skip cpu
cur_allocr = NULL;
} else {
cur_allocr = node_allocr;
}
} else {
node_allocr(node) = cur_allocr;
SET_CAUSE(node, "2.cur");
}
}
}
// pass 2.2 expand gpu down
{
ggml_tallocr_t cur_allocr = NULL;
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
if (ggml_is_view_op(node->op)) {
continue;
}
ggml_tallocr_t node_allocr = node_allocr(node);
if (node_allocr != NULL) {
if (sched_allocr_prio(sched, node_allocr) == sched->n_backends - 1) {
// skip cpu
cur_allocr = NULL;
} else {
cur_allocr = node_allocr;
}
} else {
node_allocr(node) = cur_allocr;
SET_CAUSE(node, "2.cur");
}
}
}
// pass 2.3 expand rest up
{
ggml_tallocr_t cur_allocr = NULL;
for (int i = graph->n_nodes - 1; i >= 0; i--) {
struct ggml_tensor * node = graph->nodes[i];
if (ggml_is_view_op(node->op)) {
continue;
}
ggml_tallocr_t node_allocr = node_allocr(node);
if (node_allocr != NULL) {
cur_allocr = node_allocr;
} else {
node_allocr(node) = cur_allocr;
SET_CAUSE(node, "2.cur");
}
}
}
#ifdef DEBUG_PASS2
fprintf(stderr, "PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
#endif
// pass 3: assign backends to remaining src from dst and view_src
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
ggml_tallocr_t cur_allocr = node_allocr(node);
if (ggml_is_view_op(node->op) && cur_allocr == NULL) {
cur_allocr = node_allocr(node) = node_allocr(node->view_src);
SET_CAUSE(node, "3.vsrc");
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
@ -977,16 +1091,21 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
}
ggml_tallocr_t src_allocr = node_allocr(src);
if (src_allocr == NULL) {
node_allocr(src) = node_allocr;
if (src->view_src != NULL) {
// views are always on the same backend as the source
node_allocr(src) = node_allocr(src->view_src);
} else {
node_allocr(src) = cur_allocr;
}
}
}
//printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
}
#ifdef DEBUG_PASS3
fprintf(stderr, "PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
#endif
// pass 4: split graph, find tensors that need to be copied
// TODO:
// - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost
// find first backend
{
int cur_split = 0;
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
@ -1029,11 +1148,23 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
}
ggml_tallocr_t src_allocr = node_allocr(src);
if (src_allocr != node_allocr) {
// check if the input is already in the split
bool found = false;
for (int k = 0; k < sched->splits[cur_split].n_inputs; k++) {
if (sched->splits[cur_split].inputs[k] == src) {
found = true;
break;
}
}
if (!found) {
int n_inputs = sched->splits[cur_split].n_inputs++;
//printf("split %d input %d: %s (%s)\n", cur_split, n_inputs, src->name, ggml_backend_name(get_allocr_backend(sched, src_allocr)));
GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
sched->splits[cur_split].inputs[n_inputs] = (struct ggml_tensor *)src;
}
// create copies
// create a copy of the input in the split's backend
size_t id = hash_id(src);
if (sched->node_copies[id][cur_backend_id] == NULL) {
struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
@ -1048,10 +1179,12 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
}
sched->splits[cur_split].i_end = graph->n_nodes;
sched->n_splits = cur_split + 1;
}
#ifdef DEBUG_PASS4
fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
#endif
//fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout);
#if 1
#ifndef NDEBUG
// sanity check: all sources should have the same backend as the node
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
@ -1059,6 +1192,11 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
if (node_allocr == NULL) {
fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
}
if (node->view_src != NULL && node_allocr != node_allocr(node->view_src)) {
fprintf(stderr, "!!!!!!! %s has backend %s, view_src %s has backend %s\n",
node->name, node_allocr ? ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
node->view_src->name, node_allocr(node->view_src) ? ggml_backend_name(get_allocr_backend(sched, node_allocr(node->view_src))) : "NULL");
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
@ -1070,8 +1208,14 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
node->name, node_allocr ? ggml_backend_name(get_allocr_backend(sched, node_allocr)) : "NULL",
j, src->name, src_allocr ? ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL");
}
if (src->view_src != NULL && src_allocr != node_allocr(src->view_src)) {
fprintf(stderr, "!!!!!!! [src] %s has backend %s, view_src %s has backend %s\n",
src->name, src_allocr ? ggml_backend_name(get_allocr_backend(sched, src_allocr)) : "NULL",
src->view_src->name, node_allocr(src->view_src) ? ggml_backend_name(get_allocr_backend(sched, node_allocr(src->view_src))) : "NULL");
}
}
}
fflush(stderr);
#endif
// create copies of the graph for each split
@ -1085,6 +1229,7 @@ static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * g
for (int j = 0; j < split->n_inputs; j++) {
struct ggml_tensor * input = split->inputs[j];
struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
// add a dependency to the input source so that it is not freed before the copy is done
input_cpy->src[0] = input;
graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
}
@ -1121,19 +1266,20 @@ static void sched_compute_splits(ggml_backend_sched_t sched) {
struct ggml_tensor * input = split->inputs[j];
struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_backend_prio(sched, split_backend)];
if (input->buffer == NULL) {
GGML_ASSERT(false);
if (input->view_src == NULL) {
fprintf(stderr, "input %s has no buffer and no view_src\n", input->name);
exit(1);
GGML_ASSERT(false);
}
// FIXME: may need to use the sched buffer instead
ggml_backend_view_init(input->view_src->buffer, input);
}
if (input_cpy->buffer == NULL) {
fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
exit(1);
GGML_ASSERT(false);
}
//GGML_ASSERT(input->buffer->backend != input_cpy->buffer->backend);
//GGML_ASSERT(input_cpy->buffer->backend == split_backend);
// TODO: avoid this copy if it was already copied in a previous split, and the input didn't change
// this is important to avoid copying constants such as KQ_mask and inp_pos multiple times
ggml_backend_tensor_copy(input, input_cpy);
}
// ggml_backend_synchronize(split_backend);
@ -1168,13 +1314,23 @@ static void sched_reset(ggml_backend_sched_t sched) {
for (int i = 0; i < sched->n_backends; i++) {
ggml_tallocr_reset(sched->tallocs[i]);
}
// reset state for the next run
size_t hash_size = sched->hash_set.size;
memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
}
ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends) {
ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends, size_t graph_size) {
GGML_ASSERT(n_backends > 0);
GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS);
struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
memset(sched, 0, sizeof(struct ggml_backend_sched));
struct ggml_backend_sched * sched = calloc(sizeof(struct ggml_backend_sched), 1);
// initialize hash table
sched->hash_set = ggml_hash_set_new(graph_size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
sched->node_talloc = calloc(sizeof(sched->node_talloc[0]) * sched->hash_set.size, 1);
sched->node_copies = calloc(sizeof(sched->node_copies[0]) * sched->hash_set.size, 1);
sched->n_backends = n_backends;
for (int i = 0; i < n_backends; i++) {
@ -1199,6 +1355,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
ggml_tallocr_free(sched->tallocs[i]);
}
ggml_gallocr_free(sched->galloc);
ggml_free(sched->ctx);
free(sched->hash_set.keys);
free(sched->node_talloc);
free(sched->node_copies);
@ -1206,12 +1363,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
}
void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
// initialize hash tables
size_t hash_size = measure_graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS;
sched->hash_set.size = hash_size;
sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size);
sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size);
sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size);
GGML_ASSERT(ggml_tallocr_is_measure(sched->tallocs[0])); // can only be initialized once
sched_split_graph(sched, measure_graph);
sched_alloc_splits(sched);
@ -1227,7 +1379,7 @@ void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgr
}
void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
sched_split_graph(sched, graph);
sched_alloc_splits(sched);
@ -1235,13 +1387,19 @@ void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cg
sched_reset(sched);
}
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
return sched->n_splits;
}
ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) {
int backend_index = sched_backend_prio(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
return sched->tallocs[backend_index];
}
ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) {
int backend_index = sched_backend_prio(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
return ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
}
@ -1252,9 +1410,10 @@ void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml
}
// utils
void ggml_backend_view_init(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
GGML_ASSERT(tensor->buffer == NULL);
//GGML_ASSERT(tensor->data == NULL); // views of pre-allocted tensors may have the data set, but still need to be initialized
//GGML_ASSERT(tensor->data == NULL); // views of pre-allocated tensors may have the data set in ggml_new_tensor, but still need to be initialized by the backend
GGML_ASSERT(tensor->view_src != NULL);
GGML_ASSERT(tensor->view_src->buffer != NULL);
GGML_ASSERT(tensor->view_src->data != NULL);
@ -1320,6 +1479,7 @@ static void graph_init_tensor(struct ggml_hash_set hash_set, struct ggml_tensor
struct ggml_tensor * dst = node_copies[id];
if (dst->view_src != NULL) {
graph_init_tensor(hash_set, node_copies, node_init, src->view_src);
ggml_backend_view_init(dst->view_src->buffer, dst);
}
else {
@ -1353,6 +1513,21 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
struct ggml_context * ctx_allocated = ggml_init(params);
struct ggml_context * ctx_unallocated = ggml_init(params);
if (ctx_allocated == NULL || ctx_unallocated == NULL) {
fprintf(stderr, "failed to allocate context for graph copy\n");
free(hash_set.keys);
free(node_copies);
free(node_init);
ggml_free(ctx_allocated);
ggml_free(ctx_unallocated);
return (struct ggml_backend_graph_copy) {
/* .buffer = */ NULL,
/* .ctx_allocated = */ NULL,
/* .ctx_unallocated = */ NULL,
/* .graph = */ NULL,
};
}
// dup nodes
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
@ -1361,6 +1536,20 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
// allocate nodes
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
if (buffer == NULL) {
fprintf(stderr, "failed to allocate buffer for graph copy\n");
free(hash_set.keys);
free(node_copies);
free(node_init);
ggml_free(ctx_allocated);
ggml_free(ctx_unallocated);
return (struct ggml_backend_graph_copy) {
/* .buffer = */ NULL,
/* .ctx_allocated = */ NULL,
/* .ctx_unallocated = */ NULL,
/* .graph = */ NULL,
};
}
//printf("copy buffer size: %zu MB\n", ggml_backend_buffer_get_size(buffer) / 1024 / 1024);
@ -1397,8 +1586,12 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
ggml_free(copy.ctx_unallocated);
}
void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
if (copy.buffer == NULL) {
return false;
}
struct ggml_cgraph * g1 = graph;
struct ggml_cgraph * g2 = copy.graph;
@ -1428,4 +1621,6 @@ void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
}
ggml_backend_graph_copy_free(copy);
return true;
}

View File

@ -17,13 +17,20 @@ extern "C" {
//
// buffer type
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size);
GGML_API const char * ggml_backend_buft_name (ggml_backend_buffer_type_t buft);
GGML_API ggml_backend_buffer_t ggml_backend_buft_alloc_buffer (ggml_backend_buffer_type_t buft, size_t size);
GGML_API size_t ggml_backend_buft_get_alignment (ggml_backend_buffer_type_t buft);
GGML_API size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
GGML_API size_t ggml_backend_buft_get_alloc_size (ggml_backend_buffer_type_t buft, struct ggml_tensor * tensor);
GGML_API bool ggml_backend_buft_supports_backend(ggml_backend_buffer_type_t buft, ggml_backend_t backend);
GGML_API bool ggml_backend_buft_is_host (ggml_backend_buffer_type_t buft);
// buffer
enum ggml_backend_buffer_usage {
GGML_BACKEND_BUFFER_USAGE_ANY = 0,
GGML_BACKEND_BUFFER_USAGE_WEIGHTS = 1,
};
GGML_API const char * ggml_backend_buffer_name (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
@ -32,7 +39,10 @@ extern "C" {
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_clear (ggml_backend_buffer_t buffer, uint8_t value);
GGML_API bool ggml_backend_buffer_is_host (ggml_backend_buffer_t buffer);
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_type(ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_set_usage (ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage);
GGML_API ggml_backend_buffer_type_t ggml_backend_buffer_get_type (ggml_backend_buffer_t buffer);
GGML_API void ggml_backend_buffer_reset (ggml_backend_buffer_t buffer);
//
// Backend
@ -140,24 +150,23 @@ extern "C" {
typedef struct ggml_backend_sched * ggml_backend_sched_t;
// Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends);
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends, size_t graph_size);
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
// Initialize backend buffers from a measure graph
GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
// Get the number of splits of the last graph
GGML_API int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched);
GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
// Allocate a graph on the backend scheduler
// Allocate and compute graph on the backend scheduler
GGML_API void ggml_backend_sched_graph_compute(
ggml_backend_sched_t sched,
struct ggml_cgraph * graph);
//
// Utils
//
@ -176,7 +185,7 @@ extern "C" {
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends
GGML_API void ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
// Tensor initialization
GGML_API void ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);

File diff suppressed because it is too large Load Diff

View File

@ -27,22 +27,6 @@ GGML_API void * ggml_cuda_host_malloc(size_t size);
GGML_API void ggml_cuda_host_free(void * ptr);
GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers_no_alloc(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_scratch_offset(struct ggml_tensor * tensor, size_t offset);
GGML_API void ggml_cuda_copy_to_device(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_set_main_device(int main_device);
GGML_API void ggml_cuda_set_mul_mat_q(bool mul_mat_q);
GGML_API void ggml_cuda_set_scratch_size(size_t scratch_size);
GGML_API void ggml_cuda_free_scratch(void);
GGML_API bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
GGML_API int ggml_cuda_get_device_count(void);
@ -52,13 +36,17 @@ GGML_API void ggml_cuda_get_device_description(int device, char * description,
GGML_API ggml_backend_t ggml_backend_cuda_init(int device);
GGML_API bool ggml_backend_is_cuda(ggml_backend_t backend);
GGML_API int ggml_backend_cuda_get_device(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
// split tensor buffer that splits matrices by rows across multiple devices
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_split_buffer_type(const float * tensor_split);
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);
GGML_API int ggml_backend_cuda_get_device_count(void);
GGML_API void ggml_backend_cuda_get_device_description(int device, char * description, size_t description_size);
GGML_API void ggml_backend_cuda_get_device_memory(int device, size_t * free, size_t * total);
#ifdef __cplusplus
}
#endif

View File

@ -228,6 +228,8 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_HASHTABLE_FULL ((size_t)-1)
#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
struct ggml_hash_set ggml_hash_set_new(size_t size);
bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted

View File

@ -2482,10 +2482,10 @@ static void ggml_backend_metal_free_device(void) {
}
}
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
return "Metal";
return ctx->all_data;
UNUSED(buffer);
}
static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer) {
@ -2503,6 +2503,12 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
free(ctx);
}
static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) {
struct ggml_backend_metal_buffer_context * ctx = (struct ggml_backend_metal_buffer_context *)buffer->context;
return ctx->all_data;
}
static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
memcpy((char *)tensor->data + offset, data, size);
@ -2515,13 +2521,13 @@ static void ggml_backend_metal_buffer_get_tensor(ggml_backend_buffer_t buffer, c
UNUSED(buffer);
}
static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
static void ggml_backend_metal_buffer_cpy_tensor_from(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
UNUSED(buffer);
}
static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, struct ggml_tensor * src, struct ggml_tensor * dst) {
static void ggml_backend_metal_buffer_cpy_tensor_to(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
UNUSED(buffer);
@ -2534,6 +2540,7 @@ static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buffer, uint8_
}
static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
/* .get_name = */ ggml_backend_metal_buffer_get_name,
/* .free_buffer = */ ggml_backend_metal_buffer_free_buffer,
/* .get_base = */ ggml_backend_metal_buffer_get_base,
/* .init_tensor = */ NULL,
@ -2542,10 +2549,17 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = {
/* .cpy_tensor_from = */ ggml_backend_metal_buffer_cpy_tensor_from,
/* .cpy_tensor_to = */ ggml_backend_metal_buffer_cpy_tensor_to,
/* .clear = */ ggml_backend_metal_buffer_clear,
/* .reset = */ NULL,
};
// default buffer type
static const char * ggml_backend_metal_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "Metal";
UNUSED(buft);
}
static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
struct ggml_backend_metal_buffer_context * ctx = malloc(sizeof(struct ggml_backend_metal_buffer_context));
@ -2618,6 +2632,7 @@ static bool ggml_backend_metal_buffer_type_is_host(ggml_backend_buffer_type_t bu
ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_buffer_type_metal = {
/* .iface = */ {
/* .get_name = */ ggml_backend_metal_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_metal_buffer_type_get_alignment,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
@ -2641,6 +2656,14 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
ctx->n_buffers = 0;
const size_t size_page = sysconf(_SC_PAGESIZE);
// page-align the data ptr
{
const uintptr_t offs = (uintptr_t) data % size_page;
data = (void *) ((char *) data - offs);
size += offs;
}
size_t size_aligned = size;
if ((size_aligned % size_page) != 0) {
size_aligned += (size_page - (size_aligned % size_page));
@ -2741,7 +2764,7 @@ static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct
UNUSED(backend);
}
static struct ggml_backend_i metal_backend_i = {
static struct ggml_backend_i ggml_backend_metal_i = {
/* .get_name = */ ggml_backend_metal_name,
/* .free = */ ggml_backend_metal_free,
/* .get_default_buffer_type = */ ggml_backend_metal_get_default_buffer_type,
@ -2767,7 +2790,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
*metal_backend = (struct ggml_backend) {
/* .interface = */ metal_backend_i,
/* .interface = */ ggml_backend_metal_i,
/* .context = */ ctx,
};
@ -2775,7 +2798,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
}
bool ggml_backend_is_metal(ggml_backend_t backend) {
return backend->iface.get_name == ggml_backend_metal_name;
return backend && backend->iface.get_name == ggml_backend_metal_name;
}
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {

View File

@ -1,5 +1,6 @@
#include "ggml.h"
#include "ggml-opencl.h"
#include "ggml-backend-impl.h"
#include <array>
#include <atomic>
@ -10,7 +11,7 @@
#include <sstream>
#include <vector>
#define CL_TARGET_OPENCL_VERSION 110
#define CL_TARGET_OPENCL_VERSION 120
#include <clblast.h>
#if defined(_MSC_VER)
@ -929,6 +930,11 @@ static cl_program build_program_from_source(cl_context ctx, cl_device_id dev, co
}
void ggml_cl_init(void) {
static bool initialized = false;
if (initialized) {
return;
}
cl_int err;
struct cl_device;
@ -1483,8 +1489,8 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
} else {
d_X = ggml_cl_pool_malloc(sizeof(float) * x_ne, &x_size);
}
cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
cl_mem d_Y = src1->backend == GGML_BACKEND_GPU ? (cl_mem) src1->extra : ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
cl_mem d_D = dst->backend == GGML_BACKEND_GPU ? (cl_mem) dst->extra : ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
size_t x_offset = 0;
@ -1501,7 +1507,9 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
// copy src1 to device
if (src1->backend == GGML_BACKEND_CPU) {
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
}
CL_CHECK(clFinish(queue));
@ -1522,18 +1530,24 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
}
// copy dst to host
if (dst->backend == GGML_BACKEND_CPU) {
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
}
}
}
}
}
if (src0->backend != GGML_BACKEND_GPU) {
ggml_cl_pool_free(d_X, x_size);
}
if (src1->backend != GGML_BACKEND_GPU) {
ggml_cl_pool_free(d_Y, y_size);
}
if (dst->backend != GGML_BACKEND_GPU) {
ggml_cl_pool_free(d_D, d_size);
}
}
static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
@ -1598,6 +1612,8 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
}
// FIXME: convert on device
for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
// convert src1 to fp16
// TODO: use multiple threads
@ -1643,11 +1659,13 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
}
// copy dst to host, then convert to float
if (dst->backend == GGML_BACKEND_CPU) {
CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
ggml_fp16_to_fp32_row(tmp, d, d_ne);
} else {
// FIXME: convert dst to fp32 on device
}
}
}
}
@ -1801,7 +1819,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
}
bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst) {
const int64_t ne10 = src1->ne[0];
const int64_t ne0 = dst->ne[0];
@ -1895,3 +1913,292 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
tensor->extra = dst;
GGML_ASSERT(tensor->backend == GGML_BACKEND_GPU);
}
// ggml-backend
// buffer
struct ggml_backend_opencl_buffer_context {
~ggml_backend_opencl_buffer_context() {
if (buffer) {
clReleaseMemObject(buffer);
}
for (auto * sub_buffer : sub_buffers) {
clReleaseMemObject(sub_buffer);
}
}
cl_mem buffer;
std::vector<cl_mem> sub_buffers;
};
static void * const cl_ptr_base = (void *)(uintptr_t) 0x1000;
static const char * ggml_backend_opencl_buffer_get_name(ggml_backend_buffer_t buffer) {
return "OpenCL";
GGML_UNUSED(buffer);
}
static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
delete ctx;
}
static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) {
return cl_ptr_base;
GGML_UNUSED(buffer);
}
static void ggml_backend_opencl_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
if (tensor->view_src != NULL && tensor->view_offs == 0) {
tensor->extra = tensor->view_src->extra;
} else {
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
cl_buffer_region region = {(size_t)((char *)tensor->data - (char *)cl_ptr_base), ggml_nbytes(tensor)};
cl_int err;
cl_mem sub_buffer = clCreateSubBuffer(ctx->buffer, CL_MEM_READ_WRITE, CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
ctx->sub_buffers.push_back(sub_buffer);
tensor->extra = sub_buffer;
}
tensor->backend = GGML_BACKEND_GPU;
}
static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
cl_mem tensor_buffer = (cl_mem) tensor->extra;
CL_CHECK(clEnqueueWriteBuffer(queue, tensor_buffer, true, offset, size, data, 0, NULL, NULL));
CL_CHECK(clFinish(queue));
GGML_UNUSED(buffer);
}
static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
cl_mem tensor_buffer = (cl_mem) tensor->extra;
CL_CHECK(clEnqueueReadBuffer(queue, tensor_buffer, true, offset, size, data, 0, NULL, NULL));
CL_CHECK(clFinish(queue));
GGML_UNUSED(buffer);
}
static void ggml_backend_opencl_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
CL_CHECK(clEnqueueFillBuffer(queue, ctx->buffer, &value, sizeof(value), 0, buffer->size, 0, NULL, NULL));
CL_CHECK(clFinish(queue));
}
static void ggml_backend_opencl_buffer_reset(ggml_backend_buffer_t buffer) {
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
for (auto * sub_buffer : ctx->sub_buffers) {
clReleaseMemObject(sub_buffer);
}
ctx->sub_buffers.clear();
}
static ggml_backend_buffer_i ggml_backend_opencl_buffer_interface = {
/* .get_name = */ ggml_backend_opencl_buffer_get_name,
/* .free_buffer = */ ggml_backend_opencl_buffer_free_buffer,
/* .get_base = */ ggml_backend_opencl_buffer_get_base,
/* .init_tensor = */ ggml_backend_opencl_buffer_init_tensor,
/* .set_tensor = */ ggml_backend_opencl_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_opencl_buffer_get_tensor,
/* .cpy_tensor_from = */ NULL,
/* .cpy_tensor_to = */ NULL,
/* .clear = */ ggml_backend_opencl_buffer_clear,
/* .reset = */ ggml_backend_opencl_buffer_reset,
};
// buffer type
static const char * ggml_backend_opencl_buffer_type_name(ggml_backend_buffer_type_t buffer_type) {
return "OpenCL";
GGML_UNUSED(buffer_type);
}
static ggml_backend_buffer_t ggml_backend_opencl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buffer_type, size_t size) {
ggml_cl_init();
cl_int err;
cl_mem mem = clCreateBuffer(context, CL_MEM_READ_WRITE, size, NULL, &err);
if (err != CL_SUCCESS) {
fprintf(stderr, "%s: failed to allocate %.2f MiB\n", __func__, size / 1024.0 / 1024.0);
return nullptr;
}
ggml_backend_opencl_buffer_context * ctx = new ggml_backend_opencl_buffer_context{mem, {}};
return ggml_backend_buffer_init(buffer_type, ggml_backend_opencl_buffer_interface, ctx, size);
}
static size_t ggml_backend_opencl_buffer_type_get_alignment(ggml_backend_buffer_type_t buffer_type) {
// FIXME: not thread safe, device may not be initialized yet
static cl_uint alignment = -1;
if (alignment == (cl_uint)-1) {
ggml_cl_init();
clGetDeviceInfo(device, CL_DEVICE_MEM_BASE_ADDR_ALIGN, sizeof(cl_uint), &alignment, NULL);
}
return alignment;
GGML_UNUSED(buffer_type);
}
static bool ggml_backend_opencl_buffer_type_supports_backend(ggml_backend_buffer_type_t buffer_type, ggml_backend_t backend) {
//return ggml_backend_is_opencl(backend); // opencl must be used through the cpu backend
return ggml_backend_is_cpu(backend);
GGML_UNUSED(buffer_type);
}
static ggml_backend_buffer_type_i ggml_backend_opencl_buffer_type_interface = {
/* .get_name = */ ggml_backend_opencl_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_opencl_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_opencl_buffer_type_get_alignment,
/* .get_alloc_size = */ NULL,
/* .supports_backend = */ ggml_backend_opencl_buffer_type_supports_backend,
/* .is_host = */ NULL,
};
ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type() {
static ggml_backend_buffer_type buffer_type = {
/* .iface = */ ggml_backend_opencl_buffer_type_interface,
/* .context = */ nullptr,
};
return &buffer_type;
}
#if 0
// host buffer type
static const char * ggml_backend_opencl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
return "CL_Host";
GGML_UNUSED(buft);
}
static const char * ggml_backend_opencl_host_buffer_name(ggml_backend_buffer_t buffer) {
return "CL_Host";
GGML_UNUSED(buffer);
}
static void ggml_backend_opencl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_cl_host_free(buffer->context);
}
static ggml_backend_buffer_t ggml_backend_opencl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
void * ptr = ggml_cl_host_malloc(size);
if (ptr == nullptr) {
// fallback to cpu buffer
return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
}
ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
buffer->buft = buft;
buffer->iface.get_name = ggml_backend_opencl_host_buffer_name;
buffer->iface.free_buffer = ggml_backend_opencl_host_buffer_free_buffer;
return buffer;
}
ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type() {
static struct ggml_backend_buffer_type ggml_backend_opencl_buffer_type_host = {
/* .iface = */ {
/* .get_name = */ ggml_backend_opencl_host_buffer_type_name,
/* .alloc_buffer = */ ggml_backend_opencl_host_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
/* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
/* .supports_backend = */ ggml_backend_cpu_buffer_type()->iface.supports_backend,
/* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
},
/* .context = */ nullptr,
};
return &ggml_backend_opencl_buffer_type_host;
}
// backend
static const char * ggml_backend_opencl_name(ggml_backend_t backend) {
return "OpenCL";
GGML_UNUSED(backend);
}
static void ggml_backend_opencl_free(ggml_backend_t backend) {
GGML_UNUSED(backend);
}
static ggml_backend_buffer_type_t ggml_backend_opencl_get_default_buffer_type(ggml_backend_t backend) {
return ggml_backend_opencl_buffer_type();
GGML_UNUSED(backend);
}
static bool ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) {
for (int i = 0; i < graph->n_nodes; ++i) {
ggml_tensor * node = graph->nodes[i];
switch (node->op) {
case GGML_OP_MUL_MAT:
ggml_cl_mul_mat(node->src[0], node->src[1], node, nullptr, 0);
break;
case GGML_OP_MUL:
ggml_cl_mul(node->src[0], node->src[1], node);
break;
default:
GGML_ASSERT(false);
}
}
return true;
GGML_UNUSED(backend);
}
static bool ggml_backend_opencl_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
switch (op->op) {
case GGML_OP_MUL_MAT:
return ggml_cl_can_mul_mat(op->src[0], op->src[1], op);
case GGML_OP_MUL:
// return ggml_can_repeat_rows(op->src[1], op->src[0]);
return true;
default:
return false;
}
GGML_UNUSED(backend);
}
static ggml_backend_i opencl_backend_i = {
/* .get_name = */ ggml_backend_opencl_name,
/* .free = */ ggml_backend_opencl_free,
/* .get_default_buffer_type = */ ggml_backend_opencl_get_default_buffer_type,
/* .set_tensor_async = */ NULL,
/* .get_tensor_async = */ NULL,
/* .cpy_tensor_from_async = */ NULL,
/* .cpy_tensor_to_async = */ NULL,
/* .synchronize = */ NULL,
/* .graph_plan_create = */ NULL,
/* .graph_plan_free = */ NULL,
/* .graph_plan_compute = */ NULL,
/* .graph_compute = */ ggml_backend_opencl_graph_compute,
/* .supports_op = */ ggml_backend_opencl_supports_op,
};
ggml_backend_t ggml_backend_opencl_init() {
ggml_backend_t backend = new ggml_backend {
/* .interface = */ opencl_backend_i,
/* .context = */ nullptr
};
return backend;
}
bool ggml_backend_is_opencl(ggml_backend_t backend) {
return backend && backend->iface.get_name == ggml_backend_opencl_name;
}
#endif

View File

@ -1,6 +1,7 @@
#pragma once
#include "ggml.h"
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
@ -9,17 +10,26 @@ extern "C" {
GGML_API void ggml_cl_init(void);
GGML_API void ggml_cl_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, const struct ggml_tensor * dst);
GGML_API size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
GGML_API void * ggml_cl_host_malloc(size_t size);
GGML_API void ggml_cl_host_free(void * ptr);
// GGML_API void * ggml_cl_host_malloc(size_t size);
// GGML_API void ggml_cl_host_free(void * ptr);
GGML_API void ggml_cl_free_data(const struct ggml_tensor* tensor);
GGML_API void ggml_cl_transform_tensor(void * data, struct ggml_tensor * tensor);
// backend API
// GGML_API ggml_backend_t ggml_backend_opencl_init(void);
// GGML_API bool ggml_backend_is_opencl(ggml_backend_t backend);
GGML_API ggml_backend_buffer_type_t ggml_backend_opencl_buffer_type(void);
// GGML_API ggml_backend_buffer_type_t ggml_backend_opencl_host_buffer_type(void);
#ifdef __cplusplus
}
#endif

30
ggml.c
View File

@ -2336,6 +2336,10 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
}
void ggml_free(struct ggml_context * ctx) {
if (ctx == NULL) {
return;
}
// make this function thread safe
ggml_critical_section_start();
@ -4351,6 +4355,23 @@ struct ggml_tensor * ggml_cpy_inplace(
return ggml_cpy_impl(ctx, a, b, true);
}
struct ggml_tensor * ggml_cast(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_type type) {
bool is_node = false;
struct ggml_tensor * result = ggml_new_tensor(ctx, type, GGML_MAX_DIMS, a->ne);
ggml_format_name(result, "%s (copy)", a->name);
result->op = GGML_OP_CPY;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = result;
return result;
}
// ggml_cont
static struct ggml_tensor * ggml_cont_impl(
@ -14851,7 +14872,7 @@ size_t ggml_hash_find_or_insert(struct ggml_hash_set hash_set, struct ggml_tenso
return i;
}
static struct ggml_hash_set ggml_hash_set_new(size_t size) {
struct ggml_hash_set ggml_hash_set_new(size_t size) {
size = ggml_hash_size(size);
struct ggml_hash_set result;
result.size = size;
@ -16600,7 +16621,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
return GGML_EXIT_SUCCESS;
}
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threads) {
if (n_threads <= 0) {
n_threads = GGML_DEFAULT_N_THREADS;
}
@ -16662,14 +16683,15 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
} break;
case GGML_OP_MUL_MAT_ID:
{
cur = 0;
const struct ggml_tensor * src0 = node->src[2];
const struct ggml_tensor * src1 = node->src[1];
const enum ggml_type vec_dot_type = type_traits[src0->type].vec_dot_type;
if (src1->type != vec_dot_type) {
cur = ggml_row_size(vec_dot_type, ggml_nelements(src1));
cur += ggml_row_size(vec_dot_type, ggml_nelements(src1));
}
const int n_as = ggml_get_op_params_i32(node, 1);
cur = GGML_PAD(cur, sizeof(int64_t)); // align
cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
} break;

9
ggml.h
View File

@ -1167,6 +1167,11 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
GGML_API struct ggml_tensor * ggml_cast(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_type type);
// make contiguous
GGML_API struct ggml_tensor * ggml_cont(
struct ggml_context * ctx,
@ -1849,8 +1854,8 @@ extern "C" {
// ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API struct ggml_cplan ggml_graph_plan (const struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
GGML_API int ggml_graph_compute( struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
// same as ggml_graph_compute() but the work data is allocated as a part of the context
// note: the drawback of this API is that you must have ensured that the context has enough memory for the work data

2200
llama.cpp

File diff suppressed because it is too large Load Diff

17
llama.h
View File

@ -116,6 +116,12 @@ extern "C" {
LLAMA_ROPE_SCALING_MAX_VALUE = LLAMA_ROPE_SCALING_YARN,
};
enum llama_split_mode {
LLAMA_SPLIT_NONE = 0, // single GPU
LLAMA_SPLIT_LAYER = 1, // split layers and KV across GPUs
LLAMA_SPLIT_ROW = 2, // split rows across GPUs
};
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
@ -178,8 +184,15 @@ extern "C" {
struct llama_model_params {
int32_t n_gpu_layers; // number of layers to store in VRAM
int32_t main_gpu; // the GPU that is used for scratch and small tensors
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)
enum llama_split_mode split_mode; // how to split the model across multiple GPUs
// main_gpu interpretation depends on split_mode:
// LLAMA_SPLIT_NONE: the GPU that is used for the entire model
// LLAMA_SPLIT_ROW: the GPU that is used for small tensors and intermediate results
// LLAMA_SPLIT_LAYER: ignored
int32_t main_gpu;
// proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES
const float * tensor_split;
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
// If the provided progress_callback returns true, model loading continues.

View File

@ -376,6 +376,11 @@ struct test_case {
// allocate
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend1);
if (buf == NULL) {
printf("failed to allocate tensors [%s] ", ggml_backend_name(backend1));
ggml_free(ctx);
return false;
}
// build graph
ggml_build_forward_expand(gf, out);
@ -463,19 +468,23 @@ struct test_case {
GGML_UNUSED(index);
};
ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
if (ud.ok) {
printf("\033[1;32mOK\033[0m\n");
} else {
printf("\033[1;31mFAIL\033[0m\n");
if (!cmp_ok) {
printf("compare failed ");
}
ggml_backend_buffer_free(buf);
ggml_free(ctx);
return ud.ok;
if (ud.ok && cmp_ok) {
printf("\033[1;32mOK\033[0m\n");
return true;
}
printf("\033[1;31mFAIL\033[0m\n");
return false;
}
bool eval_perf(ggml_backend_t backend, const char * op_name) {
@ -519,6 +528,11 @@ struct test_case {
// allocate
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend);
if (buf == NULL) {
printf("failed to allocate tensors\n");
ggml_free(ctx);
return false;
}
// randomize tensors
initialize_tensors(ctx);