mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
ggml : add RPC backend (#6829)
* ggml : add RPC backend The RPC backend proxies all operations to a remote server which runs a regular backend (CPU, CUDA, Metal, etc). * set TCP_NODELAY * add CI workflows * Address review comments * fix warning * implement llama_max_devices() for RPC * Address review comments * Address review comments * wrap sockfd into a struct * implement get_alignment and get_max_size * add get_device_memory * fix warning * win32 support * add README * readme : trim trailing whitespace * Address review comments * win32 fix * Address review comments * fix compile warnings on macos
This commit is contained in:
parent
541600201e
commit
5e31828d3e
32
.github/workflows/build.yml
vendored
32
.github/workflows/build.yml
vendored
@ -340,6 +340,36 @@ jobs:
|
|||||||
cd build
|
cd build
|
||||||
ctest -L main --verbose
|
ctest -L main --verbose
|
||||||
|
|
||||||
|
ubuntu-latest-cmake-rpc:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
continue-on-error: true
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Clone
|
||||||
|
id: checkout
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Dependencies
|
||||||
|
id: depends
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install build-essential
|
||||||
|
|
||||||
|
- name: Build
|
||||||
|
id: cmake_build
|
||||||
|
run: |
|
||||||
|
mkdir build
|
||||||
|
cd build
|
||||||
|
cmake -DLLAMA_RPC=ON ..
|
||||||
|
cmake --build . --config Release -j $(nproc)
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
id: cmake_test
|
||||||
|
run: |
|
||||||
|
cd build
|
||||||
|
ctest -L main --verbose
|
||||||
|
|
||||||
ubuntu-22-cmake-vulkan:
|
ubuntu-22-cmake-vulkan:
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
|
|
||||||
@ -663,6 +693,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
|
- build: 'rpc'
|
||||||
|
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_RPC=ON -DBUILD_SHARED_LIBS=ON'
|
||||||
- build: 'noavx'
|
- build: 'noavx'
|
||||||
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DBUILD_SHARED_LIBS=ON'
|
defines: '-DLLAMA_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DBUILD_SHARED_LIBS=ON'
|
||||||
- build: 'avx2'
|
- build: 'avx2'
|
||||||
|
@ -123,6 +123,7 @@ set(LLAMA_METAL_MACOSX_VERSION_MIN "" CACHE STRING
|
|||||||
set(LLAMA_METAL_STD "" CACHE STRING "llama: metal standard version (-std flag)")
|
set(LLAMA_METAL_STD "" CACHE STRING "llama: metal standard version (-std flag)")
|
||||||
option(LLAMA_KOMPUTE "llama: use Kompute" OFF)
|
option(LLAMA_KOMPUTE "llama: use Kompute" OFF)
|
||||||
option(LLAMA_MPI "llama: use MPI" OFF)
|
option(LLAMA_MPI "llama: use MPI" OFF)
|
||||||
|
option(LLAMA_RPC "llama: use RPC" OFF)
|
||||||
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
|
option(LLAMA_QKK_64 "llama: use super-block size of 64 for k-quants" OFF)
|
||||||
option(LLAMA_SYCL "llama: use SYCL" OFF)
|
option(LLAMA_SYCL "llama: use SYCL" OFF)
|
||||||
option(LLAMA_SYCL_F16 "llama: use 16 bit floats for sycl calculations" OFF)
|
option(LLAMA_SYCL_F16 "llama: use 16 bit floats for sycl calculations" OFF)
|
||||||
@ -494,6 +495,17 @@ if (LLAMA_MPI)
|
|||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_RPC)
|
||||||
|
add_compile_definitions(GGML_USE_RPC)
|
||||||
|
|
||||||
|
if (WIN32)
|
||||||
|
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ws2_32)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(GGML_HEADERS_RPC ggml-rpc.h)
|
||||||
|
set(GGML_SOURCES_RPC ggml-rpc.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (LLAMA_CLBLAST)
|
if (LLAMA_CLBLAST)
|
||||||
find_package(CLBlast)
|
find_package(CLBlast)
|
||||||
if (CLBlast_FOUND)
|
if (CLBlast_FOUND)
|
||||||
@ -1176,6 +1188,7 @@ add_library(ggml OBJECT
|
|||||||
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
|
${GGML_SOURCES_OPENCL} ${GGML_HEADERS_OPENCL}
|
||||||
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
|
${GGML_SOURCES_METAL} ${GGML_HEADERS_METAL}
|
||||||
${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI}
|
${GGML_SOURCES_MPI} ${GGML_HEADERS_MPI}
|
||||||
|
${GGML_SOURCES_RPC} ${GGML_HEADERS_RPC}
|
||||||
${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
|
${GGML_SOURCES_EXTRA} ${GGML_HEADERS_EXTRA}
|
||||||
${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL}
|
${GGML_SOURCES_SYCL} ${GGML_HEADERS_SYCL}
|
||||||
${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
|
${GGML_SOURCES_KOMPUTE} ${GGML_HEADERS_KOMPUTE}
|
||||||
|
@ -1060,6 +1060,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
|||||||
#endif // GGML_USE_CUDA_SYCL_VULKAN
|
#endif // GGML_USE_CUDA_SYCL_VULKAN
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--rpc") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params.rpc_servers = argv[i];
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "--no-mmap") {
|
if (arg == "--no-mmap") {
|
||||||
params.use_mmap = false;
|
params.use_mmap = false;
|
||||||
return true;
|
return true;
|
||||||
@ -1557,6 +1565,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
|
printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
|
||||||
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
|
printf(" or for intermediate results and KV (with split-mode = row) (default: %d)\n", params.main_gpu);
|
||||||
}
|
}
|
||||||
|
printf(" --rpc SERVERS comma separated list of RPC servers\n");
|
||||||
printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
|
printf(" --verbose-prompt print a verbose prompt before generation (default: %s)\n", params.verbose_prompt ? "true" : "false");
|
||||||
printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
|
printf(" --no-display-prompt don't print prompt at generation (default: %s)\n", !params.display_prompt ? "true" : "false");
|
||||||
printf(" -gan N, --grp-attn-n N\n");
|
printf(" -gan N, --grp-attn-n N\n");
|
||||||
@ -1830,6 +1839,7 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
|
|||||||
if (params.n_gpu_layers != -1) {
|
if (params.n_gpu_layers != -1) {
|
||||||
mparams.n_gpu_layers = params.n_gpu_layers;
|
mparams.n_gpu_layers = params.n_gpu_layers;
|
||||||
}
|
}
|
||||||
|
mparams.rpc_servers = params.rpc_servers.c_str();
|
||||||
mparams.main_gpu = params.main_gpu;
|
mparams.main_gpu = params.main_gpu;
|
||||||
mparams.split_mode = params.split_mode;
|
mparams.split_mode = params.split_mode;
|
||||||
mparams.tensor_split = params.tensor_split;
|
mparams.tensor_split = params.tensor_split;
|
||||||
|
@ -82,6 +82,7 @@ struct gpt_params {
|
|||||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||||
|
std::string rpc_servers = ""; // comma separated list of RPC servers
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||||
void * cb_eval_user_data = nullptr;
|
void * cb_eval_user_data = nullptr;
|
||||||
|
@ -49,4 +49,7 @@ else()
|
|||||||
add_subdirectory(server)
|
add_subdirectory(server)
|
||||||
endif()
|
endif()
|
||||||
add_subdirectory(export-lora)
|
add_subdirectory(export-lora)
|
||||||
|
if (LLAMA_RPC)
|
||||||
|
add_subdirectory(rpc)
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
2
examples/rpc/CMakeLists.txt
Normal file
2
examples/rpc/CMakeLists.txt
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
add_executable(rpc-server rpc-server.cpp)
|
||||||
|
target_link_libraries(rpc-server PRIVATE ggml llama)
|
74
examples/rpc/README.md
Normal file
74
examples/rpc/README.md
Normal file
@ -0,0 +1,74 @@
|
|||||||
|
## Overview
|
||||||
|
|
||||||
|
The `rpc-server` allows running `ggml` backend on a remote host.
|
||||||
|
The RPC backend communicates with one or several instances of `rpc-server` and offloads computations to them.
|
||||||
|
This can be used for distributed LLM inference with `llama.cpp` in the following way:
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
flowchart TD
|
||||||
|
rpcb---|TCP|srva
|
||||||
|
rpcb---|TCP|srvb
|
||||||
|
rpcb-.-|TCP|srvn
|
||||||
|
subgraph hostn[Host N]
|
||||||
|
srvn[rpc-server]-.-backend3["Backend (CUDA,Metal,etc.)"]
|
||||||
|
end
|
||||||
|
subgraph hostb[Host B]
|
||||||
|
srvb[rpc-server]---backend2["Backend (CUDA,Metal,etc.)"]
|
||||||
|
end
|
||||||
|
subgraph hosta[Host A]
|
||||||
|
srva[rpc-server]---backend["Backend (CUDA,Metal,etc.)"]
|
||||||
|
end
|
||||||
|
subgraph host[Main Host]
|
||||||
|
ggml[llama.cpp]---rpcb[RPC backend]
|
||||||
|
end
|
||||||
|
style hostn stroke:#66,stroke-width:2px,stroke-dasharray: 5 5
|
||||||
|
```
|
||||||
|
|
||||||
|
Each host can run a different backend, e.g. one with CUDA and another with Metal.
|
||||||
|
You can also run multiple `rpc-server` instances on the same host, each with a different backend.
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
On each host, build the corresponding backend with `cmake` and add `-DLLAMA_RPC=ON` to the build options.
|
||||||
|
For example, to build the CUDA backend with RPC support:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir build-rpc-cuda
|
||||||
|
cd build-rpc-cuda
|
||||||
|
cmake .. -DLLAMA_CUDA=ON -DLLAMA_RPC=ON
|
||||||
|
cmake --build . --config Release
|
||||||
|
```
|
||||||
|
|
||||||
|
Then, start the `rpc-server` with the backend:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ bin/rpc-server 0.0.0.0 50052
|
||||||
|
create_backend: using CUDA backend
|
||||||
|
ggml_cuda_init: GGML_CUDA_FORCE_MMQ: no
|
||||||
|
ggml_cuda_init: CUDA_USE_TENSOR_CORES: yes
|
||||||
|
ggml_cuda_init: found 1 CUDA devices:
|
||||||
|
Device 0: NVIDIA T1200 Laptop GPU, compute capability 7.5, VMM: yes
|
||||||
|
Starting RPC server on 0.0.0.0:50052
|
||||||
|
```
|
||||||
|
|
||||||
|
When using the CUDA backend, you can specify the device with the `CUDA_VISIBLE_DEVICES` environment variable, e.g.:
|
||||||
|
```bash
|
||||||
|
$ CUDA_VISIBLE_DEVICES=0 bin/rpc-server 0.0.0.0 50052
|
||||||
|
```
|
||||||
|
This way you can run multiple `rpc-server` instances on the same host, each with a different CUDA device.
|
||||||
|
|
||||||
|
|
||||||
|
On the main host build `llama.cpp` only with `-DLLAMA_RPC=ON`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir build-rpc
|
||||||
|
cd build-rpc
|
||||||
|
cmake .. -DLLAMA_RPC=ON
|
||||||
|
cmake --build . --config Release
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, use the `--rpc` option to specify the host and port of each `rpc-server`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ bin/main -m ../models/tinyllama-1b/ggml-model-f16.gguf -p "Hello, my name is" --repeat-penalty 1.0 -n 64 --rpc 192.168.88.10:50052,192.168.88.11:50052 -ngl 99
|
||||||
|
```
|
70
examples/rpc/rpc-server.cpp
Normal file
70
examples/rpc/rpc-server.cpp
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
#ifdef GGML_USE_CUDA
|
||||||
|
#include "ggml-cuda.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_METAL
|
||||||
|
#include "ggml-metal.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "ggml-rpc.h"
|
||||||
|
#include <string>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
static ggml_backend_t create_backend() {
|
||||||
|
ggml_backend_t backend = NULL;
|
||||||
|
#ifdef GGML_USE_CUDA
|
||||||
|
fprintf(stderr, "%s: using CUDA backend\n", __func__);
|
||||||
|
backend = ggml_backend_cuda_init(0); // init device 0
|
||||||
|
if (!backend) {
|
||||||
|
fprintf(stderr, "%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||||
|
}
|
||||||
|
#elif GGML_USE_METAL
|
||||||
|
fprintf(stderr, "%s: using Metal backend\n", __func__);
|
||||||
|
backend = ggml_backend_metal_init();
|
||||||
|
if (!backend) {
|
||||||
|
fprintf(stderr, "%s: ggml_backend_metal_init() failed\n", __func__);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// if there aren't GPU Backends fallback to CPU backend
|
||||||
|
if (!backend) {
|
||||||
|
fprintf(stderr, "%s: using CPU backend\n", __func__);
|
||||||
|
backend = ggml_backend_cpu_init();
|
||||||
|
}
|
||||||
|
return backend;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void get_backend_memory(size_t * free_mem, size_t * total_mem) {
|
||||||
|
#ifdef GGML_USE_CUDA
|
||||||
|
ggml_backend_cuda_get_device_memory(0, free_mem, total_mem);
|
||||||
|
#else
|
||||||
|
// TODO: implement for other backends
|
||||||
|
*free_mem = 1;
|
||||||
|
*total_mem = 1;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char * argv[]) {
|
||||||
|
if (argc < 3) {
|
||||||
|
fprintf(stderr, "Usage: %s <host> <port>\n", argv[0]);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
const char * host = argv[1];
|
||||||
|
int port = std::stoi(argv[2]);
|
||||||
|
if (port <= 0 || port > 65535) {
|
||||||
|
fprintf(stderr, "Invalid port number: %d\n", port);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
ggml_backend_t backend = create_backend();
|
||||||
|
if (!backend) {
|
||||||
|
fprintf(stderr, "Failed to create backend\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
printf("Starting RPC server on %s:%d\n", host, port);
|
||||||
|
size_t free_mem, total_mem;
|
||||||
|
get_backend_memory(&free_mem, &total_mem);
|
||||||
|
std::string endpoint = std::string(host) + ":" + std::to_string(port);
|
||||||
|
start_rpc_server(backend, endpoint.c_str(), free_mem, total_mem);
|
||||||
|
ggml_backend_free(backend);
|
||||||
|
return 0;
|
||||||
|
}
|
1023
ggml-rpc.cpp
Normal file
1023
ggml-rpc.cpp
Normal file
File diff suppressed because it is too large
Load Diff
24
ggml-rpc.h
Normal file
24
ggml-rpc.h
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ggml.h"
|
||||||
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
extern "C" {
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define GGML_RPC_MAX_SERVERS 16
|
||||||
|
|
||||||
|
// backend API
|
||||||
|
GGML_API GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint);
|
||||||
|
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend);
|
||||||
|
|
||||||
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint);
|
||||||
|
|
||||||
|
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
|
||||||
|
|
||||||
|
GGML_API GGML_CALL void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
}
|
||||||
|
#endif
|
238
llama.cpp
238
llama.cpp
@ -7,6 +7,10 @@
|
|||||||
#include "ggml-alloc.h"
|
#include "ggml-alloc.h"
|
||||||
#include "ggml-backend.h"
|
#include "ggml-backend.h"
|
||||||
|
|
||||||
|
#ifdef GGML_USE_RPC
|
||||||
|
# include "ggml-rpc.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef GGML_USE_CUDA
|
#ifdef GGML_USE_CUDA
|
||||||
# include "ggml-cuda.h"
|
# include "ggml-cuda.h"
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
@ -1685,91 +1689,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer
|
|||||||
GGML_UNUSED(host_buffer);
|
GGML_UNUSED(host_buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t llama_default_buffer_type_offload(int gpu) {
|
|
||||||
ggml_backend_buffer_type_t buft = nullptr;
|
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
|
||||||
buft = ggml_backend_metal_buffer_type();
|
|
||||||
#elif defined(GGML_USE_CUDA)
|
|
||||||
buft = ggml_backend_cuda_buffer_type(gpu);
|
|
||||||
#elif defined(GGML_USE_VULKAN)
|
|
||||||
buft = ggml_backend_vk_buffer_type(gpu);
|
|
||||||
#elif defined(GGML_USE_SYCL)
|
|
||||||
buft = ggml_backend_sycl_buffer_type(gpu);
|
|
||||||
#elif defined(GGML_USE_CLBLAST)
|
|
||||||
buft = ggml_backend_opencl_buffer_type();
|
|
||||||
#elif defined(GGML_USE_KOMPUTE)
|
|
||||||
buft = ggml_backend_kompute_buffer_type(gpu);
|
|
||||||
if (buft == nullptr) {
|
|
||||||
LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (buft == nullptr) {
|
|
||||||
buft = llama_default_buffer_type_cpu(true);
|
|
||||||
}
|
|
||||||
return buft;
|
|
||||||
|
|
||||||
GGML_UNUSED(gpu);
|
|
||||||
}
|
|
||||||
|
|
||||||
static ggml_backend_buffer_type_t llama_default_buffer_type_split(int fallback_gpu, const float * tensor_split) {
|
|
||||||
ggml_backend_buffer_type_t buft = nullptr;
|
|
||||||
|
|
||||||
#ifdef GGML_USE_CUDA
|
|
||||||
if (ggml_backend_cuda_get_device_count() > 1) {
|
|
||||||
buft = ggml_backend_cuda_split_buffer_type(tensor_split);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#ifdef GGML_USE_SYCL
|
|
||||||
if (ggml_backend_sycl_get_device_count() > 1) {
|
|
||||||
buft = ggml_backend_sycl_split_buffer_type(tensor_split);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (buft == nullptr) {
|
|
||||||
buft = llama_default_buffer_type_offload(fallback_gpu);
|
|
||||||
}
|
|
||||||
return buft;
|
|
||||||
|
|
||||||
GGML_UNUSED(tensor_split);
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t llama_get_device_count() {
|
|
||||||
#if defined(GGML_USE_CUDA)
|
|
||||||
return ggml_backend_cuda_get_device_count();
|
|
||||||
#elif defined(GGML_USE_SYCL)
|
|
||||||
return ggml_backend_sycl_get_device_count();
|
|
||||||
#elif defined(GGML_USE_VULKAN)
|
|
||||||
return ggml_backend_vk_get_device_count();
|
|
||||||
#else
|
|
||||||
return 1;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
static size_t llama_get_device_memory(int device) {
|
|
||||||
#if defined(GGML_USE_CUDA)
|
|
||||||
size_t total;
|
|
||||||
size_t free;
|
|
||||||
ggml_backend_cuda_get_device_memory(device, &free, &total);
|
|
||||||
return free;
|
|
||||||
#elif defined(GGML_USE_SYCL)
|
|
||||||
size_t total;
|
|
||||||
size_t free;
|
|
||||||
ggml_backend_sycl_get_device_memory(device, &free, &total);
|
|
||||||
return free;
|
|
||||||
#elif defined(GGML_USE_VULKAN)
|
|
||||||
size_t total;
|
|
||||||
size_t free;
|
|
||||||
ggml_backend_vk_get_device_memory(device, &free, &total);
|
|
||||||
return free;
|
|
||||||
#else
|
|
||||||
return 1;
|
|
||||||
GGML_UNUSED(device);
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// globals
|
// globals
|
||||||
//
|
//
|
||||||
@ -2210,6 +2129,8 @@ struct llama_model {
|
|||||||
int main_gpu;
|
int main_gpu;
|
||||||
int n_gpu_layers;
|
int n_gpu_layers;
|
||||||
|
|
||||||
|
std::vector<std::string> rpc_servers;
|
||||||
|
|
||||||
// gguf metadata
|
// gguf metadata
|
||||||
std::unordered_map<std::string, std::string> gguf_kv;
|
std::unordered_map<std::string, std::string> gguf_kv;
|
||||||
|
|
||||||
@ -2353,6 +2274,104 @@ struct llama_context {
|
|||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int gpu) {
|
||||||
|
ggml_backend_buffer_type_t buft = nullptr;
|
||||||
|
|
||||||
|
#ifdef GGML_USE_RPC
|
||||||
|
std::string endpoint = model.rpc_servers[gpu];
|
||||||
|
buft = ggml_backend_rpc_buffer_type(endpoint.c_str());
|
||||||
|
#elif defined(GGML_USE_METAL)
|
||||||
|
buft = ggml_backend_metal_buffer_type();
|
||||||
|
#elif defined(GGML_USE_CUDA)
|
||||||
|
buft = ggml_backend_cuda_buffer_type(gpu);
|
||||||
|
#elif defined(GGML_USE_VULKAN)
|
||||||
|
buft = ggml_backend_vk_buffer_type(gpu);
|
||||||
|
#elif defined(GGML_USE_SYCL)
|
||||||
|
buft = ggml_backend_sycl_buffer_type(gpu);
|
||||||
|
#elif defined(GGML_USE_CLBLAST)
|
||||||
|
buft = ggml_backend_opencl_buffer_type();
|
||||||
|
#elif defined(GGML_USE_KOMPUTE)
|
||||||
|
buft = ggml_backend_kompute_buffer_type(gpu);
|
||||||
|
if (buft == nullptr) {
|
||||||
|
LLAMA_LOG_WARN("%s: cannot use GPU %d, check `vulkaninfo --summary`\n", __func__, gpu);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (buft == nullptr) {
|
||||||
|
buft = llama_default_buffer_type_cpu(true);
|
||||||
|
}
|
||||||
|
return buft;
|
||||||
|
GGML_UNUSED(model);
|
||||||
|
GGML_UNUSED(gpu);
|
||||||
|
}
|
||||||
|
|
||||||
|
static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_model & model, int fallback_gpu, const float * tensor_split) {
|
||||||
|
ggml_backend_buffer_type_t buft = nullptr;
|
||||||
|
|
||||||
|
#ifdef GGML_USE_CUDA
|
||||||
|
if (ggml_backend_cuda_get_device_count() > 1) {
|
||||||
|
buft = ggml_backend_cuda_split_buffer_type(tensor_split);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_SYCL
|
||||||
|
if (ggml_backend_sycl_get_device_count() > 1) {
|
||||||
|
buft = ggml_backend_sycl_split_buffer_type(tensor_split);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (buft == nullptr) {
|
||||||
|
buft = llama_default_buffer_type_offload(model, fallback_gpu);
|
||||||
|
}
|
||||||
|
return buft;
|
||||||
|
|
||||||
|
GGML_UNUSED(tensor_split);
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t llama_get_device_count(const llama_model & model) {
|
||||||
|
#if defined(GGML_USE_RPC)
|
||||||
|
return model.rpc_servers.size();
|
||||||
|
#elif defined(GGML_USE_CUDA)
|
||||||
|
return ggml_backend_cuda_get_device_count();
|
||||||
|
#elif defined(GGML_USE_SYCL)
|
||||||
|
return ggml_backend_sycl_get_device_count();
|
||||||
|
#elif defined(GGML_USE_VULKAN)
|
||||||
|
return ggml_backend_vk_get_device_count();
|
||||||
|
#else
|
||||||
|
return 1;
|
||||||
|
#endif
|
||||||
|
GGML_UNUSED(model);
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t llama_get_device_memory(const llama_model & model, int device) {
|
||||||
|
#if defined(GGML_USE_RPC)
|
||||||
|
size_t total;
|
||||||
|
size_t free;
|
||||||
|
std::string endpoint = model.rpc_servers[device];
|
||||||
|
ggml_backend_rpc_get_device_memory(endpoint.c_str(), &free, &total);
|
||||||
|
return free;
|
||||||
|
#elif defined(GGML_USE_CUDA)
|
||||||
|
size_t total;
|
||||||
|
size_t free;
|
||||||
|
ggml_backend_cuda_get_device_memory(device, &free, &total);
|
||||||
|
return free;
|
||||||
|
#elif defined(GGML_USE_SYCL)
|
||||||
|
size_t total;
|
||||||
|
size_t free;
|
||||||
|
ggml_backend_sycl_get_device_memory(device, &free, &total);
|
||||||
|
return free;
|
||||||
|
#elif defined(GGML_USE_VULKAN)
|
||||||
|
size_t total;
|
||||||
|
size_t free;
|
||||||
|
ggml_backend_vk_get_device_memory(device, &free, &total);
|
||||||
|
return free;
|
||||||
|
#else
|
||||||
|
return 1;
|
||||||
|
#endif
|
||||||
|
GGML_UNUSED(model);
|
||||||
|
GGML_UNUSED(device);
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// kv cache helpers
|
// kv cache helpers
|
||||||
//
|
//
|
||||||
@ -4791,13 +4810,13 @@ static bool llm_load_tensors(
|
|||||||
|
|
||||||
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
|
if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
|
||||||
// calculate the split points
|
// calculate the split points
|
||||||
int device_count = llama_get_device_count();
|
int device_count = llama_get_device_count(model);
|
||||||
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
|
bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + device_count, [](float x) { return x == 0.0f; });
|
||||||
std::vector<float> splits(device_count);
|
std::vector<float> splits(device_count);
|
||||||
if (all_zero) {
|
if (all_zero) {
|
||||||
// default split, by free memory
|
// default split, by free memory
|
||||||
for (int i = 0; i < device_count; ++i) {
|
for (int i = 0; i < device_count; ++i) {
|
||||||
splits[i] = llama_get_device_memory(i);
|
splits[i] = llama_get_device_memory(model, i);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
std::copy(tensor_split, tensor_split + device_count, splits.begin());
|
std::copy(tensor_split, tensor_split + device_count, splits.begin());
|
||||||
@ -4817,35 +4836,35 @@ static bool llm_load_tensors(
|
|||||||
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
|
int act_gpu_layers = std::min(n_gpu_layers, (int)n_layer + 1);
|
||||||
for (int64_t i = i_gpu_start; i < n_layer; ++i) {
|
for (int64_t i = i_gpu_start; i < n_layer; ++i) {
|
||||||
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
|
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(i - i_gpu_start)/act_gpu_layers) - splits.begin();
|
||||||
model.buft_layer[i] = llama_default_buffer_type_offload(layer_gpu);
|
model.buft_layer[i] = llama_default_buffer_type_offload(model, layer_gpu);
|
||||||
}
|
}
|
||||||
// assign the output layer
|
// assign the output layer
|
||||||
if (n_gpu_layers > n_layer) {
|
if (n_gpu_layers > n_layer) {
|
||||||
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
|
int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + device_count, float(act_gpu_layers - 1)/act_gpu_layers) - splits.begin();
|
||||||
model.buft_output = llama_default_buffer_type_offload(layer_gpu);
|
model.buft_output = llama_default_buffer_type_offload(model, layer_gpu);
|
||||||
} else {
|
} else {
|
||||||
model.buft_output = llama_default_buffer_type_cpu(true);
|
model.buft_output = llama_default_buffer_type_cpu(true);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
ggml_backend_buffer_type_t split_buft;
|
ggml_backend_buffer_type_t split_buft;
|
||||||
if (split_mode == LLAMA_SPLIT_MODE_ROW) {
|
if (split_mode == LLAMA_SPLIT_MODE_ROW) {
|
||||||
split_buft = llama_default_buffer_type_split(main_gpu, tensor_split);
|
split_buft = llama_default_buffer_type_split(model, main_gpu, tensor_split);
|
||||||
} else {
|
} else {
|
||||||
// LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
|
// LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_LAYER in backends where it is not supported
|
||||||
split_buft = llama_default_buffer_type_offload(main_gpu);
|
split_buft = llama_default_buffer_type_offload(model, main_gpu);
|
||||||
}
|
}
|
||||||
// assign the repeating layers
|
// assign the repeating layers
|
||||||
for (int64_t i = i_gpu_start; i < n_layer; ++i) {
|
for (int64_t i = i_gpu_start; i < n_layer; ++i) {
|
||||||
model.buft_layer[i] = {
|
model.buft_layer[i] = {
|
||||||
split_buft,
|
split_buft,
|
||||||
llama_default_buffer_type_offload(main_gpu)
|
llama_default_buffer_type_offload(model, main_gpu)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
// assign the output layer
|
// assign the output layer
|
||||||
if (n_gpu_layers > n_layer) {
|
if (n_gpu_layers > n_layer) {
|
||||||
model.buft_output = {
|
model.buft_output = {
|
||||||
split_buft,
|
split_buft,
|
||||||
llama_default_buffer_type_offload(main_gpu)
|
llama_default_buffer_type_offload(model, main_gpu)
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
model.buft_output = llama_default_buffer_type_cpu(true);
|
model.buft_output = llama_default_buffer_type_cpu(true);
|
||||||
@ -15390,6 +15409,7 @@ struct llama_model_params llama_model_default_params() {
|
|||||||
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
|
/*.split_mode =*/ LLAMA_SPLIT_MODE_LAYER,
|
||||||
/*.main_gpu =*/ 0,
|
/*.main_gpu =*/ 0,
|
||||||
/*.tensor_split =*/ nullptr,
|
/*.tensor_split =*/ nullptr,
|
||||||
|
/*.rpc_servers =*/ nullptr,
|
||||||
/*.progress_callback =*/ nullptr,
|
/*.progress_callback =*/ nullptr,
|
||||||
/*.progress_callback_user_data =*/ nullptr,
|
/*.progress_callback_user_data =*/ nullptr,
|
||||||
/*.kv_overrides =*/ nullptr,
|
/*.kv_overrides =*/ nullptr,
|
||||||
@ -15460,7 +15480,9 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_max_devices(void) {
|
size_t llama_max_devices(void) {
|
||||||
#if defined(GGML_USE_METAL)
|
#if defined(GGML_USE_RPC)
|
||||||
|
return GGML_RPC_MAX_SERVERS;
|
||||||
|
#elif defined(GGML_USE_METAL)
|
||||||
return 1;
|
return 1;
|
||||||
#elif defined(GGML_USE_CUDA)
|
#elif defined(GGML_USE_CUDA)
|
||||||
return GGML_CUDA_MAX_DEVICES;
|
return GGML_CUDA_MAX_DEVICES;
|
||||||
@ -15483,7 +15505,7 @@ bool llama_supports_mlock(void) {
|
|||||||
|
|
||||||
bool llama_supports_gpu_offload(void) {
|
bool llama_supports_gpu_offload(void) {
|
||||||
#if defined(GGML_USE_CUDA) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
|
#if defined(GGML_USE_CUDA) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) || defined(GGML_USE_VULKAN) || \
|
||||||
defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE)
|
defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC)
|
||||||
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
|
||||||
return true;
|
return true;
|
||||||
#else
|
#else
|
||||||
@ -15546,7 +15568,17 @@ struct llama_model * llama_load_model_from_file(
|
|||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
if (params.rpc_servers != nullptr) {
|
||||||
|
// split the servers set them into model->rpc_servers
|
||||||
|
std::string servers(params.rpc_servers);
|
||||||
|
size_t pos = 0;
|
||||||
|
while ((pos = servers.find(",")) != std::string::npos) {
|
||||||
|
std::string server = servers.substr(0, pos);
|
||||||
|
model->rpc_servers.push_back(server);
|
||||||
|
servers.erase(0, pos + 1);
|
||||||
|
}
|
||||||
|
model->rpc_servers.push_back(servers);
|
||||||
|
}
|
||||||
int status = llama_model_load(path_model, *model, params);
|
int status = llama_model_load(path_model, *model, params);
|
||||||
GGML_ASSERT(status <= 0);
|
GGML_ASSERT(status <= 0);
|
||||||
if (status < 0) {
|
if (status < 0) {
|
||||||
@ -15693,7 +15725,17 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
|
|
||||||
if (!hparams.vocab_only) {
|
if (!hparams.vocab_only) {
|
||||||
// initialize backends
|
// initialize backends
|
||||||
#ifdef GGML_USE_METAL
|
#if defined(GGML_USE_RPC)
|
||||||
|
for (auto & server : model->rpc_servers) {
|
||||||
|
ggml_backend_t backend = ggml_backend_rpc_init(server.c_str());
|
||||||
|
if (backend == nullptr) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to connect RPC backend to %s\n", __func__, server.c_str());
|
||||||
|
llama_free(ctx);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
ctx->backends.push_back(backend);
|
||||||
|
}
|
||||||
|
#elif defined(GGML_USE_METAL)
|
||||||
if (model->n_gpu_layers > 0) {
|
if (model->n_gpu_layers > 0) {
|
||||||
ctx->backend_metal = ggml_backend_metal_init();
|
ctx->backend_metal = ggml_backend_metal_init();
|
||||||
if (ctx->backend_metal == nullptr) {
|
if (ctx->backend_metal == nullptr) {
|
||||||
@ -15850,7 +15892,7 @@ struct llama_context * llama_new_context_with_model(
|
|||||||
|
|
||||||
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
// enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary
|
||||||
bool pipeline_parallel =
|
bool pipeline_parallel =
|
||||||
llama_get_device_count() > 1 &&
|
llama_get_device_count(*model) > 1 &&
|
||||||
model->n_gpu_layers > (int)model->hparams.n_layer &&
|
model->n_gpu_layers > (int)model->hparams.n_layer &&
|
||||||
model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
|
model->split_mode == LLAMA_SPLIT_MODE_LAYER &&
|
||||||
params.offload_kqv;
|
params.offload_kqv;
|
||||||
|
3
llama.h
3
llama.h
@ -242,6 +242,9 @@ extern "C" {
|
|||||||
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
// proportion of the model (layers or rows) to offload to each GPU, size: llama_max_devices()
|
||||||
const float * tensor_split;
|
const float * tensor_split;
|
||||||
|
|
||||||
|
// comma separated list of RPC servers to use for offloading
|
||||||
|
const char * rpc_servers;
|
||||||
|
|
||||||
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
|
||||||
// If the provided progress_callback returns true, model loading continues.
|
// If the provided progress_callback returns true, model loading continues.
|
||||||
// If it returns false, model loading is immediately aborted.
|
// If it returns false, model loading is immediately aborted.
|
||||||
|
Loading…
Reference in New Issue
Block a user