mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
CUDA: more warps for mmvq on NVIDIA (#5394)
This commit is contained in:
parent
41f308f58e
commit
8e6a9d2de0
133
ggml-cuda.cu
133
ggml-cuda.cu
@ -5310,22 +5310,26 @@ template <bool need_check> static __global__ void
|
|||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
#endif // __CUDA_ARCH__ >= CC_VOLTA
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
#define MMVQ_NWARPS_NVIDIA 4
|
||||||
|
#define MMVQ_NWARPS_AMD_RDNA2 1
|
||||||
|
#define MMVQ_NWARPS_AMD_OLD 4
|
||||||
|
|
||||||
|
template <int nwarps, int ncols_y_template, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
|
__launch_bounds__(nwarps*WARP_SIZE, 1) // tells the compiler to use as many registers as it wants
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
static __global__ void mul_mat_vec_q(
|
static __global__ void mul_mat_vec_q(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y_par, const int nrows_dst) {
|
||||||
|
|
||||||
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
|
const int ncols_y = ncols_y_template != 0 ? ncols_y_template : ncols_y_par;
|
||||||
|
|
||||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
|
const int row = blockIdx.x;
|
||||||
if (row >= nrows_x) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int blocks_per_row_x = ncols_x / qk;
|
const int blocks_per_row_x = ncols_x / qk;
|
||||||
const int blocks_per_col_y = nrows_y / QK8_1;
|
const int blocks_per_col_y = nrows_y / QK8_1;
|
||||||
const int blocks_per_warp = vdr * WARP_SIZE / qi;
|
const int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
|
||||||
|
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
|
float tmp[ncols_y_template != 0 ? ncols_y_template : 8] = {0.0f};
|
||||||
@ -5333,12 +5337,12 @@ static __global__ void mul_mat_vec_q(
|
|||||||
const block_q_t * x = (const block_q_t *) vx;
|
const block_q_t * x = (const block_q_t *) vx;
|
||||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||||
|
|
||||||
for (int i = threadIdx.x / (qi/vdr); i < blocks_per_row_x; i += blocks_per_warp) {
|
for (int i = tid / (qi/vdr); i < blocks_per_row_x; i += blocks_per_iter) {
|
||||||
const int ibx = row*blocks_per_row_x + i; // x block index
|
const int ibx = row*blocks_per_row_x + i; // x block index
|
||||||
|
|
||||||
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
const int iby = i * (qk/QK8_1); // y block index that aligns with ibx
|
||||||
|
|
||||||
const int iqs = vdr * (threadIdx.x % (qi/vdr)); // x block quant index when casting the quants to int
|
const int iqs = vdr * (tid % (qi/vdr)); // x block quant index when casting the quants to int
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_y; ++j) {
|
for (int j = 0; j < ncols_y; ++j) {
|
||||||
@ -5346,9 +5350,25 @@ static __global__ void mul_mat_vec_q(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y_template != 0 ? ncols_y_template : 8][WARP_SIZE];
|
||||||
|
if (threadIdx.y > 0) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols_y; ++j) {
|
||||||
|
tmp_shared[threadIdx.y-1][j][threadIdx.x] = tmp[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
if (threadIdx.y > 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_y; ++j) {
|
for (int j = 0; j < ncols_y; ++j) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < nwarps-1; ++i) {
|
||||||
|
tmp[j] += tmp_shared[i][j][threadIdx.x];
|
||||||
|
}
|
||||||
tmp[j] = warp_reduce_sum(tmp[j]);
|
tmp[j] = warp_reduce_sum(tmp[j]);
|
||||||
|
|
||||||
if (threadIdx.x == 0) {
|
if (threadIdx.x == 0) {
|
||||||
@ -6833,46 +6853,65 @@ static void mul_mat_vec_q_cuda(
|
|||||||
GGML_ASSERT(ncols_x % qk == 0);
|
GGML_ASSERT(ncols_x % qk == 0);
|
||||||
GGML_ASSERT(ncols_y <= 4);
|
GGML_ASSERT(ncols_y <= 4);
|
||||||
|
|
||||||
const int block_num_y = (nrows_x + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
int id;
|
||||||
const dim3 block_nums(block_num_y, 1, 1);
|
CUDA_CHECK(cudaGetDevice(&id));
|
||||||
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
|
||||||
switch (ncols_y) {
|
int nwarps;
|
||||||
case 1:
|
if (g_device_caps[id].cc >= CC_OFFSET_AMD) {
|
||||||
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
|
nwarps = g_device_caps[id].cc >= CC_RDNA2 ? MMVQ_NWARPS_AMD_RDNA2 : MMVQ_NWARPS_AMD_OLD;
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
} else {
|
||||||
break;
|
nwarps = MMVQ_NWARPS_NVIDIA;
|
||||||
case 2:
|
}
|
||||||
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
|
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
const dim3 block_nums(nrows_x, 1, 1);
|
||||||
break;
|
const dim3 block_dims(WARP_SIZE, nwarps, 1);
|
||||||
case 3:
|
|
||||||
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
|
switch (nwarps) {
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
case 1: switch(ncols_y) {
|
||||||
break;
|
case 1:
|
||||||
case 4:
|
mul_mat_vec_q<1, 1, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
break;
|
||||||
break;
|
case 2:
|
||||||
// case 5:
|
mul_mat_vec_q<1, 2, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
// mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||||
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
break;
|
||||||
// break;
|
case 3:
|
||||||
// case 6:
|
mul_mat_vec_q<1, 3, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
// mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||||
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
break;
|
||||||
// break;
|
case 4:
|
||||||
// case 7:
|
mul_mat_vec_q<1, 4, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
// mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||||
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
break;
|
||||||
// break;
|
default:
|
||||||
// case 8:
|
GGML_ASSERT(false);
|
||||||
// mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
|
break;
|
||||||
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
} break;
|
||||||
// break;
|
case 4: switch(ncols_y) {
|
||||||
|
case 1:
|
||||||
|
mul_mat_vec_q<4, 1, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
mul_mat_vec_q<4, 2, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
mul_mat_vec_q<4, 3, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||||
|
break;
|
||||||
|
case 4:
|
||||||
|
mul_mat_vec_q<4, 4, qk, qi, block_q_t, vdr, vec_dot>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
|
} break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
// mul_mat_vec_q<0, qk, qi, block_q_t, vdr, vec_dot>
|
|
||||||
// <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst);
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user