From d50f8897a797a5a03f31228d1b5a7b8130ee1bc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Thu, 20 Jun 2024 14:39:21 +0200 Subject: [PATCH] CUDA: stream-k decomposition for MMQ (#8018) * CUDA: stream-k decomposition for MMQ * fix undefined memory reads for small matrices --- ggml-cuda.cu | 2 +- ggml-cuda/common.cuh | 4 +- ggml-cuda/mmq.cu | 20 +-- ggml-cuda/mmq.cuh | 379 +++++++++++++++++++++++++++++++------------ 4 files changed, 292 insertions(+), 113 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index b8298ab20..f914efd71 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -635,7 +635,7 @@ static int64_t get_row_rounding(const std::array & } const int cc = ggml_cuda_info().devices[id].cc; - row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc, get_mmq_x_max_host(cc))); + row_rounding = std::max(row_rounding, (int64_t)get_mmq_y_host(cc)); } return row_rounding; } diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index de7c2e434..5bd24ebe5 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -652,8 +652,8 @@ static int get_mmq_x_max_host(const int cc) { } // Round rows to this value for --split-mode row: -static int get_mmq_y_host(const int cc, const int mmq_x) { - return cc >= CC_VOLTA && mmq_x >= 32 ? 128 : 64; +static int get_mmq_y_host(const int cc) { + return cc >= CC_VOLTA ? 128 : 64; } ////////////////////// diff --git a/ggml-cuda/mmq.cu b/ggml-cuda/mmq.cu index 1d6b9e698..6dbd85fef 100644 --- a/ggml-cuda/mmq.cu +++ b/ggml-cuda/mmq.cu @@ -30,34 +30,34 @@ void ggml_cuda_op_mul_mat_q( switch (src0->type) { case GGML_TYPE_Q4_0: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q4_1: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q5_0: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q5_1: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q8_0: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q2_K: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q3_K: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q4_K: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q5_K: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; case GGML_TYPE_Q6_K: - mul_mat_q_case(args, stream); + mul_mat_q_case(ctx, args, stream); break; default: GGML_ASSERT(false); diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 6d57974fb..e2d07c202 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -8,6 +8,7 @@ #include #define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) +#define MMQ_NWARPS 8 typedef void (*load_tiles_mmq_t)( const char * __restrict__ x, int * __restrict__ x_qs, half2 * __restrict__ x_dm, @@ -15,7 +16,7 @@ typedef void (*load_tiles_mmq_t)( typedef void (*vec_dot_mmq_t)( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0); -typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1); +typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max); struct block_q8_1_mmq { half2 ds[4]; @@ -50,21 +51,17 @@ static constexpr __device__ int get_mmq_x_max_device() { // get_mmq_y_host is in common.cuh so that it can be used to determine the correct way to round for --split-mode row +static constexpr __device__ int get_mmq_y_device() { #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -static constexpr __device__ int get_mmq_y_device(int mmq_x) { - return mmq_x >= 32 ? 128 : 64; -} + return 128; #else #if __CUDA_ARCH__ >= CC_VOLTA -static constexpr __device__ int get_mmq_y_device(int mmq_x) { - return mmq_x >= 32 ? 128 : 64; -} + return 128; #else -static constexpr __device__ int get_mmq_y_device(int /*mmq_x*/) { return 64; -} #endif // __CUDA_ARCH__ >= CC_VOLTA #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +} #define TILE_X_SIZES_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} #define TILE_X_SIZES_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} @@ -1734,30 +1731,34 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } template -static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { +static __device__ __forceinline__ void mmq_write_back_dp4a( + const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { - const int j = blockIdx.y*mmq_x + j0 + threadIdx.y; + const int j = j0 + threadIdx.y; - if (j >= ne1) { + if (j > j_max) { return; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { - const int i = blockIdx.x*mmq_y + i0 + threadIdx.x; + const int i = i0 + threadIdx.x; - if (need_check && i >= ne0) { + if (need_check && i > i_max) { continue; } - dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; } } } template -static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) { +static __device__ __forceinline__ void mmq_write_back_mma( + const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) { + typedef mma_int_C_I16J8 mma_C; const int i0 = threadIdx.y*mma_C::I; @@ -1769,19 +1770,19 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) { #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l); + const int j = j0 + mma_C::get_j(l); - if (j >= ne1) { + if (j > j_max) { continue; } - const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l); + const int i = i0 + mma_C::get_i(l); - if (need_check && i >= ne0) { + if (need_check && i > i_max) { continue; } - dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l]; + dst[j*stride + i] = sum[(j0/mma_C::J)*mma_C::ne + l]; } } } @@ -1896,32 +1897,16 @@ static bool mmq_need_sum(const ggml_type type_x) { return false; } -template -#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -#if defined(RDNA3) || defined(RDNA2) - __launch_bounds__(WARP_SIZE*nwarps, 2) -#endif // defined(RDNA3) || defined(RDNA2) -#else -#if __CUDA_ARCH__ >= CC_VOLTA - __launch_bounds__(WARP_SIZE*nwarps, 1) -#else - __launch_bounds__(WARP_SIZE*nwarps, 2) -#endif // __CUDA_ARCH__ >= CC_VOLTA -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) -static __global__ void mul_mat_q( - const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, - const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) { - - // Skip unused template specializations for faster compilation: - if (mmq_x > get_mmq_x_max_device()) { - NO_DEVICE_CODE; - return; - } +template +static __device__ void mul_mat_q_process_tile( + const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0, + const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) { constexpr int qk = ggml_cuda_type_traits::qk; constexpr int qr = ggml_cuda_type_traits::qr; constexpr int qi = ggml_cuda_type_traits::qi; - constexpr int mmq_y = get_mmq_y_device(mmq_x); + constexpr int mmq_y = get_mmq_y_device(); constexpr int vdr = mmq_type_traits::vdr; constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; @@ -1941,20 +1926,18 @@ static __global__ void mul_mat_q( int * tile_x_sc = (int *) (tile_x_dm + txs.dm); int * tile_y = (int *) (tile_x_sc + txs.sc); // [mmq_x * (WARP_SIZE + WARP_SIZE/QI8_1)] - const int blocks_per_row_x = ne00 / qk; - const int blocks_per_warp = WARP_SIZE / qi; - - const int & ne1 = ne11; - - const int tile_x_max_i = ne01 - blockIdx.x*mmq_y - 1; - - const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); + constexpr int blocks_per_warp = WARP_SIZE / qi; float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; - for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) { + const int tile_x_max_i = ne01 - it*mmq_y - 1; + const int tile_y_max_j = ne11 - jt*mmq_x - 1; - load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*blockIdx.x*mmq_y + kb0, tile_x_max_i, stride01); + const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int)); + + for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_warp) { + + load_tiles(x, tile_x_qs, tile_x_dm, tile_x_sc, stride01*it*mmq_y + kb0, tile_x_max_i, stride01); #pragma unroll for (int kr = 0; kr < qr; ++kr) { @@ -1977,7 +1960,176 @@ static __global__ void mul_mat_q( } } - write_back(sum, dst, ne0, ne1); + if (fixup) { + write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x); + } else { + write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j); + } +} + + +// The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 + +template +#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +#if defined(RDNA3) || defined(RDNA2) + __launch_bounds__(WARP_SIZE*nwarps, 2) +#endif // defined(RDNA3) || defined(RDNA2) +#else +#if __CUDA_ARCH__ >= CC_VOLTA + __launch_bounds__(WARP_SIZE*nwarps, 1) +#else + __launch_bounds__(WARP_SIZE*nwarps, 2) +#endif // __CUDA_ARCH__ >= CC_VOLTA +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) +static __global__ void mul_mat_q( + const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup, + const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) { + + // Skip unused template specializations for faster compilation: + if (mmq_x > get_mmq_x_max_device()) { + NO_DEVICE_CODE; + return; + } + + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int mmq_y = get_mmq_y_device(); + + // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: +#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA + { + constexpr bool fixup = false; + mul_mat_q_process_tile + (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, + blockIdx.x, blockIdx.y, 0, ne00/qk); + return; + } +#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA + + const int64_t blocks_per_ne00 = ne00 / qk; + constexpr int blocks_per_warp = WARP_SIZE / qi; + + const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x + const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y + + // kbc == k block continuous, current index in continuous ijk space. + int64_t kbc = GGML_PAD((int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp); + const int64_t kbc_stop = GGML_PAD((int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x, blocks_per_warp); + + // kb0 == k index when doing the matrix multiplication for an output tile. + int kb0_start = kbc % blocks_per_ne00; + int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc); + while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) { + const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile. + const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile. + + constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. + mul_mat_q_process_tile + (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, + it, jt, kb0_start, kb0_stop); + + kbc += blocks_per_ne00; + kbc -= kbc % blocks_per_ne00; + + kb0_start = 0; + kb0_stop = min(blocks_per_ne00, kbc_stop - kbc); + } + + if (kbc >= kbc_stop) { + return; + } + + const int jt = kbc / (blocks_per_ne00*nty); + const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; + + constexpr bool fixup = true; // Last index writes it data to fixup buffer to avoid data races with other blocks. + mul_mat_q_process_tile + (x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0, + it, jt, kb0_start, kb0_stop); +} + + +template +static __global__ void mul_mat_q_stream_k_fixup( + float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) { + + constexpr int mmq_y = get_mmq_y_device(); + constexpr int qk = ggml_cuda_type_traits::qk; + constexpr int qi = ggml_cuda_type_traits::qi; + constexpr int blocks_per_warp = WARP_SIZE / qi; + const int64_t blocks_per_ne00 = ne00 / qk; + + float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; + + const int ntx = (ne11 + mmq_x - 1) / mmq_x; + const int nty = (ne01 + mmq_y - 1) / mmq_y; + + bool any_fixup = false; + + const int bidx_start = (blockIdx.y*nty + blockIdx.x) * block_num_mmq / (gridDim.y*gridDim.x); + const int bidx_stop = (blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq / (gridDim.y*gridDim.x) + 1; + + for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) { + const int64_t kbc = GGML_PAD((int64_t) bidx *blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp); + const int64_t kbc_stop = GGML_PAD((int64_t)(bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq, blocks_per_warp); + + // Skip fixup tile if the MMQ CUDA block never wrote anything to it: + if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) { + continue; + } + + const int jt = kbc_stop / (blocks_per_ne00*nty); + const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; + + // Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block: + if (it != blockIdx.x || jt != blockIdx.y) { + continue; + } + + any_fixup = true; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; + } + } + } + + if (!any_fixup) { + return; + } + + dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y; + + const int i_max = ne01 - blockIdx.x*mmq_y - 1; + const int j_max = ne11 - blockIdx.y*mmq_x - 1; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + + if (j > j_max) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + if (need_check && i > i_max) { + continue; + } + + dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + } + } } struct mmq_args { @@ -1987,124 +2139,151 @@ struct mmq_args { int64_t ne0; }; -constexpr int mmq_get_nwarps(int mmq_x) { - return mmq_x >= 32 ? 8 : 4; -} - static int mmq_get_shmem(const ggml_type type, const int mmq_x, const int mmq_y) { const tile_x_sizes txs = get_tile_x_sizes_host(type, mmq_y); - const int nwarps = mmq_get_nwarps(mmq_x); const int shmem_x = txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const int shmem_y = mmq_x*WARP_SIZE*sizeof(int) + mmq_x*(WARP_SIZE/QI8_1)*sizeof(half2); - return shmem_x + GGML_PAD(shmem_y, nwarps*WARP_SIZE*sizeof(int)); + return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); } -template -static void launch_mul_mat_q(const mmq_args & args, cudaStream_t stream) { +template +static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; - const int mmq_y = get_mmq_y_host(cc, mmq_x); + const int nsm = ggml_cuda_info().devices[id].nsm; + const int mmq_y = get_mmq_y_host(cc); - const int block_num_x = (args.ne01 + mmq_y - 1) / mmq_y; - const int block_num_y = (args.ne11 + mmq_x - 1) / mmq_x; - const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); const int shmem = mmq_get_shmem(type, mmq_x, mmq_y); #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false}; if (!shmem_limit_raised[id]) { - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); - CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); + CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem)); shmem_limit_raised[id] = true; } #endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) + const int nty = (args.ne01 + mmq_y - 1) / mmq_y; + const int ntx = (args.ne11 + mmq_x - 1) / mmq_x; + const dim3 block_nums_xy_tiling(nty, ntx, 1); + + const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD; + if (!use_stream_k) { + if (args.ne01 % mmq_y == 0) { + constexpr bool need_check = false; + mul_mat_q<<>> + (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + } else { + constexpr bool need_check = true; + mul_mat_q<<>> + (args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + } + return; + } + + const dim3 block_nums_mmq(nsm, 1, 1); + + ggml_cuda_pool & pool = ctx.pool(); + ggml_cuda_pool_alloc tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y); + if (args.ne01 % mmq_y == 0) { - const bool need_check = false; - mul_mat_q<<>> - (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + constexpr bool need_check = false; + + mul_mat_q<<>> + (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + + mul_mat_q_stream_k_fixup<<>> + (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); } else { - const bool need_check = true; - mul_mat_q<<>> - (args.x, args.y, args.dst, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + constexpr bool need_check = true; + + mul_mat_q<<>> + (args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0); + + mul_mat_q_stream_k_fixup<<>> + (args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x); } } template -void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { +void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { const int id = ggml_cuda_get_device(); const int nsm = ggml_cuda_info().devices[id].nsm; const int cc = ggml_cuda_info().devices[id].cc; const int smpbo = ggml_cuda_info().devices[id].smpbo; const int mmq_x_max = get_mmq_x_max_host(cc); - const int mmq_y = get_mmq_y_host(cc, mmq_x_max); + const int mmq_y = get_mmq_y_host(cc); const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y; + const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD; int mmq_x_best = 0; - int nwaves_best = INT_MAX; + int nparts_best = INT_MAX; - for (int mmq_x = 8; mmq_x <= mmq_x_max && nwaves_best > 1; mmq_x += 8) { - const int block_num_x = (args.ne11 + mmq_x - 1) / mmq_x; - const int nwaves = (block_num_x*block_num_y + nsm - 1) / nsm; + for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) { + const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x; + const int nwaves_xy_tiling = ntiles_x*block_num_y; - if (nwaves < nwaves_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) { + const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling; + + if (nparts < nparts_best && mmq_get_shmem(type, mmq_x, mmq_y) <= smpbo) { mmq_x_best = mmq_x; - nwaves_best = nwaves; + nparts_best = nparts; } } switch (mmq_x_best) { case 8: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 16: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 24: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 32: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 40: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 48: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 56: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 64: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 72: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 80: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 88: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 96: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 104: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 112: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 120: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; case 128: - launch_mul_mat_q(args, stream); + launch_mul_mat_q(ctx, args, stream); break; default: fprintf(stderr, "mmq_x_best=%d\n", mmq_x_best); @@ -2114,7 +2293,7 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) { } #define DECL_MMQ_CASE(type) \ - template void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) \ + template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \ extern DECL_MMQ_CASE(GGML_TYPE_Q4_0); extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);