llama : accept a list of devices to use to offload a model (#10497)

* llama : accept a list of devices to use to offload a model

* accept `--dev none` to completely disable offloading

* fix dev list with dl backends

* rename env parameter to LLAMA_ARG_DEVICE for consistency
This commit is contained in:
Diego Devesa 2024-11-25 19:30:06 +01:00 committed by GitHub
parent 1f922254f0
commit 10bce0450f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 104 additions and 27 deletions

View File

@ -298,6 +298,27 @@ static void common_params_print_usage(common_params_context & ctx_arg) {
print_options(specific_options);
}
static std::vector<ggml_backend_dev_t> parse_device_list(const std::string & value) {
std::vector<ggml_backend_dev_t> devices;
auto dev_names = string_split<std::string>(value, ',');
if (dev_names.empty()) {
throw std::invalid_argument("no devices specified");
}
if (dev_names.size() == 1 && dev_names[0] == "none") {
devices.push_back(nullptr);
} else {
for (const auto & device : dev_names) {
auto * dev = ggml_backend_dev_by_name(device.c_str());
if (!dev || ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_GPU) {
throw std::invalid_argument(string_format("invalid device: %s", device.c_str()));
}
devices.push_back(dev);
}
devices.push_back(nullptr);
}
return devices;
}
bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
auto ctx_arg = common_params_parser_init(params, ex, print_usage);
const common_params params_org = ctx_arg.params; // the example can modify the default params
@ -324,6 +345,9 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// load dynamic backends
ggml_backend_load_all();
common_params_context ctx_arg(params);
ctx_arg.print_usage = print_usage;
ctx_arg.ex = ex;
@ -1312,6 +1336,30 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
else { throw std::invalid_argument("invalid value"); }
}
).set_env("LLAMA_ARG_NUMA"));
add_opt(common_arg(
{"-dev", "--device"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading (none = don't offload)\n"
"use --list-devices to see a list of available devices",
[](common_params & params, const std::string & value) {
params.devices = parse_device_list(value);
}
).set_env("LLAMA_ARG_DEVICE"));
add_opt(common_arg(
{"--list-devices"},
"print list of available devices and exit",
[](common_params &) {
printf("Available devices:\n");
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
auto * dev = ggml_backend_dev_get(i);
if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_GPU) {
size_t free, total;
ggml_backend_dev_memory(dev, &free, &total);
printf(" %s: %s (%zu MiB, %zu MiB free)\n", ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), total / 1024 / 1024, free / 1024 / 1024);
}
}
exit(0);
}
));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
@ -1336,10 +1384,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} else if (arg_next == "layer") {
params.split_mode = LLAMA_SPLIT_MODE_LAYER;
} else if (arg_next == "row") {
#ifdef GGML_USE_SYCL
fprintf(stderr, "warning: The split mode value:[row] is not supported by llama.cpp with SYCL. It's developing.\nExit!\n");
exit(1);
#endif // GGML_USE_SYCL
params.split_mode = LLAMA_SPLIT_MODE_ROW;
} else {
throw std::invalid_argument("invalid value");
@ -2042,6 +2086,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.n_ctx = value;
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-devd", "--device-draft"}, "<dev1,dev2,..>",
"comma-separated list of devices to use for offloading the draft model (none = don't offload)\n"
"use --list-devices to see a list of available devices",
[](common_params & params, const std::string & value) {
params.speculative.devices = parse_device_list(value);
}
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"-ngld", "--gpu-layers-draft", "--n-gpu-layers-draft"}, "N",
"number of layers to store in VRAM for the draft model",

View File

