From 17c97fb0620448b37516a3f53fea6c482b0a30a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 6 Feb 2024 18:43:06 +0100 Subject: [PATCH] CUDA: mul_mat_vec_q max. batch size 8 -> 4 (#5370) --- ggml-cuda.cu | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 95161b3f4..3b828375e 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6831,7 +6831,7 @@ static void mul_mat_vec_q_cuda( const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) { GGML_ASSERT(ncols_x % qk == 0); - GGML_ASSERT(ncols_y <= 8); + GGML_ASSERT(ncols_y <= 4); const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(block_num_y, 1, 1); @@ -6853,22 +6853,22 @@ static void mul_mat_vec_q_cuda( mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot> <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); break; - case 5: - mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); - break; - case 6: - mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); - break; - case 7: - mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); - break; - case 8: - mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> - <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); - break; + // case 5: + // mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> + // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + // break; + // case 6: + // mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> + // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + // break; + // case 7: + // mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> + // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + // break; + // case 8: + // mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> + // <<>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y); + // break; default: GGML_ASSERT(false); // mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot> @@ -9909,7 +9909,7 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false); } } else { - if (src1->ne[1] <= 8 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) { + if (src1->ne[1] <= 4 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true); } else if (use_mul_mat_q) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);