mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
feat: Support Moore Threads GPU (#8383)
* Update doc for MUSA Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Add GGML_MUSA in Makefile Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Add GGML_MUSA in CMake Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * CUDA => MUSA Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * MUSA adds support for __vsubss4 Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Fix CI build failure Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> --------- Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com>
This commit is contained in:
parent
5e2727fe03
commit
e54c35e4fb
57
Makefile
57
Makefile
@ -528,10 +528,21 @@ ifndef GGML_NO_ACCELERATE
|
||||
endif
|
||||
endif # GGML_NO_ACCELERATE
|
||||
|
||||
ifdef GGML_MUSA
|
||||
CC := clang
|
||||
CXX := clang++
|
||||
GGML_CUDA := 1
|
||||
MK_CPPFLAGS += -DGGML_USE_MUSA
|
||||
endif
|
||||
|
||||
ifndef GGML_NO_OPENMP
|
||||
MK_CPPFLAGS += -DGGML_USE_OPENMP
|
||||
MK_CFLAGS += -fopenmp
|
||||
MK_CXXFLAGS += -fopenmp
|
||||
ifdef GGML_MUSA
|
||||
MK_CPPFLAGS += -I/usr/lib/llvm-10/include/openmp
|
||||
MK_LDFLAGS += -L/usr/lib/llvm-10/lib
|
||||
endif # GGML_MUSA
|
||||
endif # GGML_NO_OPENMP
|
||||
|
||||
ifdef GGML_OPENBLAS
|
||||
@ -582,15 +593,27 @@ else
|
||||
endif # GGML_CUDA_FA_ALL_QUANTS
|
||||
|
||||
ifdef GGML_CUDA
|
||||
ifneq ('', '$(wildcard /opt/cuda)')
|
||||
CUDA_PATH ?= /opt/cuda
|
||||
else
|
||||
CUDA_PATH ?= /usr/local/cuda
|
||||
endif
|
||||
ifdef GGML_MUSA
|
||||
ifneq ('', '$(wildcard /opt/musa)')
|
||||
CUDA_PATH ?= /opt/musa
|
||||
else
|
||||
CUDA_PATH ?= /usr/local/musa
|
||||
endif
|
||||
|
||||
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
|
||||
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
|
||||
MK_NVCCFLAGS += -use_fast_math
|
||||
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include
|
||||
MK_LDFLAGS += -lmusa -lmublas -lmusart -lpthread -ldl -lrt -L$(CUDA_PATH)/lib -L/usr/lib64
|
||||
MK_NVCCFLAGS += -x musa -mtgpu --cuda-gpu-arch=mp_22
|
||||
else
|
||||
ifneq ('', '$(wildcard /opt/cuda)')
|
||||
CUDA_PATH ?= /opt/cuda
|
||||
else
|
||||
CUDA_PATH ?= /usr/local/cuda
|
||||
endif
|
||||
|
||||
MK_CPPFLAGS += -DGGML_USE_CUDA -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include -DGGML_CUDA_USE_GRAPHS
|
||||
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L$(CUDA_PATH)/lib64/stubs -L/usr/lib/wsl/lib
|
||||
MK_NVCCFLAGS += -use_fast_math
|
||||
endif # GGML_MUSA
|
||||
|
||||
OBJ_GGML += ggml/src/ggml-cuda.o
|
||||
OBJ_GGML += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu))
|
||||
@ -600,9 +623,11 @@ ifdef LLAMA_FATAL_WARNINGS
|
||||
MK_NVCCFLAGS += -Werror all-warnings
|
||||
endif # LLAMA_FATAL_WARNINGS
|
||||
|
||||
ifndef GGML_MUSA
|
||||
ifndef JETSON_EOL_MODULE_DETECT
|
||||
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
|
||||
endif # JETSON_EOL_MODULE_DETECT
|
||||
endif # GGML_MUSA
|
||||
|
||||
ifdef LLAMA_DEBUG
|
||||
MK_NVCCFLAGS += -lineinfo
|
||||
@ -615,8 +640,12 @@ endif # GGML_CUDA_DEBUG
|
||||
ifdef GGML_CUDA_NVCC
|
||||
NVCC = $(CCACHE) $(GGML_CUDA_NVCC)
|
||||
else
|
||||
NVCC = $(CCACHE) nvcc
|
||||
endif #GGML_CUDA_NVCC
|
||||
ifdef GGML_MUSA
|
||||
NVCC = $(CCACHE) mcc
|
||||
else
|
||||
NVCC = $(CCACHE) nvcc
|
||||
endif # GGML_MUSA
|
||||
endif # GGML_CUDA_NVCC
|
||||
|
||||
ifdef CUDA_DOCKER_ARCH
|
||||
MK_NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
|
||||
@ -687,9 +716,15 @@ define NVCC_COMPILE
|
||||
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
|
||||
endef # NVCC_COMPILE
|
||||
else
|
||||
ifdef GGML_MUSA
|
||||
define NVCC_COMPILE
|
||||
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -c $< -o $@
|
||||
endef # NVCC_COMPILE
|
||||
else
|
||||
define NVCC_COMPILE
|
||||
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
|
||||
endef # NVCC_COMPILE
|
||||
endif # GGML_MUSA
|
||||
endif # JETSON_EOL_MODULE_DETECT
|
||||
|
||||
ggml/src/ggml-cuda/%.o: \
|
||||
@ -944,6 +979,7 @@ $(info I CXX: $(shell $(CXX) --version | head -n 1))
|
||||
ifdef GGML_CUDA
|
||||
$(info I NVCC: $(shell $(NVCC) --version | tail -n 1))
|
||||
CUDA_VERSION := $(shell $(NVCC) --version | grep -oP 'release (\K[0-9]+\.[0-9])')
|
||||
ifndef GGML_MUSA
|
||||
ifeq ($(shell awk -v "v=$(CUDA_VERSION)" 'BEGIN { print (v < 11.7) }'),1)
|
||||
|
||||
ifndef CUDA_DOCKER_ARCH
|
||||
@ -953,6 +989,7 @@ endif # CUDA_POWER_ARCH
|
||||
endif # CUDA_DOCKER_ARCH
|
||||
|
||||
endif # eq ($(shell echo "$(CUDA_VERSION) < 11.7" | bc),1)
|
||||
endif # GGML_MUSA
|
||||
endif # GGML_CUDA
|
||||
$(info )
|
||||
|
||||
|
@ -409,6 +409,7 @@ Please refer to [Build llama.cpp locally](./docs/build.md)
|
||||
| [BLAS](./docs/build.md#blas-build) | All |
|
||||
| [BLIS](./docs/backend/BLIS.md) | All |
|
||||
| [SYCL](./docs/backend/SYCL.md) | Intel and Nvidia GPU |
|
||||
| [MUSA](./docs/build.md#musa) | Moore Threads GPU |
|
||||
| [CUDA](./docs/build.md#cuda) | Nvidia GPU |
|
||||
| [hipBLAS](./docs/build.md#hipblas) | AMD GPU |
|
||||
| [Vulkan](./docs/build.md#vulkan) | GPU |
|
||||
|
@ -192,6 +192,19 @@ The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/c
|
||||
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
|
||||
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
|
||||
|
||||
### MUSA
|
||||
|
||||
- Using `make`:
|
||||
```bash
|
||||
make GGML_MUSA=1
|
||||
```
|
||||
- Using `CMake`:
|
||||
|
||||
```bash
|
||||
cmake -B build -DGGML_MUSA=ON
|
||||
cmake --build build --config Release
|
||||
```
|
||||
|
||||
### hipBLAS
|
||||
|
||||
This provides BLAS acceleration on HIP-supported AMD GPUs.
|
||||
|
@ -113,6 +113,7 @@ set(GGML_BLAS_VENDOR ${GGML_BLAS_VENDOR_DEFAULT} CACHE STRING
|
||||
option(GGML_LLAMAFILE "ggml: use LLAMAFILE" OFF)
|
||||
|
||||
option(GGML_CUDA "ggml: use CUDA" OFF)
|
||||
option(GGML_MUSA "ggml: use MUSA" OFF)
|
||||
option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF)
|
||||
option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF)
|
||||
option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF)
|
||||
|
@ -6,6 +6,9 @@
|
||||
#ifdef GGML_USE_HIPBLAS
|
||||
#define GGML_CUDA_NAME "ROCm"
|
||||
#define GGML_CUBLAS_NAME "hipBLAS"
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
#define GGML_CUDA_NAME "MUSA"
|
||||
#define GGML_CUBLAS_NAME "muBLAS"
|
||||
#else
|
||||
#define GGML_CUDA_NAME "CUDA"
|
||||
#define GGML_CUBLAS_NAME "cuBLAS"
|
||||
|
@ -139,6 +139,17 @@ if (GGML_METAL)
|
||||
)
|
||||
endif()
|
||||
|
||||
if (GGML_MUSA)
|
||||
set(CMAKE_C_COMPILER clang)
|
||||
set(CMAKE_C_EXTENSIONS OFF)
|
||||
set(CMAKE_CXX_COMPILER clang++)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
set(GGML_CUDA ON)
|
||||
|
||||
list(APPEND GGML_CDEF_PUBLIC GGML_USE_MUSA)
|
||||
endif()
|
||||
|
||||
if (GGML_OPENMP)
|
||||
find_package(OpenMP)
|
||||
if (OpenMP_FOUND)
|
||||
@ -147,6 +158,11 @@ if (GGML_OPENMP)
|
||||
add_compile_definitions(GGML_USE_OPENMP)
|
||||
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
|
||||
|
||||
if (GGML_MUSA)
|
||||
set(GGML_EXTRA_INCLUDES ${GGML_EXTRA_INCLUDES} "/usr/lib/llvm-10/include/openmp")
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} "/usr/lib/llvm-10/lib/libomp.so")
|
||||
endif()
|
||||
else()
|
||||
message(WARNING "OpenMP not found")
|
||||
endif()
|
||||
@ -249,7 +265,13 @@ endif()
|
||||
if (GGML_CUDA)
|
||||
cmake_minimum_required(VERSION 3.18) # for CMAKE_CUDA_ARCHITECTURES
|
||||
|
||||
find_package(CUDAToolkit)
|
||||
if (GGML_MUSA)
|
||||
list(APPEND CMAKE_MODULE_PATH "/usr/local/musa/cmake/")
|
||||
find_package(MUSAToolkit)
|
||||
set(CUDAToolkit_FOUND ${MUSAToolkit_FOUND})
|
||||
else()
|
||||
find_package(CUDAToolkit)
|
||||
endif()
|
||||
|
||||
if (CUDAToolkit_FOUND)
|
||||
message(STATUS "CUDA found")
|
||||
@ -268,7 +290,11 @@ if (GGML_CUDA)
|
||||
endif()
|
||||
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
|
||||
|
||||
enable_language(CUDA)
|
||||
if (GGML_MUSA)
|
||||
set(CMAKE_CUDA_COMPILER ${MUSAToolkit_MCC_EXECUTABLE})
|
||||
else()
|
||||
enable_language(CUDA)
|
||||
endif()
|
||||
|
||||
file(GLOB GGML_HEADERS_CUDA "ggml-cuda/*.cuh")
|
||||
list(APPEND GGML_HEADERS_CUDA "../include/ggml-cuda.h")
|
||||
@ -332,21 +358,40 @@ if (GGML_CUDA)
|
||||
add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
|
||||
endif()
|
||||
|
||||
if (GGML_MUSA)
|
||||
set_source_files_properties(${GGML_SOURCES_CUDA} PROPERTIES LANGUAGE CXX)
|
||||
foreach(SOURCE ${GGML_SOURCES_CUDA})
|
||||
set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS "-x musa -mtgpu --cuda-gpu-arch=mp_22")
|
||||
endforeach()
|
||||
endif()
|
||||
|
||||
if (GGML_STATIC)
|
||||
if (WIN32)
|
||||
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas CUDA::cublasLt)
|
||||
else ()
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
if (GGML_MUSA)
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart_static MUSA::mublas_static)
|
||||
else()
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||
if (GGML_MUSA)
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musart MUSA::mublas)
|
||||
else()
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (GGML_CUDA_NO_VMM)
|
||||
# No VMM requested, no need to link directly with the cuda driver lib (libcuda.so)
|
||||
else()
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
|
||||
if (GGML_MUSA)
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} MUSA::musa_driver) # required by muDeviceGetAttribute(), muMemGetAllocationGranularity(...), ...
|
||||
else()
|
||||
set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cuda_driver) # required by cuDeviceGetAttribute(), cuMemGetAllocationGranularity(...), ...
|
||||
endif()
|
||||
endif()
|
||||
else()
|
||||
message(WARNING "CUDA not found")
|
||||
@ -857,8 +902,10 @@ function(get_flags CCID CCVER)
|
||||
set(C_FLAGS -Wdouble-promotion)
|
||||
set(CXX_FLAGS -Wno-array-bounds)
|
||||
|
||||
if (CCVER VERSION_GREATER_EQUAL 7.1.0)
|
||||
list(APPEND CXX_FLAGS -Wno-format-truncation)
|
||||
if (NOT GGML_MUSA)
|
||||
if (CCVER VERSION_GREATER_EQUAL 7.1.0)
|
||||
list(APPEND CXX_FLAGS -Wno-format-truncation)
|
||||
endif()
|
||||
endif()
|
||||
if (CCVER VERSION_GREATER_EQUAL 8.1.0)
|
||||
list(APPEND CXX_FLAGS -Wextra-semi)
|
||||
@ -1264,6 +1311,7 @@ endif()
|
||||
target_compile_definitions(ggml PUBLIC ${GGML_CDEF_PUBLIC})
|
||||
target_include_directories(ggml PUBLIC ../include)
|
||||
target_include_directories(ggml PRIVATE . ${GGML_EXTRA_INCLUDES})
|
||||
target_link_directories(ggml PRIVATE ${GGML_EXTRA_LIBDIRS})
|
||||
target_compile_features (ggml PRIVATE c_std_11) # don't bump
|
||||
|
||||
target_link_libraries(ggml PRIVATE Threads::Threads ${GGML_EXTRA_LIBS})
|
||||
|
@ -19,7 +19,11 @@ typedef half2 ggml_half2;
|
||||
|
||||
#define GGML_COMMON_DECL
|
||||
#elif defined(GGML_COMMON_DECL_CUDA)
|
||||
#if defined(GGML_COMMON_DECL_MUSA)
|
||||
#include <musa_fp16.h>
|
||||
#else
|
||||
#include <cuda_fp16.h>
|
||||
#endif
|
||||
#include <cstdint>
|
||||
|
||||
typedef half ggml_half;
|
||||
@ -415,7 +419,7 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
|
||||
#define GGML_TABLE_END() };
|
||||
|
||||
#define GGML_COMMON_IMPL
|
||||
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP)
|
||||
#elif defined(GGML_COMMON_IMPL_CUDA) || defined(GGML_COMMON_IMPL_HIP) || defined(GGML_COMMON_IMPL_MUSA)
|
||||
#include <cstdint>
|
||||
|
||||
#define GGML_TABLE_BEGIN(type, name, size) static const __device__ type name[size] = {
|
||||
|
@ -167,7 +167,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
for (int id = 0; id < info.device_count; ++id) {
|
||||
int device_vmm = 0;
|
||||
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
CUdevice device;
|
||||
CU_CHECK(cuDeviceGet(&device, id));
|
||||
CU_CHECK(cuDeviceGetAttribute(&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
|
||||
@ -179,7 +179,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
alloc_prop.location.id = id;
|
||||
CU_CHECK(cuMemGetAllocationGranularity(&info.devices[id].vmm_granularity, &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
}
|
||||
#endif // !defined(GGML_USE_HIPBLAS)
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
info.devices[id].vmm = !!device_vmm;
|
||||
|
||||
cudaDeviceProp prop;
|
||||
@ -315,7 +315,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
|
||||
};
|
||||
|
||||
// pool with virtual memory
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35; // 32 GB
|
||||
|
||||
@ -409,14 +409,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
|
||||
GGML_ASSERT(ptr == (void *) (pool_addr + pool_used));
|
||||
}
|
||||
};
|
||||
#endif // !defined(GGML_USE_HIPBLAS)
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
|
||||
std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(int device) {
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
if (ggml_cuda_info().devices[device].vmm) {
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm(device));
|
||||
}
|
||||
#endif
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_CUDA_NO_VMM) && !defined(GGML_USE_MUSA)
|
||||
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
|
||||
}
|
||||
|
||||
@ -1341,7 +1341,7 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
|
||||
static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
|
||||
void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
|
||||
|
||||
#if !defined(GGML_USE_HIPBLAS)
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
|
||||
cudaMemcpy3DPeerParms p = {};
|
||||
p.dstDevice = dstDevice;
|
||||
@ -1355,7 +1355,7 @@ static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
|
||||
GGML_UNUSED(dstDevice);
|
||||
GGML_UNUSED(srcDevice);
|
||||
return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
|
||||
#endif // !defined(GGML_USE_HIPBLAS)
|
||||
#endif // !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
}
|
||||
|
||||
static void ggml_cuda_op_mul_mat(
|
||||
@ -1828,6 +1828,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
}
|
||||
}
|
||||
#else
|
||||
#ifdef GGML_USE_MUSA
|
||||
GGML_ASSERT(false);
|
||||
#else // !GGML_USE_MUSA
|
||||
if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
|
||||
// there is no broadcast and src0, src1 are contiguous across dims 2, 3
|
||||
// use cublasGemmStridedBatchedEx
|
||||
@ -1870,6 +1873,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
|
||||
cu_compute_type,
|
||||
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
#endif // GGML_USE_MUSA
|
||||
#endif
|
||||
|
||||
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
|
||||
@ -3027,7 +3031,7 @@ GGML_CALL bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size
|
||||
return false;
|
||||
}
|
||||
|
||||
#if CUDART_VERSION >= 11100
|
||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
||||
cudaError_t err = cudaHostRegister(buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
|
||||
if (err != cudaSuccess) {
|
||||
// clear the error
|
||||
|
@ -12,6 +12,10 @@
|
||||
#else
|
||||
#define GGML_COMMON_DECL_CUDA
|
||||
#define GGML_COMMON_IMPL_CUDA
|
||||
#if defined(GGML_USE_MUSA)
|
||||
#define GGML_COMMON_DECL_MUSA
|
||||
#define GGML_COMMON_IMPL_MUSA
|
||||
#endif
|
||||
#endif
|
||||
#include "ggml-common.h"
|
||||
|
||||
@ -114,6 +118,150 @@
|
||||
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||
#elif defined(GGML_USE_MUSA)
|
||||
#include <musa_runtime.h>
|
||||
#include <musa.h>
|
||||
#include <mublas.h>
|
||||
#include <musa_fp16.h>
|
||||
// XXX: Keep the following order the same as hipBLAS
|
||||
// #define CUBLAS_COMPUTE_16F MUBLAS_COMPUTE_16F
|
||||
// #define CUBLAS_COMPUTE_32F MUBLAS_COMPUTE_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
||||
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N MUBLAS_OP_N
|
||||
#define CUBLAS_OP_T MUBLAS_OP_T
|
||||
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
||||
// #define CUBLAS_TF32_TENSOR_OP_MATH 0
|
||||
#define CUDA_R_16F MUSA_R_16F
|
||||
#define CUDA_R_32F MUSA_R_32F
|
||||
// #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
// #define cublasComputeType_t mublasComputeType_t
|
||||
#define cublasCreate mublasCreate
|
||||
#define cublasDestroy mublasDestroy
|
||||
#define cublasGemmEx mublasGemmEx
|
||||
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
||||
#define cublasHandle_t mublasHandle_t
|
||||
// #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cudaDataType_t musaDataType_t //deprecated, new hipblasDatatype not in 5.6
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp musaDeviceProp
|
||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||
#define cudaError_t musaError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
||||
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
||||
#define cudaEventDisableTiming musaEventDisableTiming
|
||||
#define cudaEventRecord musaEventRecord
|
||||
#define cudaEventSynchronize musaEventSynchronize
|
||||
#define cudaEvent_t musaEvent_t
|
||||
#define cudaEventDestroy musaEventDestroy
|
||||
#define cudaFree musaFree
|
||||
#define cudaFreeHost musaFreeHost
|
||||
#define cudaGetDevice musaGetDevice
|
||||
#define cudaGetDeviceCount musaGetDeviceCount
|
||||
#define cudaGetDeviceProperties musaGetDeviceProperties
|
||||
#define cudaGetErrorString musaGetErrorString
|
||||
#define cudaGetLastError musaGetLastError
|
||||
#define cudaHostRegister musaHostRegister
|
||||
#define cudaHostRegisterPortable musaHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
||||
#define cudaHostUnregister musaHostUnregister
|
||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||
#define cudaMalloc musaMalloc
|
||||
#define cudaMallocHost musaMallocHost
|
||||
#define cudaMemcpy musaMemcpy
|
||||
#define cudaMemcpyAsync musaMemcpyAsync
|
||||
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
||||
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
||||
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
||||
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
||||
#define cudaMemcpyKind musaMemcpyKind
|
||||
#define cudaMemset musaMemset
|
||||
#define cudaMemsetAsync musaMemsetAsync
|
||||
#define cudaMemGetInfo musaMemGetInfo
|
||||
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
||||
#define cudaSetDevice musaSetDevice
|
||||
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
||||
#define cudaStreamDestroy musaStreamDestroy
|
||||
#define cudaStreamFireAndForget musaStreamFireAndForget
|
||||
#define cudaStreamNonBlocking musaStreamNonBlocking
|
||||
#define cudaStreamPerThread musaStreamPerThread
|
||||
#define cudaStreamSynchronize musaStreamSynchronize
|
||||
#define cudaStreamWaitEvent musaStreamWaitEvent
|
||||
#define cudaStream_t musaStream_t
|
||||
#define cudaSuccess musaSuccess
|
||||
|
||||
// XXX: Other CUDA => MUSA mapping
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
||||
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
||||
#define CUdevice MUdevice
|
||||
#define CUdeviceptr MUdeviceptr
|
||||
#define CUmemAccessDesc MUmemAccessDesc
|
||||
#define CUmemAllocationProp MUmemAllocationProp
|
||||
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
||||
#define cuDeviceGet muDeviceGet
|
||||
#define cuDeviceGetAttribute muDeviceGetAttribute
|
||||
#define cuMemAddressFree muMemAddressFree
|
||||
#define cuMemAddressReserve muMemAddressReserve
|
||||
#define cuMemCreate muMemCreate
|
||||
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
||||
#define cuMemMap muMemMap
|
||||
#define cuMemRelease muMemRelease
|
||||
#define cuMemSetAccess muMemSetAccess
|
||||
#define cuMemUnmap muMemUnmap
|
||||
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
||||
#define cudaFuncSetAttribute musaFuncSetAttribute
|
||||
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
||||
#define make_cudaExtent make_musaExtent
|
||||
#define make_cudaPitchedPtr make_musaPitchedPtr
|
||||
|
||||
// XXX: USE_CUDA_GRAPH
|
||||
#define CUDA_SUCCESS MUSA_SUCCESS
|
||||
#define CUresult MUresult
|
||||
#define cuGetErrorString muGetErrorString
|
||||
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
||||
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
||||
#define cudaGraphDestroy musaGraphDestroy
|
||||
#define cudaGraphExecDestroy musaGraphExecDestroy
|
||||
#define cudaGraphExec_t musaGraphExec_t
|
||||
#define cudaGraphExecUpdate musaGraphExecUpdate
|
||||
#define cudaGraphExecUpdateResultInfo musaGraphExecUpdateResult
|
||||
#define cudaGraphGetNodes musaGraphGetNodes
|
||||
#define cudaGraphInstantiate musaGraphInstantiate
|
||||
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
||||
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
||||
#define cudaGraphLaunch musaGraphLaunch
|
||||
#define cudaGraphNodeGetType musaGraphNodeGetType
|
||||
#define cudaGraphNode_t musaGraphNode_t
|
||||
#define cudaGraphNodeType musaGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
||||
#define cudaGraph_t musaGraph_t
|
||||
#define cudaKernelNodeParams musaKernelNodeParams
|
||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||
#define cudaStreamEndCapture musaStreamEndCapture
|
||||
|
||||
// XXX: cuBLAS => muBLAS mapping
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT
|
||||
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
||||
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
||||
#define cublasComputeType_t cudaDataType_t
|
||||
|
||||
// XXX: Clang builtins mapping
|
||||
#define __vsub4 __vsub4_musa
|
||||
#define __vcmpeq4 __vcmpeq4_musa
|
||||
#define __vcmpne4 __vcmpne4_musa
|
||||
#else
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
@ -168,9 +316,13 @@ void ggml_cuda_error(const char * stmt, const char * func, const char * file, in
|
||||
|
||||
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
|
||||
|
||||
#if CUDART_VERSION >= 12000
|
||||
#if CUDART_VERSION >= 12000 || defined(GGML_USE_MUSA)
|
||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||
#ifndef GGML_USE_MUSA
|
||||
return cublasGetStatusString(err);
|
||||
#else
|
||||
return mublasStatus_to_string(err);
|
||||
#endif // GGML_USE_MUSA
|
||||
}
|
||||
#else
|
||||
static const char * cublas_get_error_str(const cublasStatus_t err) {
|
||||
@ -200,7 +352,7 @@ static const char * cu_get_error_str(CUresult err) {
|
||||
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
||||
#endif
|
||||
|
||||
#if CUDART_VERSION >= 11100
|
||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||
#else
|
||||
#define GGML_CUDA_ASSUME(x)
|
||||
@ -214,6 +366,42 @@ typedef float dfloat; // dequantize float
|
||||
typedef float2 dfloat2;
|
||||
#endif //GGML_CUDA_F16
|
||||
|
||||
#if defined(GGML_USE_MUSA)
|
||||
#ifndef __has_builtin
|
||||
#define __has_builtin(x) 0
|
||||
#endif
|
||||
|
||||
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
static __device__ __forceinline__ int __vsub4_musa(const int a, const int b) {
|
||||
return __vsubss4(a, b);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ unsigned int __vcmpeq4_musa(unsigned int a, unsigned int b) {
|
||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||
unsigned int c;
|
||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ unsigned int __vcmpne4_musa(unsigned int a, unsigned int b) {
|
||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||
unsigned int c;
|
||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
#endif // defined(GGML_USE_MUSA)
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
@ -455,7 +643,7 @@ static __device__ __forceinline__ uint32_t __hgt2_mask(const half2 a, const half
|
||||
const uint32_t mask_high = 0xFFFF0000 * (float(__high2half(a)) > float(__high2half(b)));
|
||||
return mask_low | mask_high;
|
||||
}
|
||||
#endif // CUDART_VERSION < 12000
|
||||
#endif // CUDART_VERSION < CUDART_HMASK
|
||||
|
||||
static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, int c) {
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
|
Loading…
Reference in New Issue
Block a user