ggml : add metal backend registry / device (#9713)

* ggml : add metal backend registry / device

ggml-ci

* metal : fix names [no ci]

* metal : global registry and device instances

ggml-ci

* cont : alternative initialization of global objects

ggml-ci

* llama : adapt to backend changes

ggml-ci

* fixes

* metal : fix indent

* metal : fix build when MTLGPUFamilyApple3 is not available

ggml-ci

* fix merge

* metal : avoid unnecessary singleton accesses

ggml-ci

* metal : minor fix [no ci]

* metal : g_state -> g_ggml_ctx_dev_main [no ci]

* metal : avoid reference of device context in the backend context

ggml-ci

* metal : minor [no ci]

* metal : fix maxTransferRate check

* metal : remove transfer rate stuff

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2024-10-07 18:27:51 +03:00 committed by GitHub
parent 96b6912103
commit d5ac8cf2f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 535 additions and 284 deletions

View File

@ -127,6 +127,8 @@ extern "C" {
bool async; bool async;
// pinned host buffer // pinned host buffer
bool host_buffer; bool host_buffer;
// creating buffers from host ptr
bool buffer_from_host_ptr;
// event synchronization // event synchronization
bool events; bool events;
}; };

View File

@ -43,7 +43,9 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
GGML_API bool ggml_backend_is_metal(ggml_backend_t backend); GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size); GGML_DEPRECATED(
GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
"obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
@ -57,6 +59,8 @@ GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int fam
// capture all command buffers committed the next time `ggml_backend_graph_compute` is called // capture all command buffers committed the next time `ggml_backend_graph_compute` is called
GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend); GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
GGML_API ggml_backend_reg_t ggml_backend_metal_reg(void);
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif

View File

@ -463,6 +463,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
} }
void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) { void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) {
memset(props, 0, sizeof(*props));
device->iface.get_props(device, props); device->iface.get_props(device, props);
} }
@ -479,6 +480,10 @@ ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t devic
} }
ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) { ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
if (device->iface.get_host_buffer_type == NULL) {
return NULL;
}
return device->iface.get_host_buffer_type(device); return device->iface.get_host_buffer_type(device);
} }
@ -525,6 +530,10 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
#include "ggml-cuda.h" #include "ggml-cuda.h"
#endif #endif
#ifdef GGML_USE_METAL
#include "ggml-metal.h"
#endif
struct ggml_backend_registry { struct ggml_backend_registry {
std::vector<ggml_backend_reg_t> backends; std::vector<ggml_backend_reg_t> backends;
std::vector<ggml_backend_dev_t> devices; std::vector<ggml_backend_dev_t> devices;
@ -533,10 +542,13 @@ struct ggml_backend_registry {
#ifdef GGML_USE_CUDA #ifdef GGML_USE_CUDA
register_backend(ggml_backend_cuda_reg()); register_backend(ggml_backend_cuda_reg());
#endif #endif
#ifdef GGML_USE_METAL
register_backend(ggml_backend_metal_reg());
#endif
register_backend(ggml_backend_cpu_reg()); register_backend(ggml_backend_cpu_reg());
// TODO: sycl, metal, vulkan, kompute, cann // TODO: sycl, vulkan, kompute, cann
} }
void register_backend(ggml_backend_reg_t reg) { void register_backend(ggml_backend_reg_t reg) {
@ -1118,9 +1130,10 @@ static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggm
props->type = ggml_backend_cpu_device_get_type(dev); props->type = ggml_backend_cpu_device_get_type(dev);
ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total); ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = { props->caps = {
/* async */ false, /* .async = */ false,
/* host_buffer */ false, /* .host_buffer = */ false,
/* events */ false, /* .buffer_from_host_ptr = */ true,
/* .events = */ false,
}; };
} }

View File

@ -2920,9 +2920,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
#endif #endif
props->caps = { props->caps = {
/* async */ true, /* .async = */ true,
/* host_buffer */ host_buffer, /* .host_buffer = */ host_buffer,
/* events */ events, /* .buffer_from_host_ptr = */ false,
/* .events = */ events,
}; };
} }

File diff suppressed because it is too large Load Diff

View File

