From 5cb04dbc16d1da38c8fdcc0111b40e67d00dd1c3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 31 Jan 2024 17:30:17 +0200 Subject: [PATCH] llama : remove LLAMA_MAX_DEVICES and LLAMA_SUPPORTS_GPU_OFFLOAD (#5240) * llama : remove LLAMA_MAX_DEVICES from llama.h ggml-ci * Update llama.cpp Co-authored-by: slaren * server : remove LLAMA_MAX_DEVICES ggml-ci * llama : remove LLAMA_SUPPORTS_GPU_OFFLOAD ggml-ci * train : remove LLAMA_SUPPORTS_GPU_OFFLOAD * readme : add deprecation notice * readme : change deprecation notice to "remove" and fix url * llama : remove gpu includes from llama.h ggml-ci --------- Co-authored-by: slaren --- README.md | 3 +- common/common.cpp | 56 ++++++++++---------- common/common.h | 66 ++++++++++++------------ common/train.cpp | 12 ++--- examples/batched-bench/batched-bench.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 16 +++--- examples/server/server.cpp | 44 ++++++++-------- llama.cpp | 39 +++++++++++--- llama.h | 29 ++++------- 9 files changed, 143 insertions(+), 124 deletions(-) diff --git a/README.md b/README.md index 7746cb510..e6ed1d429 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,8 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ ### Hot topics -- ⚠️ Incoming backends: https://github.com/ggerganov/llama.cpp/discussions/5138 +- Remove LLAMA_MAX_DEVICES and LLAMA_SUPPORTS_GPU_OFFLOAD: https://github.com/ggerganov/llama.cpp/pull/5240 +- Incoming backends: https://github.com/ggerganov/llama.cpp/discussions/5138 - [SYCL backend](README-sycl.md) is ready (1/28/2024), support Linux/Windows in Intel GPUs (iGPU, Arc/Flex/Max series) - New SOTA quantized models, including pure 2-bits: https://huggingface.co/ikawrakow - Collecting Apple Silicon performance stats: diff --git a/common/common.cpp b/common/common.cpp index 9d976c7c8..ce739b15c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -583,20 +583,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.n_gpu_layers = std::stoi(argv[i]); -#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); -#endif + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); + } } else if (arg == "--gpu-layers-draft" || arg == "-ngld" || arg == "--n-gpu-layers-draft") { if (++i >= argc) { invalid_param = true; break; } params.n_gpu_layers_draft = std::stoi(argv[i]); -#ifndef LLAMA_SUPPORTS_GPU_OFFLOAD - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); -#endif + if (!llama_supports_gpu_offload()) { + fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers-draft option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); + } } else if (arg == "--main-gpu" || arg == "-mg") { if (++i >= argc) { invalid_param = true; @@ -637,11 +637,11 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { const std::regex regex{R"([,/]+)"}; std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; std::vector split_arg{it, {}}; - if (split_arg.size() >= LLAMA_MAX_DEVICES) { + if (split_arg.size() >= llama_max_devices()) { invalid_param = true; break; } - for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) { + for (size_t i = 0; i < llama_max_devices(); ++i) { if (i < split_arg.size()) { params.tensor_split[i] = std::stof(split_arg[i]); } else { @@ -989,30 +989,30 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n"); printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n"); - if (llama_mlock_supported()) { + if (llama_supports_mlock()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } - if (llama_mmap_supported()) { + if (llama_supports_mmap()) { printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); } printf(" --numa attempt optimizations that help on some NUMA systems\n"); printf(" if run without this previously, it is recommended to drop the system page cache before using this\n"); printf(" see https://github.com/ggerganov/llama.cpp/issues/1437\n"); -#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD - printf(" -ngl N, --n-gpu-layers N\n"); - printf(" number of layers to store in VRAM\n"); - printf(" -ngld N, --n-gpu-layers-draft N\n"); - printf(" number of layers to store in VRAM for the draft model\n"); - printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n"); - printf(" how to split the model across multiple GPUs, one of:\n"); - printf(" - none: use one GPU only\n"); - printf(" - layer (default): split layers and KV across GPUs\n"); - printf(" - row: split rows across GPUs\n"); - printf(" -ts SPLIT, --tensor-split SPLIT\n"); - printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n"); - printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); - printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu); -#endif // LLAMA_SUPPORTS_GPU_OFFLOAD + if (llama_supports_gpu_offload()) { + printf(" -ngl N, --n-gpu-layers N\n"); + printf(" number of layers to store in VRAM\n"); + printf(" -ngld N, --n-gpu-layers-draft N\n"); + printf(" number of layers to store in VRAM for the draft model\n"); + printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n"); + printf(" how to split the model across multiple GPUs, one of:\n"); + printf(" - none: use one GPU only\n"); + printf(" - layer (default): split layers and KV across GPUs\n"); + printf(" - row: split rows across GPUs\n"); + printf(" -ts SPLIT, --tensor-split SPLIT\n"); + printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n"); + printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); + printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu); + } printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false"); printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false"); printf(" -gan N, --grp-attn-n N\n"); @@ -1651,7 +1651,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); - const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); + const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); dump_vector_float_yaml(stream, "tensor_split", tensor_split_vector); fprintf(stream, "tfs: %f # default: 1.0\n", sparams.tfs_z); diff --git a/common/common.h b/common/common.h index 214a379b5..24a99d728 100644 --- a/common/common.h +++ b/common/common.h @@ -43,40 +43,40 @@ extern char const *LLAMA_BUILD_TARGET; int32_t get_num_physical_cores(); struct gpt_params { - uint32_t seed = -1; // RNG seed + uint32_t seed = -1; // RNG seed - int32_t n_threads = get_num_physical_cores(); - int32_t n_threads_draft = -1; - int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) - int32_t n_threads_batch_draft = -1; - int32_t n_predict = -1; // new tokens to predict - int32_t n_ctx = 512; // context size - int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_draft = 8; // number of tokens to draft during speculative decoding - int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) - int32_t n_parallel = 1; // number of parallel sequences to decode - int32_t n_sequences = 1; // number of sequences to decode - float p_accept = 0.5f; // speculative decoding accept probability - float p_split = 0.1f; // speculative decoding split probability - int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) - int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs - int32_t n_beams = 0; // if non-zero then use beam search of given width. - int32_t grp_attn_n = 1; // group-attention factor - int32_t grp_attn_w = 512; // group-attention width - int32_t n_print = -1; // print token count every n tokens (-1 = disabled) - float rope_freq_base = 0.0f; // RoPE base frequency - float rope_freq_scale = 0.0f; // RoPE frequency scaling factor - float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor - float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor - float yarn_beta_fast = 32.0f; // YaRN low correction dim - float yarn_beta_slow = 1.0f; // YaRN high correction dim - int32_t yarn_orig_ctx = 0; // YaRN original context length - int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment - // pinging @cebtenzzre + int32_t n_threads = get_num_physical_cores(); + int32_t n_threads_draft = -1; + int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads) + int32_t n_threads_batch_draft = -1; + int32_t n_predict = -1; // new tokens to predict + int32_t n_ctx = 512; // context size + int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) + int32_t n_keep = 0; // number of tokens to keep from initial prompt + int32_t n_draft = 8; // number of tokens to draft during speculative decoding + int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) + int32_t n_parallel = 1; // number of parallel sequences to decode + int32_t n_sequences = 1; // number of sequences to decode + float p_accept = 0.5f; // speculative decoding accept probability + float p_split = 0.1f; // speculative decoding split probability + int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default) + int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default) + llama_split_mode split_mode = LLAMA_SPLIT_LAYER; // how to split the model across GPUs + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + int32_t n_beams = 0; // if non-zero then use beam search of given width. + int32_t grp_attn_n = 1; // group-attention factor + int32_t grp_attn_w = 512; // group-attention width + int32_t n_print = -1; // print token count every n tokens (-1 = disabled) + float rope_freq_base = 0.0f; // RoPE base frequency + float rope_freq_scale = 0.0f; // RoPE frequency scaling factor + float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor + float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor + float yarn_beta_fast = 32.0f; // YaRN low correction dim + float yarn_beta_slow = 1.0f; // YaRN high correction dim + int32_t yarn_orig_ctx = 0; // YaRN original context length + int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment + // pinging @cebtenzzre // // sampling parameters struct llama_sampling_params sparams; diff --git a/common/train.cpp b/common/train.cpp index e6f2f7a2f..e4c3d5df6 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -1363,12 +1363,12 @@ bool consume_common_train_arg( *invalid_param = true; return true; } -#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD - params->n_gpu_layers = std::stoi(argv[i]); -#else - fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); - fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); -#endif + if (llama_supports_gpu_offload()) { + params->n_gpu_layers = std::stoi(argv[i]); + } else { + fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n"); + fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); + } } else if (arg == "-h" || arg == "--help") { params->print_usage = true; return true; diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 7924db267..b52d68457 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -88,7 +88,7 @@ int main(int argc, char ** argv) { llama_model_params model_params = llama_model_default_params(); - const std::vector t_split (LLAMA_MAX_DEVICES, 0.0f); + const std::vector t_split(llama_max_devices(), 0.0f); model_params.n_gpu_layers = n_gpu_layers; model_params.tensor_split = t_split.data(); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 542cc7bb8..c5a6f744e 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -160,7 +160,7 @@ struct cmd_params { std::vector main_gpu; std::vector no_kv_offload; std::vector mul_mat_q; - std::vector> tensor_split; + std::vector> tensor_split; int reps; bool verbose; output_formats output_format; @@ -179,7 +179,7 @@ static const cmd_params cmd_params_defaults = { /* main_gpu */ {0}, /* no_kv_offload */ {false}, /* mul_mat_q */ {true}, - /* tensor_split */ {{}}, + /* tensor_split */ {std::vector(llama_max_devices(), 0.0f)}, /* reps */ 5, /* verbose */ false, /* output_format */ MARKDOWN @@ -380,10 +380,10 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { const std::regex regex{R"([;/]+)"}; std::sregex_token_iterator it{ts.begin(), ts.end(), regex, -1}; std::vector split_arg{it, {}}; - GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); + GGML_ASSERT(split_arg.size() <= llama_max_devices()); - std::array tensor_split; - for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) { + std::vector tensor_split(llama_max_devices()); + for (size_t i = 0; i < llama_max_devices(); ++i) { if (i < split_arg.size()) { tensor_split[i] = std::stof(split_arg[i]); } else { @@ -459,7 +459,7 @@ struct cmd_params_instance { int main_gpu; bool no_kv_offload; bool mul_mat_q; - std::array tensor_split; + std::vector tensor_split; llama_model_params to_llama_mparams() const { llama_model_params mparams = llama_model_default_params(); @@ -582,7 +582,7 @@ struct test { int main_gpu; bool no_kv_offload; bool mul_mat_q; - std::array tensor_split; + std::vector tensor_split; int n_prompt; int n_gen; std::string test_time; @@ -704,7 +704,7 @@ struct test { std::vector get_values() const { std::string tensor_split_str; int max_nonzero = 0; - for (int i = 0; i < LLAMA_MAX_DEVICES; i++) { + for (size_t i = 0; i < llama_max_devices(); i++) { if (tensor_split[i] > 0) { max_nonzero = i; } diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 21bdce8ed..ea77125ea 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1789,28 +1789,28 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); - if (llama_mlock_supported()) + if (llama_supports_mlock()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } - if (llama_mmap_supported()) + if (llama_supports_mmap()) { printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); } printf(" --numa attempt optimizations that help on some NUMA systems\n"); -#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD - printf(" -ngl N, --n-gpu-layers N\n"); - printf(" number of layers to store in VRAM\n"); - printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n"); - printf(" how to split the model across multiple GPUs, one of:\n"); - printf(" - none: use one GPU only\n"); - printf(" - layer (default): split layers and KV across GPUs\n"); - printf(" - row: split rows across GPUs\n"); - printf(" -ts SPLIT --tensor-split SPLIT\n"); - printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n"); - printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); - printf(" or for intermediate results and KV (with split-mode = row)\n"); -#endif + if (llama_supports_gpu_offload()) { + printf(" -ngl N, --n-gpu-layers N\n"); + printf(" number of layers to store in VRAM\n"); + printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n"); + printf(" how to split the model across multiple GPUs, one of:\n"); + printf(" - none: use one GPU only\n"); + printf(" - layer (default): split layers and KV across GPUs\n"); + printf(" - row: split rows across GPUs\n"); + printf(" -ts SPLIT --tensor-split SPLIT\n"); + printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n"); + printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n"); + printf(" or for intermediate results and KV (with split-mode = row)\n"); + } printf(" -m FNAME, --model FNAME\n"); printf(" model path (default: %s)\n", params.model.c_str()); printf(" -a ALIAS, --alias ALIAS\n"); @@ -2066,13 +2066,13 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, invalid_param = true; break; } -#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD - params.n_gpu_layers = std::stoi(argv[i]); -#else - LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " + if (llama_supports_gpu_offload()) { + params.n_gpu_layers = std::stoi(argv[i]); + } else { + LOG_WARNING("Not compiled with GPU offload support, --n-gpu-layers option will be ignored. " "See main README.md for information on enabling GPU BLAS support", {{"n_gpu_layers", params.n_gpu_layers}}); -#endif + } } else if (arg == "--split-mode" || arg == "-sm") { @@ -2115,9 +2115,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, const std::regex regex{R"([,/]+)"}; std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1}; std::vector split_arg{it, {}}; - GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES); + GGML_ASSERT(split_arg.size() <= llama_max_devices()); - for (size_t i_device = 0; i_device < LLAMA_MAX_DEVICES; ++i_device) + for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) { if (i_device < split_arg.size()) { diff --git a/llama.cpp b/llama.cpp index bb23689fa..9b249ba9c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -10090,18 +10090,45 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { return result; } -int32_t llama_max_devices(void) { - return LLAMA_MAX_DEVICES; +size_t llama_max_devices(void) { +#if defined(GGML_USE_METAL) + return 1; +#elif defined(GGML_USE_CUBLAS) + return GGML_CUDA_MAX_DEVICES; +#elif defined(GGML_USE_SYCL) + return GGML_SYCL_MAX_DEVICES; +#else + return 1; +#endif } -bool llama_mmap_supported(void) { +bool llama_supports_mmap(void) { return llama_mmap::SUPPORTED; } -bool llama_mlock_supported(void) { +bool llama_supports_mlock(void) { return llama_mlock::SUPPORTED; } +bool llama_supports_gpu_offload(void) { +#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \ + defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) + // Defined when llama.cpp is compiled with support for offloading model layers to GPU. + return true; +#else + return false; +#endif +} + +// deprecated: +bool llama_mmap_supported(void) { + return llama_supports_mmap(); +} + +bool llama_mlock_supported(void) { + return llama_supports_mlock(); +} + void llama_backend_init(bool numa) { ggml_time_init(); @@ -10133,8 +10160,8 @@ int64_t llama_time_us(void) { } struct llama_model * llama_load_model_from_file( - const char * path_model, - struct llama_model_params params) { + const char * path_model, + struct llama_model_params params) { ggml_time_init(); llama_model * model = new llama_model; diff --git a/llama.h b/llama.h index 17d43d039..9a60e9bfb 100644 --- a/llama.h +++ b/llama.h @@ -3,15 +3,7 @@ #include "ggml.h" #include "ggml-backend.h" -#ifdef GGML_USE_CUBLAS -#include "ggml-cuda.h" -#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES -#elif defined(GGML_USE_SYCL) -#include "ggml-sycl.h" -#define LLAMA_MAX_DEVICES GGML_SYCL_MAX_DEVICES -#else -#define LLAMA_MAX_DEVICES 1 -#endif // GGML_USE_CUBLAS + #include #include #include @@ -49,12 +41,6 @@ #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_VERSION 4 -#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \ - defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) -// Defined when llama.cpp is compiled with support for offloading model layers to GPU. -#define LLAMA_SUPPORTS_GPU_OFFLOAD -#endif - #ifdef __cplusplus extern "C" { #endif @@ -201,7 +187,7 @@ extern "C" { // LLAMA_SPLIT_LAYER: ignored int32_t main_gpu; - // proportion of the model (layers or rows) to offload to each GPU, size: LLAMA_MAX_DEVICES + // proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices() const float * tensor_split; // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. @@ -338,9 +324,14 @@ extern "C" { LLAMA_API int64_t llama_time_us(void); - LLAMA_API int32_t llama_max_devices(void); - LLAMA_API bool llama_mmap_supported (void); - LLAMA_API bool llama_mlock_supported(void); + LLAMA_API size_t llama_max_devices(void); + + LLAMA_API bool llama_supports_mmap (void); + LLAMA_API bool llama_supports_mlock (void); + LLAMA_API bool llama_supports_gpu_offload(void); + + LLAMA_API DEPRECATED(bool llama_mmap_supported (void), "use llama_supports_mmap() instead"); + LLAMA_API DEPRECATED(bool llama_mlock_supported(void), "use llama_supports_mlock() instead"); LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);