mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 02:14:35 +00:00
Basic Vulkan Multi-GPU implementation (#5321)
* Initial Vulkan multi-gpu implementation Move most global variables into backend context * Add names to backend device functions * Add further missing cleanup code * Reduce code duplication in tensor split layer assignment * generalize LLAMA_SPLIT_LAYER for all backends, do not expose device count and memory in llama.h * Only do device info print in the beginning and initialize one backend for cpu assist Add missing cleanup code * Rework backend memory management to make sure devices and buffers get properly allocated and freed * Rename cpu assist free function --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
ed0bf32290
commit
ee1628bdfe
@ -46,6 +46,10 @@
|
||||
#define GGML_USE_CUBLAS_SYCL
|
||||
#endif
|
||||
|
||||
#if (defined(GGML_USE_CUBLAS) || defined(GGML_USE_SYCL)) || defined(GGML_USE_VULKAN)
|
||||
#define GGML_USE_CUBLAS_SYCL_VULKAN
|
||||
#endif
|
||||
|
||||
int32_t get_num_physical_cores() {
|
||||
#ifdef __linux__
|
||||
// enumerate the set of thread siblings, num entries is num cores
|
||||
@ -660,8 +664,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
params.tensor_split[i] = 0.0f;
|
||||
}
|
||||
}
|
||||
#ifndef GGML_USE_CUBLAS_SYCL
|
||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL. Setting a tensor split has no effect.\n");
|
||||
#ifndef GGML_USE_CUBLAS_SYCL_VULKAN
|
||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS/SYCL/Vulkan. Setting a tensor split has no effect.\n");
|
||||
#endif // GGML_USE_CUBLAS_SYCL
|
||||
} else if (arg == "--no-mmap") {
|
||||
params.use_mmap = false;
|
||||
|
2639
ggml-vulkan.cpp
2639
ggml-vulkan.cpp
File diff suppressed because it is too large
Load Diff
@ -8,24 +8,29 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
#define GGML_VK_NAME "Vulkan"
|
||||
#define GGML_VK_MAX_DEVICES 16
|
||||
|
||||
GGML_API void ggml_vk_init(void);
|
||||
GGML_API void ggml_vk_init_cpu_assist(void);
|
||||
|
||||
GGML_API void ggml_vk_preallocate_buffers_graph(struct ggml_tensor * node);
|
||||
GGML_API void ggml_vk_preallocate_buffers(void);
|
||||
GGML_API void ggml_vk_build_graph(struct ggml_tensor * node, bool last_node);
|
||||
GGML_API bool ggml_vk_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_vk_preallocate_buffers_graph_cpu_assist(struct ggml_tensor * node);
|
||||
GGML_API void ggml_vk_preallocate_buffers_cpu_assist(void);
|
||||
GGML_API void ggml_vk_build_graph_cpu_assist(struct ggml_tensor * node, bool last_node);
|
||||
GGML_API bool ggml_vk_compute_forward_cpu_assist(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
void ggml_vk_check_results_1(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
||||
void ggml_vk_check_results_1_cpu_assist(struct ggml_compute_params * params, struct ggml_tensor * tensor);
|
||||
#endif
|
||||
GGML_API void ggml_vk_graph_cleanup(void);
|
||||
GGML_API void ggml_vk_graph_cleanup_cpu_assist(void);
|
||||
GGML_API void ggml_vk_free_cpu_assist(void);
|
||||
|
||||
// backend API
|
||||
GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(void);
|
||||
GGML_API GGML_CALL ggml_backend_t ggml_backend_vk_init(size_t dev_num);
|
||||
|
||||
GGML_API GGML_CALL bool ggml_backend_is_vk(ggml_backend_t backend);
|
||||
GGML_API GGML_CALL int ggml_backend_vk_get_device_count(void);
|
||||
GGML_API GGML_CALL void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size);
|
||||
GGML_API GGML_CALL void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total);
|
||||
|
||||
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(void);
|
||||
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
|
||||
// pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
|
||||
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
|
||||
|
||||
|
14
ggml.c
14
ggml.c
@ -2343,7 +2343,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
ggml_cl_init();
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
ggml_vk_init();
|
||||
ggml_vk_init_cpu_assist();
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
ggml_init_sycl();
|
||||
#endif
|
||||
@ -14850,10 +14850,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
GGML_ASSERT(tensor->src[0] == NULL || tensor->src[0]->backend == GGML_BACKEND_CPU);
|
||||
GGML_ASSERT(tensor->src[1] == NULL || tensor->src[1]->backend == GGML_BACKEND_CPU);
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
const bool skip_cpu = ggml_vk_compute_forward(params, tensor);
|
||||
const bool skip_cpu = ggml_vk_compute_forward_cpu_assist(params, tensor);
|
||||
#ifdef GGML_VULKAN_CHECK_RESULTS
|
||||
if (skip_cpu) {
|
||||
ggml_vk_check_results_1(params, tensor);
|
||||
ggml_vk_check_results_1_cpu_assist(params, tensor);
|
||||
}
|
||||
#endif
|
||||
if (skip_cpu) {
|
||||
@ -17269,12 +17269,12 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_vk_preallocate_buffers_graph(cgraph->nodes[i]);
|
||||
ggml_vk_preallocate_buffers_graph_cpu_assist(cgraph->nodes[i]);
|
||||
}
|
||||
ggml_vk_preallocate_buffers();
|
||||
ggml_vk_preallocate_buffers_cpu_assist();
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
ggml_vk_build_graph(cgraph->nodes[i], i == cgraph->n_nodes - 1);
|
||||
ggml_vk_build_graph_cpu_assist(cgraph->nodes[i], i == cgraph->n_nodes - 1);
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -17330,7 +17330,7 @@ int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
ggml_vk_graph_cleanup();
|
||||
ggml_vk_graph_cleanup_cpu_assist();
|
||||
#endif
|
||||
|
||||
// performance stats (graph)
|
||||
|
69
llama.cpp
69
llama.cpp
@ -1355,7 +1355,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) {
|
||||
#elif defined(GGML_USE_CUBLAS)
|
||||
buft = ggml_backend_cuda_buffer_type(gpu);
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
buft = ggml_backend_vk_buffer_type();
|
||||
buft = ggml_backend_vk_buffer_type(gpu);
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
buft = ggml_backend_sycl_buffer_type(gpu);
|
||||
#elif defined(GGML_USE_CLBLAST)
|
||||
@ -1392,6 +1392,33 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_g
|
||||
GGML_UNUSED(tensor_split);
|
||||
}
|
||||
|
||||
static size_t llama_get_device_count() {
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
return ggml_backend_cuda_get_device_count();
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
return ggml_backend_vk_get_device_count();
|
||||
#else
|
||||
return 1;
|
||||
#endif
|
||||
}
|
||||
|
||||
static size_t llama_get_device_memory(int device) {
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
size_t total;
|
||||
size_t free;
|
||||
ggml_backend_cuda_get_device_memory(device, &total, &free);
|
||||
return free;
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
size_t total;
|
||||
size_t free;
|
||||
ggml_backend_vk_get_device_memory(device, &total, &free);
|
||||
return free;
|
||||
#else
|
||||
return 1;
|
||||
GGML_UNUSED(device);
|
||||
#endif
|
||||
}
|
||||
|
||||
//
|
||||
// globals
|
||||
//
|
||||
@ -1763,6 +1790,10 @@ struct llama_context {
|
||||
ggml_backend_free(backend);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_VULKAN
|
||||
ggml_vk_free_cpu_assist();
|
||||
#endif
|
||||
|
||||
ggml_backend_buffer_free(buf_input);
|
||||
ggml_free(ctx_input);
|
||||
}
|
||||
@ -3436,22 +3467,18 @@ static bool llm_load_tensors(
|
||||
model.buft_layer[i] = llama_default_buffer_type_cpu(true);
|
||||
}
|
||||
|
||||
#ifdef GGML_USE_CUBLAS
|
||||
if (split_mode == LLAMA_SPLIT_LAYER) {
|
||||
// calculate the split points
|
||||
int device_count = ggml_backend_cuda_get_device_count();
|
||||
int device_count = llama_get_device_count();
|
||||
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
|
||||
float splits[GGML_CUDA_MAX_DEVICES];
|
||||
std::vector<float> splits(device_count);
|
||||
if (all_zero) {
|
||||
// default split, by free memory
|
||||
for (int i = 0; i < device_count; ++i) {
|
||||
size_t total;
|
||||
size_t free;
|
||||
ggml_backend_cuda_get_device_memory(i, &total, &free);
|
||||
splits[i] = free;
|
||||
splits[i] = llama_get_device_memory(i);
|
||||
}
|
||||
} else {
|
||||
std::copy(tensor_split, tensor_split + device_count, splits);
|
||||
std::copy(tensor_split, tensor_split + device_count, splits.begin());
|
||||
}
|
||||
|
||||
// sum and normalize the splits to get the split points
|
||||
@ -3467,19 +3494,17 @@ static bool llm_load_tensors(
|
||||
// assign the repeating layers to the devices according to the splits
|
||||
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
|
||||
for (int64_t i = i_gpu_start; i < n_layer; ++i) {
|
||||
int layer_gpu = std::upper_bound(splits, splits + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits;
|
||||
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
|
||||
model.buft_layer[i] = llama_default_buffer_type_offload(layer_gpu);
|
||||
}
|
||||
// assign the output layer
|
||||
if (n_gpu_layers > n_layer) {
|
||||
int layer_gpu = std::upper_bound(splits, splits + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits;
|
||||
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
|
||||
model.buft_output = llama_default_buffer_type_offload(layer_gpu);
|
||||
} else {
|
||||
model.buft_output = llama_default_buffer_type_cpu(true);
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
} else {
|
||||
ggml_backend_buffer_type_t split_buft;
|
||||
if (split_mode == LLAMA_SPLIT_ROW) {
|
||||
split_buft = llama_default_buffer_type_split(main_gpu, tensor_split);
|
||||
@ -10483,6 +10508,8 @@ size_t llama_max_devices(void) {
|
||||
return GGML_CUDA_MAX_DEVICES;
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
return GGML_SYCL_MAX_DEVICES;
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
return GGML_VK_MAX_DEVICES;
|
||||
#else
|
||||
return 1;
|
||||
#endif
|
||||
@ -10690,13 +10717,15 @@ struct llama_context * llama_new_context_with_model(
|
||||
}
|
||||
#elif defined(GGML_USE_VULKAN)
|
||||
if (model->n_gpu_layers > 0) {
|
||||
ggml_backend_t backend = ggml_backend_vk_init();
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
|
||||
ggml_backend_t backend = ggml_backend_vk_init(device);
|
||||
if (backend == nullptr) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize Vulkan%d backend\n", __func__, device);
|
||||
llama_free(ctx);
|
||||
return nullptr;
|
||||
}
|
||||
ctx->backends.push_back(backend);
|
||||
}
|
||||
ctx->backends.push_back(backend);
|
||||
}
|
||||
#elif defined(GGML_USE_SYCL)
|
||||
if (model->n_gpu_layers > 0) {
|
||||
|
Loading…
Reference in New Issue
Block a user