cuda : add CUDA_USE_TENSOR_CORES and GGML_CUDA_FORCE_MMQ macros

This commit is contained in:
Georgi Gerganov 2023-10-25 18:48:36 +03:00
parent 4c6744b526
commit a4e15a36e4
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 104 additions and 22 deletions

View File

@ -87,6 +87,23 @@
#define CC_OFFSET_AMD 1000000 #define CC_OFFSET_AMD 1000000
#define CC_RDNA2 (CC_OFFSET_AMD + 1030) #define CC_RDNA2 (CC_OFFSET_AMD + 1030)
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
// - 7B quantum model: +100-200 MB
// - 13B quantum model: +200-400 MB
//#define GGML_CUDA_FORCE_MMQ
// TODO: improve this to be correct for more hardware
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
// probably other such cases, and not sure what happens on AMD hardware
#if !defined(GGML_CUDA_FORCE_MMQ)
#define CUDA_USE_TENSOR_CORES
#endif
// max batch size to use MMQ kernels when tensor cores are available
#define MMQ_MAX_BATCH_SIZE 32
#if defined(GGML_USE_HIPBLAS) #if defined(GGML_USE_HIPBLAS)
#define __CUDA_ARCH__ 1300 #define __CUDA_ARCH__ 1300
@ -470,7 +487,6 @@ static int g_device_count = -1;
static int g_main_device = 0; static int g_main_device = 0;
static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES];
static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
static bool g_mul_mat_q = true;
static void * g_scratch_buffer = nullptr; static void * g_scratch_buffer = nullptr;
static size_t g_scratch_size = 0; // disabled by default static size_t g_scratch_size = 0; // disabled by default
@ -3554,9 +3570,15 @@ static __device__ __forceinline__ void mul_mat_q(
#define MMQ_X_Q4_0_RDNA1 64 #define MMQ_X_Q4_0_RDNA1 64
#define MMQ_Y_Q4_0_RDNA1 64 #define MMQ_Y_Q4_0_RDNA1 64
#define NWARPS_Q4_0_RDNA1 8 #define NWARPS_Q4_0_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q4_0_AMPERE 4 #define MMQ_X_Q4_0_AMPERE 4
#define MMQ_Y_Q4_0_AMPERE 32 #define MMQ_Y_Q4_0_AMPERE 32
#define NWARPS_Q4_0_AMPERE 4 #define NWARPS_Q4_0_AMPERE 4
#else
#define MMQ_X_Q4_0_AMPERE 64
#define MMQ_Y_Q4_0_AMPERE 128
#define NWARPS_Q4_0_AMPERE 4
#endif
#define MMQ_X_Q4_0_PASCAL 64 #define MMQ_X_Q4_0_PASCAL 64
#define MMQ_Y_Q4_0_PASCAL 64 #define MMQ_Y_Q4_0_PASCAL 64
#define NWARPS_Q4_0_PASCAL 8 #define NWARPS_Q4_0_PASCAL 8
@ -3615,9 +3637,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q4_1_RDNA1 64 #define MMQ_X_Q4_1_RDNA1 64
#define MMQ_Y_Q4_1_RDNA1 64 #define MMQ_Y_Q4_1_RDNA1 64
#define NWARPS_Q4_1_RDNA1 8 #define NWARPS_Q4_1_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q4_1_AMPERE 4 #define MMQ_X_Q4_1_AMPERE 4
#define MMQ_Y_Q4_1_AMPERE 32 #define MMQ_Y_Q4_1_AMPERE 32
#define NWARPS_Q4_1_AMPERE 4 #define NWARPS_Q4_1_AMPERE 4
#else
#define MMQ_X_Q4_1_AMPERE 64
#define MMQ_Y_Q4_1_AMPERE 128
#define NWARPS_Q4_1_AMPERE 4
#endif
#define MMQ_X_Q4_1_PASCAL 64 #define MMQ_X_Q4_1_PASCAL 64
#define MMQ_Y_Q4_1_PASCAL 64 #define MMQ_Y_Q4_1_PASCAL 64
#define NWARPS_Q4_1_PASCAL 8 #define NWARPS_Q4_1_PASCAL 8
@ -3678,9 +3706,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q5_0_RDNA1 64 #define MMQ_X_Q5_0_RDNA1 64
#define MMQ_Y_Q5_0_RDNA1 64 #define MMQ_Y_Q5_0_RDNA1 64
#define NWARPS_Q5_0_RDNA1 8 #define NWARPS_Q5_0_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q5_0_AMPERE 4 #define MMQ_X_Q5_0_AMPERE 4
#define MMQ_Y_Q5_0_AMPERE 32 #define MMQ_Y_Q5_0_AMPERE 32
#define NWARPS_Q5_0_AMPERE 4 #define NWARPS_Q5_0_AMPERE 4
#else
#define MMQ_X_Q5_0_AMPERE 128
#define MMQ_Y_Q5_0_AMPERE 64
#define NWARPS_Q5_0_AMPERE 4
#endif
#define MMQ_X_Q5_0_PASCAL 64 #define MMQ_X_Q5_0_PASCAL 64
#define MMQ_Y_Q5_0_PASCAL 64 #define MMQ_Y_Q5_0_PASCAL 64
#define NWARPS_Q5_0_PASCAL 8 #define NWARPS_Q5_0_PASCAL 8
@ -3739,9 +3773,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q5_1_RDNA1 64 #define MMQ_X_Q5_1_RDNA1 64
#define MMQ_Y_Q5_1_RDNA1 64 #define MMQ_Y_Q5_1_RDNA1 64
#define NWARPS_Q5_1_RDNA1 8 #define NWARPS_Q5_1_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q5_1_AMPERE 4 #define MMQ_X_Q5_1_AMPERE 4
#define MMQ_Y_Q5_1_AMPERE 32 #define MMQ_Y_Q5_1_AMPERE 32
#define NWARPS_Q5_1_AMPERE 4 #define NWARPS_Q5_1_AMPERE 4
#else
#define MMQ_X_Q5_1_AMPERE 128
#define MMQ_Y_Q5_1_AMPERE 64
#define NWARPS_Q5_1_AMPERE 4
#endif
#define MMQ_X_Q5_1_PASCAL 64 #define MMQ_X_Q5_1_PASCAL 64
#define MMQ_Y_Q5_1_PASCAL 64 #define MMQ_Y_Q5_1_PASCAL 64
#define NWARPS_Q5_1_PASCAL 8 #define NWARPS_Q5_1_PASCAL 8
@ -3800,9 +3840,15 @@ mul_mat_q5_1(
#define MMQ_X_Q8_0_RDNA1 64 #define MMQ_X_Q8_0_RDNA1 64
#define MMQ_Y_Q8_0_RDNA1 64 #define MMQ_Y_Q8_0_RDNA1 64
#define NWARPS_Q8_0_RDNA1 8 #define NWARPS_Q8_0_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q8_0_AMPERE 4 #define MMQ_X_Q8_0_AMPERE 4
#define MMQ_Y_Q8_0_AMPERE 32 #define MMQ_Y_Q8_0_AMPERE 32
#define NWARPS_Q8_0_AMPERE 4 #define NWARPS_Q8_0_AMPERE 4
#else
#define MMQ_X_Q8_0_AMPERE 128
#define MMQ_Y_Q8_0_AMPERE 64
#define NWARPS_Q8_0_AMPERE 4
#endif
#define MMQ_X_Q8_0_PASCAL 64 #define MMQ_X_Q8_0_PASCAL 64
#define MMQ_Y_Q8_0_PASCAL 64 #define MMQ_Y_Q8_0_PASCAL 64
#define NWARPS_Q8_0_PASCAL 8 #define NWARPS_Q8_0_PASCAL 8
@ -3861,9 +3907,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q2_K_RDNA1 128 #define MMQ_X_Q2_K_RDNA1 128
#define MMQ_Y_Q2_K_RDNA1 32 #define MMQ_Y_Q2_K_RDNA1 32
#define NWARPS_Q2_K_RDNA1 8 #define NWARPS_Q2_K_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q2_K_AMPERE 4 #define MMQ_X_Q2_K_AMPERE 4
#define MMQ_Y_Q2_K_AMPERE 32 #define MMQ_Y_Q2_K_AMPERE 32
#define NWARPS_Q2_K_AMPERE 4 #define NWARPS_Q2_K_AMPERE 4
#else
#define MMQ_X_Q2_K_AMPERE 64
#define MMQ_Y_Q2_K_AMPERE 128
#define NWARPS_Q2_K_AMPERE 4
#endif
#define MMQ_X_Q2_K_PASCAL 64 #define MMQ_X_Q2_K_PASCAL 64
#define MMQ_Y_Q2_K_PASCAL 64 #define MMQ_Y_Q2_K_PASCAL 64
#define NWARPS_Q2_K_PASCAL 8 #define NWARPS_Q2_K_PASCAL 8
@ -3922,9 +3974,15 @@ mul_mat_q2_K(
#define MMQ_X_Q3_K_RDNA1 32 #define MMQ_X_Q3_K_RDNA1 32
#define MMQ_Y_Q3_K_RDNA1 128 #define MMQ_Y_Q3_K_RDNA1 128
#define NWARPS_Q3_K_RDNA1 8 #define NWARPS_Q3_K_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q3_K_AMPERE 4 #define MMQ_X_Q3_K_AMPERE 4
#define MMQ_Y_Q3_K_AMPERE 32 #define MMQ_Y_Q3_K_AMPERE 32
#define NWARPS_Q3_K_AMPERE 4 #define NWARPS_Q3_K_AMPERE 4
#else
#define MMQ_X_Q3_K_AMPERE 128
#define MMQ_Y_Q3_K_AMPERE 128
#define NWARPS_Q3_K_AMPERE 4
#endif
#define MMQ_X_Q3_K_PASCAL 64 #define MMQ_X_Q3_K_PASCAL 64
#define MMQ_Y_Q3_K_PASCAL 64 #define MMQ_Y_Q3_K_PASCAL 64
#define NWARPS_Q3_K_PASCAL 8 #define NWARPS_Q3_K_PASCAL 8
@ -3985,9 +4043,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q4_K_RDNA1 32 #define MMQ_X_Q4_K_RDNA1 32
#define MMQ_Y_Q4_K_RDNA1 64 #define MMQ_Y_Q4_K_RDNA1 64
#define NWARPS_Q4_K_RDNA1 8 #define NWARPS_Q4_K_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q4_K_AMPERE 4 #define MMQ_X_Q4_K_AMPERE 4
#define MMQ_Y_Q4_K_AMPERE 32 #define MMQ_Y_Q4_K_AMPERE 32
#define NWARPS_Q4_K_AMPERE 4 #define NWARPS_Q4_K_AMPERE 4
#else
#define MMQ_X_Q4_K_AMPERE 64
#define MMQ_Y_Q4_K_AMPERE 128
#define NWARPS_Q4_K_AMPERE 4
#endif
#define MMQ_X_Q4_K_PASCAL 64 #define MMQ_X_Q4_K_PASCAL 64
#define MMQ_Y_Q4_K_PASCAL 64 #define MMQ_Y_Q4_K_PASCAL 64
#define NWARPS_Q4_K_PASCAL 8 #define NWARPS_Q4_K_PASCAL 8
@ -4048,9 +4112,15 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q5_K_RDNA1 32 #define MMQ_X_Q5_K_RDNA1 32
#define MMQ_Y_Q5_K_RDNA1 64 #define MMQ_Y_Q5_K_RDNA1 64
#define NWARPS_Q5_K_RDNA1 8 #define NWARPS_Q5_K_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q5_K_AMPERE 4 #define MMQ_X_Q5_K_AMPERE 4
#define MMQ_Y_Q5_K_AMPERE 32 #define MMQ_Y_Q5_K_AMPERE 32
#define NWARPS_Q5_K_AMPERE 4 #define NWARPS_Q5_K_AMPERE 4
#else
#define MMQ_X_Q5_K_AMPERE 64
#define MMQ_Y_Q5_K_AMPERE 128
#define NWARPS_Q5_K_AMPERE 4
#endif
#define MMQ_X_Q5_K_PASCAL 64 #define MMQ_X_Q5_K_PASCAL 64
#define MMQ_Y_Q5_K_PASCAL 64 #define MMQ_Y_Q5_K_PASCAL 64
#define NWARPS_Q5_K_PASCAL 8 #define NWARPS_Q5_K_PASCAL 8
@ -4109,9 +4179,15 @@ mul_mat_q5_K(
#define MMQ_X_Q6_K_RDNA1 32 #define MMQ_X_Q6_K_RDNA1 32
#define MMQ_Y_Q6_K_RDNA1 64 #define MMQ_Y_Q6_K_RDNA1 64
#define NWARPS_Q6_K_RDNA1 8 #define NWARPS_Q6_K_RDNA1 8
#if defined(CUDA_USE_TENSOR_CORES)
#define MMQ_X_Q6_K_AMPERE 4 #define MMQ_X_Q6_K_AMPERE 4
#define MMQ_Y_Q6_K_AMPERE 32 #define MMQ_Y_Q6_K_AMPERE 32
#define NWARPS_Q6_K_AMPERE 4 #define NWARPS_Q6_K_AMPERE 4
#else
#define MMQ_X_Q6_K_AMPERE 64
#define MMQ_Y_Q6_K_AMPERE 64
#define NWARPS_Q6_K_AMPERE 4
#endif
#define MMQ_X_Q6_K_PASCAL 64 #define MMQ_X_Q6_K_PASCAL 64
#define MMQ_Y_Q6_K_PASCAL 64 #define MMQ_Y_Q6_K_PASCAL 64
#define NWARPS_Q6_K_PASCAL 8 #define NWARPS_Q6_K_PASCAL 8
@ -5663,6 +5739,16 @@ void ggml_init_cublas() {
CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0; int64_t total_vram = 0;
#if defined(GGML_CUDA_FORCE_MMQ)
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
#else
fprintf(stderr, "%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
#endif
#if defined(CUDA_USE_TENSOR_CORES)
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: yes\n", __func__);
#else
fprintf(stderr, "%s: CUDA_USE_TENSOR_CORES: no\n", __func__);
#endif
fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count); fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
for (int id = 0; id < g_device_count; ++id) { for (int id = 0; id < g_device_count; ++id) {
cudaDeviceProp prop; cudaDeviceProp prop;
@ -6347,7 +6433,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10, row_diff, src1_ncols, ne10,
&alpha, src0_ddf_i, ne00, &alpha, src0_ddf_i, ne00,
src1_ddf_i, ne10, src1_ddf_i, ne10,
&beta, dst_dd_i, ldc)); &beta, dst_dd_i, ldc));
if (src0_as != 0) { if (src0_as != 0) {
@ -7204,18 +7290,23 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const bool all_on_device = const bool all_on_device =
(src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && (src0->backend == GGML_BACKEND_GPU) &&
(src1->backend == GGML_BACKEND_GPU) && (src1->backend == GGML_BACKEND_GPU) &&
( dst->backend == GGML_BACKEND_GPU); ( dst->backend == GGML_BACKEND_GPU);
int64_t min_compute_capability = INT_MAX; int64_t min_compute_capability = INT_MAX;
for (int64_t id = 0; id < g_device_count; ++id) { for (int64_t id = 0; id < g_device_count; ++id) {
if (min_compute_capability > g_compute_capabilities[id] if (min_compute_capability > g_compute_capabilities[id] && g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
&& g_tensor_split[id] < (id + 1 < g_device_count ? g_tensor_split[id + 1] : 1.0f)) {
min_compute_capability = g_compute_capabilities[id]; min_compute_capability = g_compute_capabilities[id];
} }
} }
#ifdef CUDA_USE_TENSOR_CORES
const bool use_tensor_cores = true;
#else
const bool use_tensor_cores = false;
#endif
// debug helpers // debug helpers
//printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
//printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
@ -7224,20 +7315,19 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
// KQ single-batch // KQ single-batch
ggml_cuda_mul_mat_vec_p021(src0, src1, dst); ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
} else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { } else if (all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
// KQV single-batch // KQV single-batch
ggml_cuda_mul_mat_vec_nc(src0, src1, dst); ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
// KQ + KQV multi-batch // KQ + KQV multi-batch
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
} else if (src0->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_F32) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
#ifdef GGML_CUDA_FORCE_DMMV #ifdef GGML_CUDA_FORCE_DMMV
const bool use_mul_mat_vec_q = false; const bool use_mul_mat_vec_q = false;
#else #else
@ -7250,13 +7340,11 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
} }
} else { } else {
// ref: https://github.com/ggerganov/llama.cpp/pull/3776 bool use_mul_mat_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
bool use_mul_mat_q = g_mul_mat_q && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
// TODO: better way to determine availability of tensor cores // when tensor cores are available, use them for large batch size
// currently fails for GeForce GTX 1660 which is TURING arch but does not have tensor cores // ref: https://github.com/ggerganov/llama.cpp/pull/3776
if (min_compute_capability >= CC_VOLTA && src1->ne[1] > 32) { if (use_tensor_cores && min_compute_capability >= CC_VOLTA && src1->ne[1] > MMQ_MAX_BATCH_SIZE) {
// when tensor cores are available, use them for large batch size
use_mul_mat_q = false; use_mul_mat_q = false;
} }
@ -7614,10 +7702,6 @@ void ggml_cuda_set_main_device(const int main_device) {
} }
} }
void ggml_cuda_set_mul_mat_q(const bool mul_mat_q) {
g_mul_mat_q = mul_mat_q;
}
void ggml_cuda_set_scratch_size(const size_t scratch_size) { void ggml_cuda_set_scratch_size(const size_t scratch_size) {
// this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously // this is a hack to not completely break llama.cpp when using multiple models or contexts simultaneously
// it still won't always work as expected, but it's better than nothing // it still won't always work as expected, but it's better than nothing

View File

@ -5959,8 +5959,6 @@ static int llama_decode_internal(
} }
} }
ggml_cuda_set_mul_mat_q(cparams.mul_mat_q);
// HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed // HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed
if (!lctx.embedding.empty()) { if (!lctx.embedding.empty()) {
embeddings->backend = GGML_BACKEND_CPU; embeddings->backend = GGML_BACKEND_CPU;

View File

@ -178,7 +178,7 @@ extern "C" {
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
// Keep the booleans together to avoid misalignment during copy-by-value. // Keep the booleans together to avoid misalignment during copy-by-value.
bool mul_mat_q; // if true, use experimental mul_mat_q kernels bool mul_mat_q; // if true, use experimental mul_mat_q kernels (DEPRECATED - always true)
bool f16_kv; // use fp16 for KV cache, fp32 otherwise bool f16_kv; // use fp16 for KV cache, fp32 otherwise
bool logits_all; // the llama_eval() call computes all logits, not just the last one bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool embedding; // embedding mode only bool embedding; // embedding mode only