mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 20:14:29 +00:00
cuda : fine-tune >= VOLTA params + use MMQ only for small batches
This commit is contained in:
parent
16b60dd75c
commit
a3c28439d3
46
ggml-cuda.cu
46
ggml-cuda.cu
@ -3554,8 +3554,8 @@ static __device__ __forceinline__ void mul_mat_q(
|
||||
#define MMQ_X_Q4_0_RDNA1 64
|
||||
#define MMQ_Y_Q4_0_RDNA1 64
|
||||
#define NWARPS_Q4_0_RDNA1 8
|
||||
#define MMQ_X_Q4_0_AMPERE 64
|
||||
#define MMQ_Y_Q4_0_AMPERE 128
|
||||
#define MMQ_X_Q4_0_AMPERE 4
|
||||
#define MMQ_Y_Q4_0_AMPERE 32
|
||||
#define NWARPS_Q4_0_AMPERE 4
|
||||
#define MMQ_X_Q4_0_PASCAL 64
|
||||
#define MMQ_Y_Q4_0_PASCAL 64
|
||||
@ -3615,8 +3615,8 @@ template <bool need_check> static __global__ void
|
||||
#define MMQ_X_Q4_1_RDNA1 64
|
||||
#define MMQ_Y_Q4_1_RDNA1 64
|
||||
#define NWARPS_Q4_1_RDNA1 8
|
||||
#define MMQ_X_Q4_1_AMPERE 64
|
||||
#define MMQ_Y_Q4_1_AMPERE 128
|
||||
#define MMQ_X_Q4_1_AMPERE 4
|
||||
#define MMQ_Y_Q4_1_AMPERE 32
|
||||
#define NWARPS_Q4_1_AMPERE 4
|
||||
#define MMQ_X_Q4_1_PASCAL 64
|
||||
#define MMQ_Y_Q4_1_PASCAL 64
|
||||
@ -3678,8 +3678,8 @@ template <bool need_check> static __global__ void
|
||||
#define MMQ_X_Q5_0_RDNA1 64
|
||||
#define MMQ_Y_Q5_0_RDNA1 64
|
||||
#define NWARPS_Q5_0_RDNA1 8
|
||||
#define MMQ_X_Q5_0_AMPERE 128
|
||||
#define MMQ_Y_Q5_0_AMPERE 64
|
||||
#define MMQ_X_Q5_0_AMPERE 4
|
||||
#define MMQ_Y_Q5_0_AMPERE 32
|
||||
#define NWARPS_Q5_0_AMPERE 4
|
||||
#define MMQ_X_Q5_0_PASCAL 64
|
||||
#define MMQ_Y_Q5_0_PASCAL 64
|
||||
@ -3739,8 +3739,8 @@ template <bool need_check> static __global__ void
|
||||
#define MMQ_X_Q5_1_RDNA1 64
|
||||
#define MMQ_Y_Q5_1_RDNA1 64
|
||||
#define NWARPS_Q5_1_RDNA1 8
|
||||
#define MMQ_X_Q5_1_AMPERE 128
|
||||
#define MMQ_Y_Q5_1_AMPERE 64
|
||||
#define MMQ_X_Q5_1_AMPERE 4
|
||||
#define MMQ_Y_Q5_1_AMPERE 32
|
||||
#define NWARPS_Q5_1_AMPERE 4
|
||||
#define MMQ_X_Q5_1_PASCAL 64
|
||||
#define MMQ_Y_Q5_1_PASCAL 64
|
||||
@ -3800,8 +3800,8 @@ mul_mat_q5_1(
|
||||
#define MMQ_X_Q8_0_RDNA1 64
|
||||
#define MMQ_Y_Q8_0_RDNA1 64
|
||||
#define NWARPS_Q8_0_RDNA1 8
|
||||
#define MMQ_X_Q8_0_AMPERE 128
|
||||
#define MMQ_Y_Q8_0_AMPERE 64
|
||||
#define MMQ_X_Q8_0_AMPERE 4
|
||||
#define MMQ_Y_Q8_0_AMPERE 32
|
||||
#define NWARPS_Q8_0_AMPERE 4
|
||||
#define MMQ_X_Q8_0_PASCAL 64
|
||||
#define MMQ_Y_Q8_0_PASCAL 64
|
||||
@ -3861,8 +3861,8 @@ template <bool need_check> static __global__ void
|
||||
#define MMQ_X_Q2_K_RDNA1 128
|
||||
#define MMQ_Y_Q2_K_RDNA1 32
|
||||
#define NWARPS_Q2_K_RDNA1 8
|
||||
#define MMQ_X_Q2_K_AMPERE 64
|
||||
#define MMQ_Y_Q2_K_AMPERE 128
|
||||
#define MMQ_X_Q2_K_AMPERE 4
|
||||
#define MMQ_Y_Q2_K_AMPERE 32
|
||||
#define NWARPS_Q2_K_AMPERE 4
|
||||
#define MMQ_X_Q2_K_PASCAL 64
|
||||
#define MMQ_Y_Q2_K_PASCAL 64
|
||||
@ -3922,8 +3922,8 @@ mul_mat_q2_K(
|
||||
#define MMQ_X_Q3_K_RDNA1 32
|
||||
#define MMQ_Y_Q3_K_RDNA1 128
|
||||
#define NWARPS_Q3_K_RDNA1 8
|
||||
#define MMQ_X_Q3_K_AMPERE 128
|
||||
#define MMQ_Y_Q3_K_AMPERE 128
|
||||
#define MMQ_X_Q3_K_AMPERE 4
|
||||
#define MMQ_Y_Q3_K_AMPERE 32
|
||||
#define NWARPS_Q3_K_AMPERE 4
|
||||
#define MMQ_X_Q3_K_PASCAL 64
|
||||
#define MMQ_Y_Q3_K_PASCAL 64
|
||||
@ -3985,8 +3985,8 @@ template <bool need_check> static __global__ void
|
||||
#define MMQ_X_Q4_K_RDNA1 32
|
||||
#define MMQ_Y_Q4_K_RDNA1 64
|
||||
#define NWARPS_Q4_K_RDNA1 8
|
||||
#define MMQ_X_Q4_K_AMPERE 64
|
||||
#define MMQ_Y_Q4_K_AMPERE 128
|
||||
#define MMQ_X_Q4_K_AMPERE 4
|
||||
#define MMQ_Y_Q4_K_AMPERE 32
|
||||
#define NWARPS_Q4_K_AMPERE 4
|
||||
#define MMQ_X_Q4_K_PASCAL 64
|
||||
#define MMQ_Y_Q4_K_PASCAL 64
|
||||
@ -4048,8 +4048,8 @@ template <bool need_check> static __global__ void
|
||||
#define MMQ_X_Q5_K_RDNA1 32
|
||||
#define MMQ_Y_Q5_K_RDNA1 64
|
||||
#define NWARPS_Q5_K_RDNA1 8
|
||||
#define MMQ_X_Q5_K_AMPERE 64
|
||||
#define MMQ_Y_Q5_K_AMPERE 128
|
||||
#define MMQ_X_Q5_K_AMPERE 4
|
||||
#define MMQ_Y_Q5_K_AMPERE 32
|
||||
#define NWARPS_Q5_K_AMPERE 4
|
||||
#define MMQ_X_Q5_K_PASCAL 64
|
||||
#define MMQ_Y_Q5_K_PASCAL 64
|
||||
@ -4109,8 +4109,8 @@ mul_mat_q5_K(
|
||||
#define MMQ_X_Q6_K_RDNA1 32
|
||||
#define MMQ_Y_Q6_K_RDNA1 64
|
||||
#define NWARPS_Q6_K_RDNA1 8
|
||||
#define MMQ_X_Q6_K_AMPERE 64
|
||||
#define MMQ_Y_Q6_K_AMPERE 64
|
||||
#define MMQ_X_Q6_K_AMPERE 4
|
||||
#define MMQ_Y_Q6_K_AMPERE 32
|
||||
#define NWARPS_Q6_K_AMPERE 4
|
||||
#define MMQ_X_Q6_K_PASCAL 64
|
||||
#define MMQ_Y_Q6_K_PASCAL 64
|
||||
@ -7252,7 +7252,7 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
|
||||
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
|
||||
float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
{
|
||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
half * src0_as_f16 = nullptr;
|
||||
@ -7309,6 +7309,7 @@ static void ggml_cuda_mul_mat_mat_deq_cublas(const ggml_tensor * src0, const ggm
|
||||
}
|
||||
}
|
||||
#else
|
||||
// NOTE: this seems faster for tiny models and small batch-size
|
||||
{
|
||||
// convert src0 to fp32, multiply as fp32
|
||||
float * src0_as_f32 = nullptr;
|
||||
@ -7372,7 +7373,8 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
||||
} else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
|
||||
// KQ + KQV multi-batch
|
||||
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
||||
} else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 1) {
|
||||
} else if (all_on_device && (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) && src1->ne[1] > 32) {
|
||||
// F16 and quantized src0 + high-batch src1
|
||||
ggml_cuda_mul_mat_mat_deq_cublas(src0, src1, dst);
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
||||
|
Loading…
Reference in New Issue
Block a user