cuda : add ROCm / hipBLAS cublasGemmBatchedEx define

This commit is contained in:
Georgi Gerganov 2023-10-24 00:18:49 +03:00
parent 878aa4f209
commit d415669087
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -29,6 +29,7 @@
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasCreate hipblasCreate
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream