mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
cuda : ROCm AMD Unified Memory Architecture (UMA) handling (#4449)
* AMD ROCm: handle UMA memory VRAM expansions This resolves #2797 by allowing ROCm AMD GPU users with a UMA to dynamically expand the VRAM allocated to the GPU. Without this, AMD ROCm users with shared CPU/GPU memory usually are stuck with the BIOS-set (or fixed) framebuffer VRAM, making it impossible to load more than 1-2 layers. Note that the model is duplicated in RAM because it's loaded once for the CPU and then copied into a second set of allocations that are managed by the HIP UMA system. We can fix this later. * clarify build process for ROCm on linux with cmake * avoid using deprecated ROCm hipMallocHost * keep simplifying the change required for UMA * cmake: enable UMA-compatible allocation when LLAMA_HIP_UMA=ON
This commit is contained in:
parent
562cf222b5
commit
0f630fbc92
@ -91,6 +91,7 @@ set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for
|
|||||||
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
set(LLAMA_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||||
"llama: max. batch size for using peer access")
|
"llama: max. batch size for using peer access")
|
||||||
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
|
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
|
||||||
|
option(LLAMA_HIP_UMA "llama: use HIP unified memory architecture" OFF)
|
||||||
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
||||||
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
|
option(LLAMA_METAL "llama: use Metal" ${LLAMA_METAL_DEFAULT})
|
||||||
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
|
option(LLAMA_METAL_NDEBUG "llama: disable Metal debugging" OFF)
|
||||||
@ -377,6 +378,9 @@ if (LLAMA_HIPBLAS)
|
|||||||
if (${hipblas_FOUND} AND ${hip_FOUND})
|
if (${hipblas_FOUND} AND ${hip_FOUND})
|
||||||
message(STATUS "HIP and hipBLAS found")
|
message(STATUS "HIP and hipBLAS found")
|
||||||
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
|
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
|
||||||
|
if (LLAMA_HIP_UMA)
|
||||||
|
add_compile_definitions(GGML_HIP_UMA)
|
||||||
|
endif()
|
||||||
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
|
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
|
||||||
if (BUILD_SHARED_LIBS)
|
if (BUILD_SHARED_LIBS)
|
||||||
set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
16
README.md
16
README.md
@ -432,14 +432,15 @@ Building the program with BLAS support may lead to some performance improvements
|
|||||||
```bash
|
```bash
|
||||||
make LLAMA_HIPBLAS=1
|
make LLAMA_HIPBLAS=1
|
||||||
```
|
```
|
||||||
- Using `CMake` for Linux:
|
- Using `CMake` for Linux (assuming a gfx1030-compatible AMD GPU):
|
||||||
```bash
|
```bash
|
||||||
mkdir build
|
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ \
|
||||||
cd build
|
cmake -H. -Bbuild -DLLAMA_HIPBLAS=ON -DAMDGPU_TARGETS=gfx1030 -DCMAKE_BUILD_TYPE=Release \
|
||||||
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DLLAMA_HIPBLAS=ON
|
&& cmake --build build -- -j 16
|
||||||
cmake --build .
|
|
||||||
```
|
```
|
||||||
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS):
|
On Linux it is also possible to use unified memory architecture (UMA) to share main memory between the CPU and integrated GPU by setting `-DLLAMA_HIP_UMA=ON"`.
|
||||||
|
However, this hurts performance for non-integrated GPUs.
|
||||||
|
- Using `CMake` for Windows (using x64 Native Tools Command Prompt for VS, and assuming a gfx1100-compatible AMD GPU):
|
||||||
```bash
|
```bash
|
||||||
set PATH=%HIP_PATH%\bin;%PATH%
|
set PATH=%HIP_PATH%\bin;%PATH%
|
||||||
mkdir build
|
mkdir build
|
||||||
@ -448,10 +449,11 @@ Building the program with BLAS support may lead to some performance improvements
|
|||||||
cmake --build .
|
cmake --build .
|
||||||
```
|
```
|
||||||
Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
|
Make sure that `AMDGPU_TARGETS` is set to the GPU arch you want to compile for. The above example uses `gfx1100` that corresponds to Radeon RX 7900XTX/XT/GRE. You can find a list of targets [here](https://llvm.org/docs/AMDGPUUsage.html#processors)
|
||||||
|
Find your gpu version string by matching the most significant version information from `rocminfo | grep gfx | head -1 | awk '{print $2}'` with the list of processors, e.g. `gfx1035` maps to `gfx1030`.
|
||||||
|
|
||||||
|
|
||||||
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
|
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
|
||||||
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 or 11.0.0 on RDNA3.
|
If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3.
|
||||||
The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above):
|
The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above):
|
||||||
|
|
||||||
| Option | Legal values | Default | Description |
|
| Option | Legal values | Default | Description |
|
||||||
|
@ -60,8 +60,13 @@
|
|||||||
#define cudaGetDeviceProperties hipGetDeviceProperties
|
#define cudaGetDeviceProperties hipGetDeviceProperties
|
||||||
#define cudaGetErrorString hipGetErrorString
|
#define cudaGetErrorString hipGetErrorString
|
||||||
#define cudaGetLastError hipGetLastError
|
#define cudaGetLastError hipGetLastError
|
||||||
|
#ifdef GGML_HIP_UMA
|
||||||
|
#define cudaMalloc hipMallocManaged
|
||||||
|
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
|
||||||
|
#else
|
||||||
#define cudaMalloc hipMalloc
|
#define cudaMalloc hipMalloc
|
||||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||||
|
#endif
|
||||||
#define cudaMemcpy hipMemcpy
|
#define cudaMemcpy hipMemcpy
|
||||||
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
||||||
#define cudaMemcpyAsync hipMemcpyAsync
|
#define cudaMemcpyAsync hipMemcpyAsync
|
||||||
|
Loading…
Reference in New Issue
Block a user