@ -377,9 +377,6 @@ void common_init() {
#endif
LOG_INF("build: %d (%s) with %s for %s%s\n", LLAMA_BUILD_NUMBER, LLAMA_COMMIT, LLAMA_COMPILER, LLAMA_BUILD_TARGET, build_type);
// load dynamic backends
ggml_backend_load_all();
}
std::string common_params_get_system_info(const common_params & params) {
@ -982,9 +979,12 @@ void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_l
}
}
struct llama_model_params common_model_params_to_llama(const common_params & params) {
struct llama_model_params common_model_params_to_llama(common_params & params) {
auto mparams = llama_model_default_params();
if (!params.devices.empty()) {
mparams.devices = params.devices.data();
}
if (params.n_gpu_layers != -1) {
mparams.n_gpu_layers = params.n_gpu_layers;
}

View File

@ -156,6 +156,7 @@ struct common_params_sampling {
};
struct common_params_speculative {
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_ctx = 0; // draft context size
int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding
int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding
@ -178,9 +179,6 @@ struct common_params {
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
int32_t n_parallel = 1; // number of parallel sequences to decode
int32_t n_sequences = 1; // number of sequences to decode
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
int32_t grp_attn_n = 1; // group-attention factor
int32_t grp_attn_w = 512; // group-attention width
int32_t n_print = -1; // print token count every n tokens (-1 = disabled)
@ -193,6 +191,13 @@ struct common_params {
int32_t yarn_orig_ctx = 0; // YaRN original context length
float defrag_thold = 0.1f; // KV cache defragmentation threshold
// offload params
std::vector<ggml_backend_dev_t> devices; // devices to use for offloading
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
struct cpu_params cpuparams;
struct cpu_params cpuparams_batch;
@ -201,7 +206,6 @@ struct common_params {
ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
@ -462,7 +466,7 @@ struct common_init_result {
struct common_init_result common_init_from_params(common_params & params);
struct llama_model_params common_model_params_to_llama (const common_params & params);
struct llama_model_params common_model_params_to_llama ( common_params & params);
struct llama_context_params common_context_params_to_llama(const common_params & params);
struct ggml_threadpool_params ggml_threadpool_params_from_cpu_params(const cpu_params & params);

View File

@ -692,6 +692,7 @@ struct server_context {
auto params_dft = params_base;
params_dft.devices = params_base.speculative.devices;
params_dft.model = params_base.speculative.model;
params_dft.n_ctx = params_base.speculative.n_ctx;
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;

View File

@ -46,6 +46,7 @@ int main(int argc, char ** argv) {
ctx_tgt = llama_init_tgt.context;
// load the draft model
params.devices = params.speculative.devices;
params.model = params.speculative.model;
params.n_ctx = params.speculative.n_ctx;
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;

View File

@ -76,6 +76,7 @@ int main(int argc, char ** argv) {
ctx_tgt = llama_init_tgt.context;
// load the draft model
params.devices = params.speculative.devices;
params.model = params.speculative.model;
params.n_gpu_layers = params.speculative.n_gpu_layers;
if (params.speculative.cpuparams.n_threads > 0) {

View File

@ -253,6 +253,15 @@ void ggml_backend_device_register(ggml_backend_dev_t device) {
}
// Backend (reg) enumeration
static bool striequals(const char * a, const char * b) {
for (; *a && *b; a++, b++) {
if (std::tolower(*a) != std::tolower(*b)) {
return false;
}
}
return *a == *b;
}
size_t ggml_backend_reg_count() {
return get_reg().backends.size();
}
@ -265,7 +274,7 @@ ggml_backend_reg_t ggml_backend_reg_get(size_t index) {
ggml_backend_reg_t ggml_backend_reg_by_name(const char * name) {
for (size_t i = 0; i < ggml_backend_reg_count(); i++) {
ggml_backend_reg_t reg = ggml_backend_reg_get(i);
if (std::strcmp(ggml_backend_reg_name(reg), name) == 0) {
if (striequals(ggml_backend_reg_name(reg), name)) {
return reg;
}
}
@ -285,7 +294,7 @@ ggml_backend_dev_t ggml_backend_dev_get(size_t index) {
ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {
for (size_t i = 0; i < ggml_backend_dev_count(); i++) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
if (strcmp(ggml_backend_dev_name(dev), name) == 0) {
if (striequals(ggml_backend_dev_name(dev), name)) {
return dev;
}
}

View File

@ -272,6 +272,9 @@ extern "C" {
};
struct llama_model_params {
// NULL-terminated list of devices to use for offloading (if NULL, all available devices are used)
ggml_backend_dev_t * devices;
int32_t n_gpu_layers; // number of layers to store in VRAM
enum llama_split_mode split_mode; // how to split the model across multiple GPUs

View File

@ -19364,6 +19364,7 @@ void llama_lora_adapter_free(struct llama_lora_adapter * adapter) {
//
struct llama_model_params llama_model_default_params() {
struct llama_model_params result = {
/*.devices =*/ nullptr,
/*.n_gpu_layers =*/ 0,
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
/*.main_gpu =*/ 0,
@ -19576,19 +19577,24 @@ struct llama_model * llama_load_model_from_file(
}
// create list of devices to use with this model
// currently, we use all available devices
// TODO: rework API to give user more control over device selection
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
switch (ggml_backend_dev_type(dev)) {
case GGML_BACKEND_DEVICE_TYPE_CPU:
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
// skip CPU backends since they are handled separately
break;
if (params.devices) {
for (ggml_backend_dev_t * dev = params.devices; *dev; ++dev) {
model->devices.push_back(*dev);
}
} else {
// use all available devices
for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
ggml_backend_dev_t dev = ggml_backend_dev_get(i);
switch (ggml_backend_dev_type(dev)) {
case GGML_BACKEND_DEVICE_TYPE_CPU:
case GGML_BACKEND_DEVICE_TYPE_ACCEL:
// skip CPU backends since they are handled separately
break;
case GGML_BACKEND_DEVICE_TYPE_GPU:
model->devices.push_back(dev);
break;
case GGML_BACKEND_DEVICE_TYPE_GPU:
model->devices.push_back(dev);
break;
}
}
}