@ -26,10 +26,6 @@
# include "ggml-blas.h" # include "ggml-blas.h"
#endif #endif
#ifdef GGML_USE_METAL
# include "ggml-metal.h"
#endif
// TODO: replace with ggml API call // TODO: replace with ggml API call
#define QK_K 256 #define QK_K 256
@ -3292,9 +3288,6 @@ struct llama_context {
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters; std::unordered_map<struct llama_lora_adapter *, float> lora_adapters;
std::vector<ggml_backend_t> backends; std::vector<ggml_backend_t> backends;
#ifdef GGML_USE_METAL
ggml_backend_t backend_metal = nullptr;
#endif
#ifdef GGML_USE_BLAS #ifdef GGML_USE_BLAS
ggml_backend_t backend_blas = nullptr; ggml_backend_t backend_blas = nullptr;
#endif #endif
@ -3420,9 +3413,7 @@ static int llama_get_device_count(const llama_model & model) {
count += (int) model.rpc_servers.size(); count += (int) model.rpc_servers.size();
#endif #endif
#if defined(GGML_USE_METAL) #if defined(GGML_USE_SYCL)
count += 1;
#elif defined(GGML_USE_SYCL)
count += ggml_backend_sycl_get_device_count(); count += ggml_backend_sycl_get_device_count();
#elif defined(GGML_USE_VULKAN) #elif defined(GGML_USE_VULKAN)
count += ggml_backend_vk_get_device_count(); count += ggml_backend_vk_get_device_count();
@ -3488,9 +3479,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_
} }
device -= (int)model.devices.size(); device -= (int)model.devices.size();
#if defined(GGML_USE_METAL) #if defined(GGML_USE_VULKAN)
buft = ggml_backend_metal_buffer_type();
#elif defined(GGML_USE_VULKAN)
buft = ggml_backend_vk_buffer_type(device); buft = ggml_backend_vk_buffer_type(device);
#elif defined(GGML_USE_SYCL) #elif defined(GGML_USE_SYCL)
buft = ggml_backend_sycl_buffer_type(device); buft = ggml_backend_sycl_buffer_type(device);
@ -8918,48 +8907,39 @@ static bool llm_load_tensors(
llama_buf_map bufs; llama_buf_map bufs;
bufs.reserve(n_max_backend_buffer); bufs.reserve(n_max_backend_buffer);
// check if this backend device supports buffer_from_host_ptr
ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
bool buffer_from_host_ptr_supported = false;
if (dev) {
ggml_backend_dev_props props;
ggml_backend_dev_get_props(dev, &props);
buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
}
if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported) {
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
// only the mmap region containing the tensors in the model is mapped to the backend buffer // only the mmap region containing the tensors in the model is mapped to the backend buffer
// this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
// this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
if (ml.use_mmap && use_mmap_buffer && buft == llama_default_buffer_type_cpu(model, true)) {
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
void * addr = nullptr; void * addr = nullptr;
size_t first, last; size_t first, last; // NOLINT
ml.get_mapping_range(&first, &last, &addr, idx, ctx); ml.get_mapping_range(&first, &last, &addr, idx, ctx);
if (first >= last) { if (first >= last) {
continue; continue;
} }
ggml_backend_buffer_t buf = ggml_backend_cpu_buffer_from_ptr((char *) addr + first, last - first);
if (buf == nullptr) {
throw std::runtime_error("unable to allocate backend CPU buffer");
}
model.bufs.push_back(buf);
bufs.emplace(idx, buf);
}
}
#ifdef GGML_USE_METAL
else if (ml.use_mmap && use_mmap_buffer && buft == ggml_backend_metal_buffer_type()) {
for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
const size_t max_size = ggml_get_max_tensor_size(ctx); const size_t max_size = ggml_get_max_tensor_size(ctx);
void * addr = nullptr; ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
size_t first, last;
ml.get_mapping_range(&first, &last, &addr, idx, ctx);
if (first >= last) {
continue;
}
ggml_backend_buffer_t buf = ggml_backend_metal_buffer_from_ptr((char *) addr + first, last - first, max_size);
if (buf == nullptr) { if (buf == nullptr) {
throw std::runtime_error("unable to allocate backend metal buffer"); throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
} }
model.bufs.push_back(buf); model.bufs.push_back(buf);
bufs.emplace(idx, buf); bufs.emplace(idx, buf);
} }
} }
#endif
else { else {
ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
if (buf == nullptr) { if (buf == nullptr) {
throw std::runtime_error("unable to allocate backend buffer"); throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
} }
model.bufs.push_back(buf); model.bufs.push_back(buf);
if (use_mlock && ggml_backend_buffer_is_host(buf)) { if (use_mlock && ggml_backend_buffer_is_host(buf)) {
@ -19041,7 +19021,7 @@ bool llama_supports_mlock(void) {
} }
bool llama_supports_gpu_offload(void) { bool llama_supports_gpu_offload(void) {
#if defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \ #if defined(GGML_USE_VULKAN) || \
defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC) defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU. // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
return true; return true;
@ -19344,17 +19324,7 @@ struct llama_context * llama_new_context_with_model(
} }
#endif #endif
#if defined(GGML_USE_METAL) #if defined(GGML_USE_VULKAN)
if (model->n_gpu_layers > 0) {
ctx->backend_metal = ggml_backend_metal_init();
if (ctx->backend_metal == nullptr) {
LLAMA_LOG_ERROR("%s: failed to initialize Metal backend\n", __func__);
llama_free(ctx);
return nullptr;
}
ctx->backends.push_back(ctx->backend_metal);
}
#elif defined(GGML_USE_VULKAN)
if (model->split_mode == LLAMA_SPLIT_MODE_ROW) { if (model->split_mode == LLAMA_SPLIT_MODE_ROW) {
LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__); LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__);
llama_free(ctx); llama_free(ctx);