From 2df373ac40ea581ccca8a58c713f03ad9d4b658d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Tue, 25 Jun 2024 01:22:33 +0200 Subject: [PATCH] CUDA: fix matrix multiplication algorithm choice (#8102) --- ggml-cuda.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 2dda03924..0acfda91d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1924,16 +1924,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { // FP32 precision KQV single-batch for batch size 1 without FlashAttention ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); + } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) + && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + // KQ + KQV multi-batch without FlashAttention + ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else if (use_dequantize_mul_mat_vec) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); - } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) - && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { - // KQ + KQV multi-batch without FlashAttention - ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); } else { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); }