mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
CUDA: fix tensor core logic for Pascal and HIP (#4682)
This commit is contained in:
parent
0235b9b571
commit
a20f3c7465
72
ggml-cuda.cu
72
ggml-cuda.cu
@ -123,24 +123,6 @@
|
|||||||
|
|
||||||
#define GGML_CUDA_MAX_NODES 8192
|
#define GGML_CUDA_MAX_NODES 8192
|
||||||
|
|
||||||
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
|
|
||||||
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
|
|
||||||
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
|
|
||||||
// - 7B quantum model: +100-200 MB
|
|
||||||
// - 13B quantum model: +200-400 MB
|
|
||||||
//
|
|
||||||
//#define GGML_CUDA_FORCE_MMQ
|
|
||||||
|
|
||||||
// TODO: improve this to be correct for more hardware
|
|
||||||
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
|
|
||||||
// probably other such cases, and not sure what happens on AMD hardware
|
|
||||||
#if !defined(GGML_CUDA_FORCE_MMQ)
|
|
||||||
#define CUDA_USE_TENSOR_CORES
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// max batch size to use MMQ kernels when tensor cores are available
|
|
||||||
#define MMQ_MAX_BATCH_SIZE 32
|
|
||||||
|
|
||||||
#if defined(GGML_USE_HIPBLAS)
|
#if defined(GGML_USE_HIPBLAS)
|
||||||
#define __CUDA_ARCH__ 1300
|
#define __CUDA_ARCH__ 1300
|
||||||
|
|
||||||
@ -207,6 +189,23 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
|
|||||||
}
|
}
|
||||||
#endif // defined(GGML_USE_HIPBLAS)
|
#endif // defined(GGML_USE_HIPBLAS)
|
||||||
|
|
||||||
|
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
|
||||||
|
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
|
||||||
|
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
|
||||||
|
// - 7B quantum model: +100-200 MB
|
||||||
|
// - 13B quantum model: +200-400 MB
|
||||||
|
//
|
||||||
|
//#define GGML_CUDA_FORCE_MMQ
|
||||||
|
|
||||||
|
// TODO: improve this to be correct for more hardware
|
||||||
|
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
|
||||||
|
#if !defined(GGML_CUDA_FORCE_MMQ) && (!defined(GGML_USE_HIPBLAS) || defined(RDNA3))
|
||||||
|
#define CUDA_USE_TENSOR_CORES
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// max batch size to use MMQ kernels when tensor cores are available
|
||||||
|
#define MMQ_MAX_BATCH_SIZE 32
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
#endif
|
#endif
|
||||||
@ -8661,11 +8660,26 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef CUDA_USE_TENSOR_CORES
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
const bool use_tensor_cores = true;
|
const bool fp16_performance_good = true;
|
||||||
|
|
||||||
|
#ifdef RDNA3
|
||||||
|
const bool use_mul_mat_q = false;
|
||||||
#else
|
#else
|
||||||
const bool use_tensor_cores = false;
|
const bool use_mul_mat_q = true;
|
||||||
#endif
|
#endif // RDNA3
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
const bool fp16_performance_good = min_compute_capability >= CC_VOLTA;
|
||||||
|
bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
|
||||||
|
#ifdef CUDA_USE_TENSOR_CORES
|
||||||
|
// when tensor cores are available, use them for large batch size
|
||||||
|
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
|
||||||
|
use_mul_mat_q = use_mul_mat_q && !(fp16_performance_good && src1->ne[1] > MMQ_MAX_BATCH_SIZE);
|
||||||
|
#endif // CUDA_USE_TENSOR_CORES
|
||||||
|
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||||
|
|
||||||
// debug helpers
|
// debug helpers
|
||||||
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
|
||||||
@ -8675,13 +8689,13 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||||||
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
|
||||||
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
|
||||||
|
|
||||||
if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
if (!split && all_on_device && !fp16_performance_good && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
|
||||||
// KQ single-batch
|
// KQ single-batch
|
||||||
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
|
||||||
} else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
} else if (!split && all_on_device && !fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||||
// KQV single-batch
|
// KQV single-batch
|
||||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||||
} else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
} else if (!split && all_on_device && fp16_performance_good && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
||||||
// KQ + KQV multi-batch
|
// KQ + KQV multi-batch
|
||||||
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
||||||
} else if (src0->type == GGML_TYPE_F32) {
|
} else if (src0->type == GGML_TYPE_F32) {
|
||||||
@ -8701,14 +8715,6 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
|
|
||||||
|
|
||||||
// when tensor cores are available, use them for large batch size
|
|
||||||
// ref: https://github.com/ggerganov/llama.cpp/pull/3776
|
|
||||||
if (use_tensor_cores && min_compute_capability >= CC_VOLTA && src1->ne[1] > MMQ_MAX_BATCH_SIZE) {
|
|
||||||
use_mul_mat_q = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (use_mul_mat_q) {
|
if (use_mul_mat_q) {
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
|
||||||
} else {
|
} else {
|
||||||
|
Loading…
Reference in New Issue
Block a user