mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
CUDA: mul_mat_vec_q for batch sizes > 1 (#5351)
This commit is contained in:
parent
8a79c591de
commit
2c516611f1
236
ggml-cuda.cu
236
ggml-cuda.cu
@ -5310,41 +5310,50 @@ template <bool need_check> static __global__ void
|
|||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||||
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
|
static __global__ void mul_mat_vec_q(
|
||||||
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par) {
|
||||||
|
|
||||||
|
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
|
||||||
|
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
|
|
||||||
if (row >= nrows) {
|
if (row >= nrows_x) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int blocks_per_row = ncols / qk;
|
const int blocks_per_row_x = ncols_x / qk;
|
||||||
|
const int blocks_per_col_y = nrows_y / QK8_1;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
||||||
|
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp = 0.0f;
|
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
|
||||||
|
|
||||||
const block_q_t * x = (const block_q_t *) vx;
|
const block_q_t * x = (const block_q_t *) vx;
|
||||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||||
|
|
||||||
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row; i += blocks_per_warp) {
|
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) {
|
||||||
const int ibx = row*blocks_per_row + i; // x block index
|
const int ibx = row*blocks_per_row_x + i; // x block index
|
||||||
|
|
||||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||||
|
|
||||||
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
|
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
|
||||||
|
|
||||||
tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols_y; ++j) {
|
||||||
|
tmp[j] += vec_dot_q_cuda(&x[ibx], &y[j*blocks_per_col_y + iby], iqs);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int j = 0; j < ncols_y; ++j) {
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp[j] = warp_reduce_sum(tmp[j]);
|
||||||
}
|
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
dst[row] = tmp;
|
dst[j*nrows_x + row] = tmp[j];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -6816,121 +6825,56 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
|
|||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
|
||||||
GGML_ASSERT(ncols % QK4_0 == 0);
|
static void mul_mat_vec_q_cuda(
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
const void * vx, const void * vy, float * dst,
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, cudaStream_t stream) {
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
GGML_ASSERT(ncols_x % qk == 0);
|
||||||
GGML_ASSERT(ncols % QK4_1 == 0);
|
GGML_ASSERT(ncols_y <= 8);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK4_0, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
GGML_ASSERT(ncols % QK5_0 == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
const dim3 block_nums(block_num_y, 1, 1);
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
mul_mat_vec_q<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
switch (ncols_y) {
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
case 1:
|
||||||
|
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
|
break;
|
||||||
|
case 5:
|
||||||
|
mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
|
break;
|
||||||
|
case 6:
|
||||||
|
mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
|
break;
|
||||||
|
case 7:
|
||||||
|
mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
|
break;
|
||||||
|
case 8:
|
||||||
|
mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(ncols % QK5_1 == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(ncols % QK8_0 == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_iq2_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_iq2_xs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
|
||||||
|
|
||||||
static void mul_mat_vec_iq3_xxs_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
|
||||||
GGML_ASSERT(ncols % QK_K == 0);
|
|
||||||
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
mul_mat_vec_q<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
static void ggml_mul_mat_q4_0_q8_1_cuda(
|
||||||
@ -8578,50 +8522,61 @@ static void ggml_cuda_op_mul_mat_vec_q(
|
|||||||
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
|
||||||
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
const int64_t src1_padded_row_size, cudaStream_t stream) {
|
||||||
|
|
||||||
GGML_ASSERT(ggml_nrows(src1) == 1);
|
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t row_diff = row_high - row_low;
|
const int64_t row_diff = row_high - row_low;
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q3_K:
|
case GGML_TYPE_Q3_K:
|
||||||
mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_K:
|
case GGML_TYPE_Q4_K:
|
||||||
mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_K:
|
case GGML_TYPE_Q5_K:
|
||||||
mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q6_K:
|
case GGML_TYPE_Q6_K:
|
||||||
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ2_XXS:
|
case GGML_TYPE_IQ2_XXS:
|
||||||
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ2_XS:
|
case GGML_TYPE_IQ2_XS:
|
||||||
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_IQ3_XXS:
|
case GGML_TYPE_IQ3_XXS:
|
||||||
mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stream);
|
mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
|
||||||
|
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, stream);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
@ -9945,17 +9900,18 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||||||
#ifdef GGML_CUDA_FORCE_DMMV
|
#ifdef GGML_CUDA_FORCE_DMMV
|
||||||
const bool use_mul_mat_vec_q = false;
|
const bool use_mul_mat_vec_q = false;
|
||||||
#else
|
#else
|
||||||
const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type) && ggml_nrows(src1) == 1;
|
const bool use_mul_mat_vec_q = min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type);
|
||||||
#endif // GGML_CUDA_FORCE_DMMV
|
#endif // GGML_CUDA_FORCE_DMMV
|
||||||
|
|
||||||
if (use_mul_mat_vec_q) {
|
if (use_mul_mat_vec_q) {
|
||||||
// NOTE: this kernel does not support ggml_nrows(src1) > 1
|
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (use_mul_mat_q) {
|
if (src1->ne[1] <= 8 && min_compute_capability >= MIN_CC_DP4A && ggml_is_quantized(src0->type)) {
|
||||||
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
|
||||||
|
} else if (use_mul_mat_q) {
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
|
||||||
} else {
|
} else {
|
||||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
||||||
|
Loading…
Reference in New Issue
Block a user