update HIP_UMA #7399 (#7414)

* update HIP_UMA #7399

add use of hipMemAdviseSetCoarseGrain when LLAMA_HIP_UMA is enable.
- get x2 on prompte eval and x1.5 on token gen with rocm6.0 on ryzen 7940HX iGPU (780M/gfx1103)

* simplify code, more consistent style

---------

Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
Djip007 2024-05-28 01:40:47 +02:00 committed by GitHub
parent 0136966daf
commit 852aafb163
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 8 deletions

View File

@ -119,6 +119,20 @@ int ggml_cuda_get_device() {
return id; return id;
} }
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device);
#if defined(GGML_USE_HIPBLAS) && defined(GGML_HIP_UMA)
auto res = hipMallocManaged(ptr, size);
if (res == hipSuccess) {
// if error we "need" to know why...
CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
}
return res;
#else
return cudaMalloc(ptr, size);
#endif
}
static ggml_cuda_device_info ggml_cuda_init() { static ggml_cuda_device_info ggml_cuda_init() {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards: // Workaround for a rocBLAS bug when using multiple graphics cards:
@ -271,7 +285,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
size_t look_ahead_size = (size_t) (1.05 * size); size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256); look_ahead_size = 256 * ((look_ahead_size + 255)/256);
ggml_cuda_set_device(device); ggml_cuda_set_device(device);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size)); CUDA_CHECK(ggml_cuda_device_malloc(&ptr, look_ahead_size, device));
*actual_size = look_ahead_size; *actual_size = look_ahead_size;
pool_size += look_ahead_size; pool_size += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC #ifdef DEBUG_CUDA_MALLOC
@ -537,7 +551,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffe
size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0 size = std::max(size, (size_t)1); // cudaMalloc returns null for size 0
void * dev_ptr; void * dev_ptr;
cudaError_t err = cudaMalloc(&dev_ptr, size); cudaError_t err = ggml_cuda_device_malloc(&dev_ptr, size, buft_ctx->device);
if (err != cudaSuccess) { if (err != cudaSuccess) {
// clear the error // clear the error
cudaGetLastError(); cudaGetLastError();
@ -798,7 +812,7 @@ GGML_CALL static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_bu
// currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
ggml_cuda_set_device(id); ggml_cuda_set_device(id);
char * buf; char * buf;
CUDA_CHECK(cudaMalloc(&buf, size)); CUDA_CHECK(ggml_cuda_device_malloc((void**)&buf, size, id));
// set padding to 0 to avoid possible NaN values // set padding to 0 to avoid possible NaN values
if (size > original_size) { if (size > original_size) {

View File

@ -79,13 +79,8 @@
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister #define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc #define cudaLaunchHostFunc hipLaunchHostFunc
#ifdef GGML_HIP_UMA
#define cudaMalloc hipMallocManaged
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
#else
#define cudaMalloc hipMalloc #define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#endif
#define cudaMemcpy hipMemcpy #define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync #define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync #define cudaMemcpyPeerAsync hipMemcpyPeerAsync