ggml-cuda: Fix HIP build (#4528)

regression of #4490
Adds defines for two new datatypes
cublasComputeType_t, cudaDataType_t.

Currently using deprecated hipblasDatatype_t since newer ones very recent.
This commit is contained in:
arlo-phoenix 2023-12-18 22:33:45 +01:00 committed by GitHub
parent 0e18b2e7d0
commit a7aee47b98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -31,6 +31,7 @@
#define CUDA_R_16F HIPBLAS_R_16F #define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F #define CUDA_R_32F HIPBLAS_R_32F
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate #define cublasCreate hipblasCreate
#define cublasGemmEx hipblasGemmEx #define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx #define cublasGemmBatchedEx hipblasGemmBatchedEx
@ -40,6 +41,7 @@
#define cublasSetStream hipblasSetStream #define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm #define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t #define cublasStatus_t hipblasStatus_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess