cuda : fine-tune >= VOLTA params + use MMQ only for small batches

This commit is contained in:
Georgi Gerganov 2023-10-25 15:07:34 +03:00
parent 16b60dd75c
commit a3c28439d3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -3554,8 +3554,8 @@ static __device__ __forceinline__ void mul_mat_q(
#define MMQ_X_Q4_0_RDNA1 64
#define MMQ_Y_Q4_0_RDNA1 64
#define NWARPS_Q4_0_RDNA1 8
#define MMQ_X_Q4_0_AMPERE 64
#define MMQ_Y_Q4_0_AMPERE 128
#define MMQ_X_Q4_0_AMPERE 4
#define MMQ_Y_Q4_0_AMPERE 32
#define NWARPS_Q4_0_AMPERE 4
#define MMQ_X_Q4_0_PASCAL 64
#define MMQ_Y_Q4_0_PASCAL 64
@ -3615,8 +3615,8 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q4_1_RDNA1 64
#define MMQ_Y_Q4_1_RDNA1 64
#define NWARPS_Q4_1_RDNA1 8
#define MMQ_X_Q4_1_AMPERE 64
#define MMQ_Y_Q4_1_AMPERE 128
#define MMQ_X_Q4_1_AMPERE 4
#define MMQ_Y_Q4_1_AMPERE 32
#define NWARPS_Q4_1_AMPERE 4
#define MMQ_X_Q4_1_PASCAL 64
#define MMQ_Y_Q4_1_PASCAL 64
@ -3678,8 +3678,8 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q5_0_RDNA1 64
#define MMQ_Y_Q5_0_RDNA1 64
#define NWARPS_Q5_0_RDNA1 8
#define MMQ_X_Q5_0_AMPERE 128
#define MMQ_Y_Q5_0_AMPERE 64
#define MMQ_X_Q5_0_AMPERE 4
#define MMQ_Y_Q5_0_AMPERE 32
#define NWARPS_Q5_0_AMPERE 4
#define MMQ_X_Q5_0_PASCAL 64
#define MMQ_Y_Q5_0_PASCAL 64
@ -3739,8 +3739,8 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q5_1_RDNA1 64
#define MMQ_Y_Q5_1_RDNA1 64
#define NWARPS_Q5_1_RDNA1 8
#define MMQ_X_Q5_1_AMPERE 128
#define MMQ_Y_Q5_1_AMPERE 64
#define MMQ_X_Q5_1_AMPERE 4
#define MMQ_Y_Q5_1_AMPERE 32
#define NWARPS_Q5_1_AMPERE 4
#define MMQ_X_Q5_1_PASCAL 64
#define MMQ_Y_Q5_1_PASCAL 64
@ -3800,8 +3800,8 @@ mul_mat_q5_1(
#define MMQ_X_Q8_0_RDNA1 64
#define MMQ_Y_Q8_0_RDNA1 64
#define NWARPS_Q8_0_RDNA1 8
#define MMQ_X_Q8_0_AMPERE 128
#define MMQ_Y_Q8_0_AMPERE 64
#define MMQ_X_Q8_0_AMPERE 4
#define MMQ_Y_Q8_0_AMPERE 32
#define NWARPS_Q8_0_AMPERE 4
#define MMQ_X_Q8_0_PASCAL 64
#define MMQ_Y_Q8_0_PASCAL 64
@ -3861,8 +3861,8 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q2_K_RDNA1 128
#define MMQ_Y_Q2_K_RDNA1 32
#define NWARPS_Q2_K_RDNA1 8
#define MMQ_X_Q2_K_AMPERE 64
#define MMQ_Y_Q2_K_AMPERE 128
#define MMQ_X_Q2_K_AMPERE 4
#define MMQ_Y_Q2_K_AMPERE 32
#define NWARPS_Q2_K_AMPERE 4
#define MMQ_X_Q2_K_PASCAL 64
#define MMQ_Y_Q2_K_PASCAL 64
@ -3922,8 +3922,8 @@ mul_mat_q2_K(
#define MMQ_X_Q3_K_RDNA1 32
#define MMQ_Y_Q3_K_RDNA1 128
#define NWARPS_Q3_K_RDNA1 8
#define MMQ_X_Q3_K_AMPERE 128
#define MMQ_Y_Q3_K_AMPERE 128
#define MMQ_X_Q3_K_AMPERE 4
#define MMQ_Y_Q3_K_AMPERE 32
#define NWARPS_Q3_K_AMPERE 4
#define MMQ_X_Q3_K_PASCAL 64
#define MMQ_Y_Q3_K_PASCAL 64
@ -3985,8 +3985,8 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q4_K_RDNA1 32
#define MMQ_Y_Q4_K_RDNA1 64
#define NWARPS_Q4_K_RDNA1 8
#define MMQ_X_Q4_K_AMPERE 64
#define MMQ_Y_Q4_K_AMPERE 128
#define MMQ_X_Q4_K_AMPERE 4
#define MMQ_Y_Q4_K_AMPERE 32
#define NWARPS_Q4_K_AMPERE 4
#define MMQ_X_Q4_K_PASCAL 64
#define MMQ_Y_Q4_K_PASCAL 64
@ -4048,8 +4048,8 @@ template <bool need_check> static __global__ void
#define MMQ_X_Q5_K_RDNA1 32
#define MMQ_Y_Q5_K_RDNA1 64
#define NWARPS_Q5_K_RDNA1 8
#define MMQ_X_Q5_K_AMPERE 64
#define MMQ_Y_Q5_K_AMPERE 128
#define MMQ_X_Q5_K_AMPERE 4
#define MMQ_Y_Q5_K_AMPERE 32
#define NWARPS_Q5_K_AMPERE 4
#define MMQ_X_Q5_K_PASCAL 64
#define MMQ_Y_Q5_K_PASCAL 64
@ -4109,8 +4109,8 @@ mul_mat_q5_K(
#define MMQ_X_Q6_K_RDNA1 32
#define MMQ_Y_Q6_K_RDNA1 64
#define NWARPS_Q6_K_RDNA1 8
#define MMQ_X_Q6_K_AMPERE 64
#define MMQ_Y_Q6_K_AMPERE 64
#define MMQ_X_Q6_K_AMPERE 4
#define MMQ_Y_Q6_K_AMPERE 32
#define NWARPS_Q6_K_AMPERE 4
#define MMQ_X_Q6_K_PASCAL 64
#define MMQ_Y_Q6_K_PASCAL 64
@ -7252,7 +7252,7 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
#if 0
#if 1
{
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
half * src0_as_f16 = nullptr;
@ -7309,6 +7309,7 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
}
}
#else
// NOTE: this seems faster for tiny models and small batch-size
{
// convert src0 to fp32, multiply as fp32
float * src0_as_f32 = nullptr;
@ -7372,7 +7373,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
// KQ + KQV multi-batch
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
} else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 1) {
} else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 32) {
// F16 and quantized src0 + high-batch src1
ggml_cuda_mul_mat_mat_deq_cublas(src0, src1, dst);
} else if (src0->type == GGML_TYPE_F32) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);