From 69c487f4ed57bb4d4514a1b7ff12608d5a8e7ef0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 20 Jul 2024 22:25:26 +0200 Subject: [PATCH] CUDA: MMQ code deduplication + iquant support (#8495) * CUDA: MMQ code deduplication + iquant support * 1 less parallel job for CI build --- .github/workflows/build.yml | 2 +- ggml/src/ggml-cuda/mmq.cu | 24 + ggml/src/ggml-cuda/mmq.cuh | 1359 +++++++++-------- .../template-instances/generate_cu_files.py | 3 +- .../template-instances/mmq-instance-iq1_s.cu | 5 + .../template-instances/mmq-instance-iq2_s.cu | 5 + .../template-instances/mmq-instance-iq2_xs.cu | 5 + .../mmq-instance-iq2_xxs.cu | 5 + .../template-instances/mmq-instance-iq3_s.cu | 5 + .../mmq-instance-iq3_xxs.cu | 5 + ggml/src/ggml-cuda/vecdotq.cuh | 21 + 11 files changed, 800 insertions(+), 639 deletions(-) create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu create mode 100644 ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 1f275603a..a1e183d11 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -860,7 +860,7 @@ jobs: mkdir build cd build cmake .. -DGGML_NATIVE=OFF -DLLAMA_BUILD_SERVER=ON -DGGML_CUDA=ON -DBUILD_SHARED_LIBS=ON - cmake --build . --config Release -j ${env:NUMBER_OF_PROCESSORS} + cmake --build . --config Release -j $((${env:NUMBER_OF_PROCESSORS} - 1)) - name: Determine tag name id: tag diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 3af2eec2f..84f6387e2 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -59,6 +59,24 @@ void ggml_cuda_op_mul_mat_q( case GGML_TYPE_Q6_K: mul_mat_q_case(ctx, args, stream); break; + case GGML_TYPE_IQ2_XXS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ2_XS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ2_S: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ3_XXS: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ3_S: + mul_mat_q_case(ctx, args, stream); + break; + case GGML_TYPE_IQ1_S: + mul_mat_q_case(ctx, args, stream); + break; case GGML_TYPE_IQ4_XS: mul_mat_q_case(ctx, args, stream); break; @@ -93,6 +111,12 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: mmq_supported = true; diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 51c44d857..f08a4758d 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -63,6 +63,14 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) { case GGML_TYPE_Q5_K: return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + return MMQ_Q8_1_DS_LAYOUT_D4; + case GGML_TYPE_IQ1_S: + return MMQ_Q8_1_DS_LAYOUT_DS4; case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_NL: return MMQ_Q8_1_DS_LAYOUT_D4; @@ -131,15 +139,16 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) } -#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} -#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} -#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} -#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0} -#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} -#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K + mmq_y/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 : @@ -152,42 +161,46 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml type == GGML_TYPE_Q4_K ? MMQ_DP4A_TXS_Q4_K : type == GGML_TYPE_Q5_K ? MMQ_DP4A_TXS_Q5_K : type == GGML_TYPE_Q6_K ? MMQ_DP4A_TXS_Q6_K : + type == GGML_TYPE_IQ2_XXS ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ2_XS ? MMQ_DP4A_TXS_Q8_0_16 : + type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 : + type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 : + type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 : type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 : type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 : tile_x_sizes{0, 0, 0}; } -#define MMQ_MMA_TILE_X_K_Q4_0 (1*WARP_SIZE + WARP_SIZE/QI4_0 + 4) -#define MMQ_MMA_TILE_X_K_Q4_1 (1*WARP_SIZE + WARP_SIZE/QI4_1 + 4) #define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) #define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/(2*QI3_K) + WARP_SIZE/8 + 7) -#define MMQ_MMA_TILE_X_K_Q4_K (1*WARP_SIZE + WARP_SIZE/QI4_K + WARP_SIZE/8 + 7) -#define MMQ_MMA_TILE_X_K_Q5_K (2*WARP_SIZE + WARP_SIZE/QI5_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4) #define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) -static_assert(MMQ_MMA_TILE_X_K_Q4_0 % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q4_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q4_K % 8 == 4, "Wrong padding."); -static_assert(MMQ_MMA_TILE_X_K_Q5_K % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding."); static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q4_0 : - type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q4_1 : + return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 : type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 : type == GGML_TYPE_Q8_0 ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_Q2_K ? MMQ_MMA_TILE_X_K_Q2_K : type == GGML_TYPE_Q3_K ? MMQ_MMA_TILE_X_K_Q3_K : - type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q4_K : - type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q5_K : + type == GGML_TYPE_Q4_K ? MMQ_MMA_TILE_X_K_Q8_1 : + type == GGML_TYPE_Q5_K ? MMQ_MMA_TILE_X_K_Q8_1 : type == GGML_TYPE_Q6_K ? MMQ_MMA_TILE_X_K_Q6_K : + type == GGML_TYPE_IQ2_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ2_XS ? MMQ_MMA_TILE_X_K_Q3_K : + type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K : + type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 : + type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 : type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 : 0; @@ -216,7 +229,7 @@ template static __device__ __forceinlin #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE); + float * x_df = (float *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; @@ -235,11 +248,13 @@ template static __device__ __forceinlin } const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b2(bxi->qs, kqsx); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + threadIdx.x] = get_int_b2(bxi->qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b2(bxi->qs, kqsx); + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; #endif // INT8_MMA_AVAILABLE } @@ -257,7 +272,7 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + kbxd] = bxi->d; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; #endif // INT8_MMA_AVAILABLE @@ -304,95 +319,12 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( } } -template -static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - mma_A A[ntx][4]; - float dA[ntx][mma_C::ne/2][4]; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) { - const int k0 = k00 + k01; - -#pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - const int i = i0 + n*mma_A::I + mma_A::get_i(l); - const int k = k0/QR4_0 + mma_A::get_k(l) % QI4_0; - const int shift = 4*(mma_A::get_k(l) / QI4_0); - - A[n][k01/(QR4_0*QI4_0)].x[l] = __vsubss4((x_qs[i*MMQ_MMA_TILE_X_K_Q4_0 + k] >> shift) & 0x0F0F0F0F, 0x08080808); - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - - dA[n][l][k01/(QR4_0*QI4_0)] = x_df[i*MMQ_MMA_TILE_X_K_Q4_0 + k0/(QR4_0*QI4_0)]; - } - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*QI4_0) { - mma_B B; - float dB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n][k01/(QR4_0*QI4_0)], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2][k01/(QR4_0*QI4_0)]*dB[l%2]*C.x[l]; - } - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; @@ -411,11 +343,13 @@ template static __device__ __forceinlin } const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; + const int qs0 = get_int_b4(bxi->qs, kqsx); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q4_1 + threadIdx.x] = get_int_b4(bxi->qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx); + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; #endif // INT8_MMA_AVAILABLE } @@ -433,7 +367,7 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; #ifdef INT8_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + kbxd] = bxi->dm; + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; #endif // INT8_MMA_AVAILABLE @@ -480,88 +414,6 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( } } -template -static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_A_I16K4 mma_A_K4; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - mma_A A[ntx][4]; - half2 dmA[ntx][mma_C::ne/2][4]; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) { - const int k0 = k00 + k01; - - A[n][k01/(QR4_1*QI4_1)].load_low(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_1 + k0/QR4_1, MMQ_MMA_TILE_X_K_Q4_1); - A[n][k01/(QR4_1*QI4_1)].x[2] = (A[n][k01/(QR4_1*QI4_1)].x[0] >> 4) & 0x0F0F0F0F; - A[n][k01/(QR4_1*QI4_1)].x[3] = (A[n][k01/(QR4_1*QI4_1)].x[1] >> 4) & 0x0F0F0F0F; - A[n][k01/(QR4_1*QI4_1)].x[0] &= 0x0F0F0F0F; - A[n][k01/(QR4_1*QI4_1)].x[1] &= 0x0F0F0F0F; - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - - dmA[n][l][k01/(QR4_1*QI4_1)] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_1 + k0/(QR4_1*QI4_1)]; - } - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*QI4_1) { - mma_B B; - half2 dsB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n][k01/(QR4_1*QI4_1)], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[n][l/2][k01/(QR4_1*QI4_1)]*dsB[l%2]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); - } - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -789,10 +641,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( } } -template +template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -808,6 +659,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const float * x_df = (const float *) x_qs + 2*WARP_SIZE; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; mma_A A[ntx][WARP_SIZE/QI8_0]; float dA[ntx][mma_C::ne/2][WARP_SIZE/QI8_0]; @@ -840,18 +692,20 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { - const int k0 = k00 + k01; - - mma_B B; + mma_B B; float dB[mma_C::ne/2]; - B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dB[l] = y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB[l] = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } } #pragma unroll @@ -866,10 +720,6 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } } } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE } template @@ -905,7 +755,6 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -922,8 +771,8 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * y_qs = (const int *) y + 4; const half2 * y_dm = (const half2 *) y; - mma_A A[ntx][WARP_SIZE/QI8_1]; - half2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1]; + mma_A A[ntx][WARP_SIZE/QI8_1]; + float2 dmA[ntx][mma_C::ne/2][WARP_SIZE/QI8_1]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; @@ -944,7 +793,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { const int k0 = k00 + k01; - dmA[n][l][k01/QI8_1] = x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]; + dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); } } } @@ -953,18 +802,16 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { #pragma unroll for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { - const int k0 = k00 + k01; + mma_B B; + float2 dsB[mma_C::ne/2]; - mma_B B; - half2 dsB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); + B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); #pragma unroll for (int l = 0; l < mma_C::ne/2; ++l) { const int j = j0 + mma_C::get_j(l); - dsB[l] = y_dm[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; + dsB[l] = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); } #pragma unroll @@ -974,8 +821,120 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #pragma unroll for (int l = 0; l < mma_C::ne; ++l) { - const half2 dmA_dsB = dmA[n][l/2][k01/QI8_1]*dsB[l%2]; - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].x*dsB[l%2].x*C.x[l]; + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dmA[n][l/2][k01/QI8_1].y*dsB[l%2].y; + } + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { + + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + txs.qs; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + +// #pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + const int k0 = k00 + k01; + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { + const int j = j0 + threadIdx.y; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + const int i = i0 + threadIdx.x; + + sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl( + &x_qs[i*(2*WARP_SIZE + 1) + k0], + &y_qs[j*MMQ_TILE_Y_K + k01], + &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], + y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + } + } +} + +template +static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( + const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { +#ifdef INT8_MMA_AVAILABLE + + typedef mma_int_A_I16K4 mma_A; + typedef mma_int_A_I16K8 mma_A_K8; + typedef mma_int_B_J8K4 mma_B; + typedef mma_int_C_I16J8 mma_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = 2 * granularity; + constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); + + mma_A A[ntx][8]; + float dA[ntx][mma_C::ne/2][8]; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + const int k0 = k00 + k01; + + ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); + +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) { + const int k0 = k00 + k01; + + dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; + } + } + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { +#pragma unroll + for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + mma_B B[2]; + float dB[mma_C::ne/2]; + + B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); + B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); + +#pragma unroll + for (int l = 0; l < mma_C::ne/2; ++l) { + const int j = j0 + mma_C::get_j(l); + + dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + mma_C C[2]; + C[0].mma_K4(A[n][k01/4 + 0], B[0]); + C[1].mma_K4(A[n][k01/4 + 1], B[1]); + +#pragma unroll + for (int l = 0; l < mma_C::ne; ++l) { + sum[(j0/mma_C::J + n)*mma_C::ne + l] += dB[l%2]*(C[0].x[l]*dA[n][l/2][k01/4 + 0] + C[1].x[l]*dA[n][l/2][k01/4 + 1]); } } } @@ -1222,7 +1181,6 @@ template static __device__ __forceinlin #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + WARP_SIZE*2); - int * x_sc = (int *) (x_df + 1); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); int * x_qs = (int *) x_tile; @@ -1262,23 +1220,6 @@ template static __device__ __forceinlin } } -#pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { - int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } - - const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - -#ifdef INT8_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K] = bxi->d; -#else - x_df[i] = bxi->d; -#endif // INT8_MMA_AVAILABLE - } - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8); @@ -1302,11 +1243,32 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); #ifdef INT8_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + threadIdx.x % (WARP_SIZE/8)] = sc; + const int8_t * sc8 = (const int8_t *) ≻ + const float d = bxi->d; + +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l]; + } #else - x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; + x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; #endif // INT8_MMA_AVAILABLE } + +#ifndef INT8_MMA_AVAILABLE +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { + int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; + + x_df[i] = bxi->d; + } +#endif // INT8_MMA_AVAILABLE } template @@ -1342,99 +1304,14 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( } } -template -static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K4 mma_A; - typedef mma_int_A_I16K8 mma_A_K8; - typedef mma_int_B_J8K4 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE*2; - const int * x_sc = (const int *) x_df + 1; - const int * y_qs = (const int *) y + 4; - const float * y_df = (const float *) y; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - - mma_A A[ntx][8]; - int scA[ntx][mma_C::ne/2][8]; - float dA[ntx][mma_C::ne/2]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - const int k0 = k00 + k01; - - ((mma_A_K8 *) A[n])[k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { - const int k0 = k00 + k01; - - const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q3_K + k0/16]; - const int8_t * sc = (const int8_t *) &sc_packed; - -#pragma unroll - for (int ksc = 0; ksc < sizeof(int); ++ksc) { - scA[n][l][k01/4 + ksc] = sc[ksc]; - } - } - - dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { - mma_B B[2]; - float dB[mma_C::ne/2]; - - B[0].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + 0), MMQ_TILE_Y_K); - B[1].load(y_qs + j0*MMQ_TILE_Y_K + (k01 + mma_B::K), MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C[2]; - C[0].mma_K4(A[n][k01/4 + 0], B[0]); - C[1].mma_K4(A[n][k01/4 + 1], B[1]); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += dA[n][l/2]*dB[l%2]* - (C[0].x[l]*scA[n][l/2][k01/4 + 0] + C[1].x[l]*scA[n][l/2][k01/4 + 1]); - } - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE +static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, const int ksc) { + // scale arrangement after the following two lines: + // - ksc == 0: sc0, sc1, sc2, sc3 + // - ksc == 1: sc4, sc5, sc6, sc7 + // - ksc == 2: m0, m1, m2, m3 + // - ksc == 3: m4, m5, m6, m7 + return ((scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F) | // lower 4 bits + ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits } template static __device__ __forceinline__ void load_tiles_q4_K( @@ -1442,8 +1319,7 @@ template static __device__ __forceinlin #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + WARP_SIZE); - int * x_sc = (int *) (x_dm + WARP_SIZE/QI4_K); + half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); int * x_qs = (int *) x_tile; @@ -1451,9 +1327,6 @@ template static __device__ __forceinlin int * x_sc = (int *) (x_dm + txs.dm); #endif // INT8_MMA_AVAILABLE - const int kbx = 0; // threadIdx.x / QI4_K - const int kqsx = threadIdx.x; // threadIdx.x % QI4_K - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -1462,33 +1335,59 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbx; + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const int qs0 = get_int_b4(bxi->qs, threadIdx.x); #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q4_K + threadIdx.x] = get_int_b4(bxi->qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = get_int_b4(bxi->qs, kqsx); + x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; #endif // INT8_MMA_AVAILABLE } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 - const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 +#ifdef INT8_MMA_AVAILABLE #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_K) { - int i = (i0 + threadIdx.y * QI4_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { + int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + kbxd; + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % (WARP_SIZE/16); + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } -#ifdef INT8_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q4_K + kbxd] = bxi->dm; #else - x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K + kbxd] = bxi->dm; -#endif // INT8_MMA_AVAILABLE + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) { + int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; } #pragma unroll @@ -1504,17 +1403,11 @@ template static __device__ __forceinlin const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - -#ifdef INT8_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + ksc] = scales8; -#else - x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; -#endif // INT8_MMA_AVAILABLE + x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } +#endif // INT8_MMA_AVAILABLE } template @@ -1544,133 +1437,18 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, - x_dm[i*(WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template -static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE; - const int * x_sc = (const int *) x_dm + WARP_SIZE/QI4_K; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - - mma_A A[ntx][4]; - int scA[ntx][mma_C::ne/2][4]; - int mA[ntx][mma_C::ne/2][4]; - half2 dmA[ntx][mma_C::ne/2]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { - const int k0 = k00 + k01; - - A[n][k01/8 + 0].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q4_K + k0/QR4_K, MMQ_MMA_TILE_X_K_Q4_K); - -#pragma unroll - for (int l = 0; l < mma_A::ne; ++l) { - A[n][k01/8 + 1].x[l] = (A[n][k01/8 + 0].x[l] >> 4) & 0x0F0F0F0F; - A[n][k01/8 + 0].x[l] &= 0x0F0F0F0F; - } - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - - const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 0)]; - const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q4_K + (k00/32 + 2)]; - - const uint8_t * sc = (const uint8_t *) &sc_packed; - const uint8_t * m = (const uint8_t *) &m_packed; - -#pragma unroll - for (int ksc = 0; ksc < sizeof(int); ++ksc) { - scA[n][l][ksc] = sc[ksc]; - mA[n][l][ksc] = m[ksc]; - } - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_A::I + mma_C::get_i(2*l); - - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q4_K]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - float tmpd[ntx][mma_C::ne] = {{0.0f}}; - float tmpm[ntx][mma_C::ne] = {{0.0f}}; - -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - mma_B B; - half2 dsB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dsB[l] = y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n][k01/8], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]); - tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]); - } - } - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l]; - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q5_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { #ifdef INT8_MMA_AVAILABLE int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); - int * x_sc = (int *) (x_dm + WARP_SIZE/QI5_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); int * x_qs = (int *) x_tile; @@ -1678,9 +1456,6 @@ template static __device__ __forceinlin int * x_sc = (int *) (x_dm + txs.dm); #endif // INT8_MMA_AVAILABLE - const int kbx = 0; // threadIdx.x / QI5_K - const int kqsx = threadIdx.x; // threadIdx.x % QI5_K - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -1689,73 +1464,91 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbx; - const int ky = QR5_K*kqsx; + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const int ky = QR5_K*threadIdx.x; - const int ql = get_int_b4(bxi->qs, kqsx); + const int ql = get_int_b4(bxi->qs, threadIdx.x); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_b4(bxi->qh, kqsx % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010; const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + (QI5_K/4); + const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4; #ifdef INT8_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq0] = ql0 | qh0; - x_qs[i*MMQ_MMA_TILE_X_K_Q5_K + kq1] = ql1 | qh1; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; #endif // INT8_MMA_AVAILABLE } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 - const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 +#ifdef INT8_MMA_AVAILABLE #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_K) { - int i = (i0 + threadIdx.y * QI5_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { + int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + kbxd; + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % (WARP_SIZE/16); + + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); + + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; + + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + +#pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } + } -#ifdef INT8_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q5_K + kbxd] = bxi->dm; #else - x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K + kbxd] = bxi->dm; -#endif // INT8_MMA_AVAILABLE + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) { + int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y; + + if (need_check) { + i = min(i, i_max); + } + + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + + x_dm[i] = bxi->dm; } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { + int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI5_K/8); + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; const int * scales = (const int *) bxi->scales; const int ksc = threadIdx.x % (WARP_SIZE/8); + const int scales8 = unpack_scales_q45_K(scales, ksc); - // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 - int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits - scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - -#ifdef INT8_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + ksc] = scales8; -#else - x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; -#endif // INT8_MMA_AVAILABLE + x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; } +#endif // INT8_MMA_AVAILABLE } template @@ -1785,122 +1578,12 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, - x_dm[i*(WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template -static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( - const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) { -#ifdef INT8_MMA_AVAILABLE - - typedef mma_int_A_I16K8 mma_A; - typedef mma_int_B_J8K8 mma_B; - typedef mma_int_C_I16J8 mma_C; - - constexpr int granularity = mmq_get_granularity_device(mmq_x); - constexpr int rows_per_warp = 2 * granularity; - constexpr int ntx = rows_per_warp/mma_C::I; // Number of x minitiles per warp. - - y += (threadIdx.y % ntx) * (mma_B::J*MMQ_TILE_Y_K); - - const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; - const int * x_sc = (const int *) x_dm + WARP_SIZE/QI5_K; - const int * y_qs = (const int *) y + 4; - const half2 * y_ds = (const half2 *) y; - - const int i0 = (threadIdx.y / ntx) * (ntx*mma_A::I); - - mma_A A[ntx][4]; - int scA[ntx][mma_C::ne/2][4]; - int mA[ntx][mma_C::ne/2][4]; - half2 dmA[ntx][mma_C::ne/2]; - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - const int k0 = k00 + k01; - - A[n][k01/8].load(x_qs + (i0 + n*mma_A::I)*MMQ_MMA_TILE_X_K_Q5_K + k0, MMQ_MMA_TILE_X_K_Q5_K); - } - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - - const int sc_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 0)]; - const int m_packed = x_sc[i*MMQ_MMA_TILE_X_K_Q5_K + (k00/32 + 2)]; - - const uint8_t * sc = (const uint8_t *) &sc_packed; - const uint8_t * m = (const uint8_t *) &m_packed; - -#pragma unroll - for (int ksc = 0; ksc < sizeof(int); ++ksc) { - scA[n][l][ksc] = sc[ksc]; - mA[n][l][ksc] = m[ksc]; - } - } - - #pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int i = i0 + n*mma_C::I + mma_C::get_i(2*l); - - dmA[n][l] = x_dm[i*MMQ_MMA_TILE_X_K_Q5_K]; - } - } - -#pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += ntx*mma_C::J) { - float tmpd[ntx][mma_C::ne] = {{0.0f}}; - float tmpm[ntx][mma_C::ne] = {{0.0f}}; - -#pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { - const int k0 = k00 + k01; - - mma_B B; - half2 dsB[mma_C::ne/2]; - - B.load(y_qs + j0*MMQ_TILE_Y_K + k0 % WARP_SIZE, MMQ_TILE_Y_K); - -#pragma unroll - for (int l = 0; l < mma_C::ne/2; ++l) { - const int j = j0 + mma_C::get_j(l); - - dsB[l] = y_ds[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]; - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { - mma_C C; - C.mma_K8(A[n][k01/8], B); - -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - tmpd[n][l] += (C.x[l]*scA[n][l/2][k01/8]) * __low2float(dsB[l%2]); - tmpm[n][l] += mA[n][l/2][k01/8] * __high2float(dsB[l%2]); - } - } - } - -#pragma unroll - for (int n = 0; n < ntx; ++n) { -#pragma unroll - for (int l = 0; l < mma_C::ne; ++l) { - sum[(j0/mma_C::J + n)*mma_C::ne + l] += __low2float(dmA[n][l/2])*tmpd[n][l] - __high2float(dmA[n][l/2])*tmpm[n][l]; - } - } - } -#else - GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); - NO_DEVICE_CODE; -#endif // INT8_MMA_AVAILABLE -} - template static __device__ __forceinline__ void load_tiles_q6_K( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -1915,9 +1598,6 @@ template static __device__ __forceinlin int * x_sc = (int *) (x_df + txs.dm); #endif // INT8_MMA_AVAILABLE - const int kbx = 0; // threadIdx.x / QI6_K - const int kqsx = threadIdx.x; // threadIdx.x % QI6_K - #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { int i = i0 + threadIdx.y; @@ -1926,19 +1606,18 @@ template static __device__ __forceinlin i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbx; - const int ky = QR6_K*kqsx; + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; - const int ql = get_int_b2(bxi->ql, kqsx); + const int ql = get_int_b2(bxi->ql, threadIdx.x); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); - const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; - const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4)); + const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030; + const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030; - const int kq0 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + 0; - const int kq1 = ky - ky % QI6_K + threadIdx.x % (QI6_K/2) + (QI6_K/2); + const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0; + const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2; #ifdef INT8_MMA_AVAILABLE x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); @@ -2187,6 +1866,358 @@ template static __device__ __forceinlin } } +template static __device__ __forceinline__ void load_tiles_iq2_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI2_XXS/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xxs * bxi = (const block_iq2_xxs *) x + kbx0 + i*stride; + + const int q2 = get_int_b2(bxi->qs, 2*kqsx+0); + const uint8_t * aux8 = (const uint8_t *) &q2; + const uint32_t aux32 = get_int_b2(bxi->qs, 2*kqsx+1); + +#pragma unroll + for (int l = 0; l < QR2_XXS; ++l) { + const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); + const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + + const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); + const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + + const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); + const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_xs( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI2_XS/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_xs * bxi = (const block_iq2_xs *) x + kbx0 + i*stride; + + const int2 q2_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint16_t * q2 = (const uint16_t *) &q2_packed; + + #pragma unroll + for (int l = 0; l < QR2_XS; ++l) { + const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); + const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + + const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq2_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI2_S/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq2_s * bxi = (const block_iq2_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->qs, QK_K/32 + kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR2_S; ++l) { + const int * grid_pos = (const int *)(iq2s_grid + (qs[l] | ((qh << (8-2*l)) & 0x300))); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = bxi->scales[kqsx]; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#else + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_xxs( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI3_XXS/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_xxs * bxi = (const block_iq3_xxs *) x + kbx0 + i*stride; + + const int2 q3_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * q3 = (const uint8_t *) &q3_packed; + const uint32_t aux32 = get_int_b2(bxi->qs, QK_K/16 + kqsx); + +#pragma unroll + for (int l = 0; l < QR3_XXS; ++l) { + const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + + const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + + const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); + const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = aux32 >> 28; + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq3_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + float * x_df = (float *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % (QI3_S/2); + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) { + int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2); + + if (need_check) { + i = min(i, i_max); + } + + const block_iq3_s * bxi = (const block_iq3_s *) x + kbx0 + i*stride; + + const int2 qs_packed = make_int2(get_int_b2(bxi->qs, 2*kqsx+0), get_int_b2(bxi->qs, 2*kqsx+1)); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + const int signs_packed_32 = get_int_b2(bxi->signs, kqsx); + const uint8_t * signs_packed_8 = (const uint8_t *) &signs_packed_32; + +#pragma unroll + for (int l = 0; l < QR3_S; ++l) { + const int2 grid_pos = make_int2( + iq3s_grid[qs[2*l+0] | ((qh << (8 - 2*l)) & 0x100)], + iq3s_grid[qs[2*l+1] | ((qh << (7 - 2*l)) & 0x100)]); + + const int signs0 = __vcmpne4(((signs_packed_8[l] & 0x03) << 7) | ((signs_packed_8[l] & 0x0C) << 21), 0x00000000); + const int signs1 = __vcmpne4(((signs_packed_8[l] & 0x30) << 3) | ((signs_packed_8[l] & 0xC0) << 17), 0x00000000); + + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h; +#endif // INT8_MMA_AVAILABLE + } + + const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); + const float d = bxi->d; +#ifdef INT8_MMA_AVAILABLE + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; +#else + x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d; +#endif // INT8_MMA_AVAILABLE + } +} + +template static __device__ __forceinline__ void load_tiles_iq1_s( + const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { + +#ifdef INT8_MMA_AVAILABLE + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2); +#else + constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); + int * x_qs = (int *) x_tile; + half2 * x_ds = (half2 *) (x_qs + txs.qs); +#endif // INT8_MMA_AVAILABLE + + const int kqsx = threadIdx.x % QI1_S; + +#pragma unroll + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) { + int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S; + + if (need_check) { + i = min(i, i_max); + } + + const block_iq1_s * bxi = (const block_iq1_s *) x + kbx0 + i*stride; + + const int qs_packed = get_int_b2(bxi->qs, kqsx); + const uint8_t * qs = (const uint8_t *) &qs_packed; + + const int qh = bxi->qh[kqsx]; + + #pragma unroll + for (int l = 0; l < QR1_S/2; ++l) { + const int grid = iq1s_grid_gpu[qs[l] | (((qh >> (3*l)) & 0x07) << 8)]; + + const int grid0 = (grid >> 0) & 0x0F0F0F0F; + const int grid1 = (grid >> 4) & 0x0F0F0F0F; + +#ifdef INT8_MMA_AVAILABLE + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; +#else + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1; +#endif // INT8_MMA_AVAILABLE + } + + const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); + const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); + +#ifdef INT8_MMA_AVAILABLE + x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); +#else + x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); +#endif // INT8_MMA_AVAILABLE + } +} + template static __device__ __forceinline__ void load_tiles_iq4_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) { @@ -2320,7 +2351,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; }; @@ -2328,7 +2359,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; }; @@ -2336,7 +2367,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -2352,7 +2383,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -2368,7 +2399,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q3_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; }; @@ -2376,7 +2407,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q4_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; }; @@ -2384,7 +2415,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q5_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; }; @@ -2396,11 +2427,59 @@ struct mmq_type_traits { static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; +}; + +template +struct mmq_type_traits { + static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; +}; + template struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -2408,7 +2487,7 @@ template struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; @@ -2837,6 +2916,12 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q3_K); extern DECL_MMQ_CASE(GGML_TYPE_Q4_K); extern DECL_MMQ_CASE(GGML_TYPE_Q5_K); extern DECL_MMQ_CASE(GGML_TYPE_Q6_K); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); +extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S); +extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL); extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS); diff --git a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py index ffeb3c27d..d7874e6ea 100755 --- a/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +++ b/ggml/src/ggml-cuda/template-instances/generate_cu_files.py @@ -23,7 +23,8 @@ SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block} TYPES_MMQ = [ "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", - "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" + "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", + "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" ] SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu new file mode 100644 index 000000000..84ec85029 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ1_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu new file mode 100644 index 000000000..583c4e5a5 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu new file mode 100644 index 000000000..edaf1560d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu new file mode 100644 index 000000000..233d9342c --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ2_XXS); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu new file mode 100644 index 000000000..6092dc713 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_S); diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu new file mode 100644 index 000000000..1d5bd201f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../mmq.cuh" + +DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS); diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 6a17d0f3e..40091a0ef 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -188,6 +188,27 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp return sumi*d8d8 + m8s8 / (QI8_1 / vdr); } +template static __device__ __forceinline__ float vec_dot_q8_0_16_q8_1_impl( + const int * v, const int * u, const float * d8_0, const float & d8_1) { + + float sumf = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < vdr; i0 += QI8_0/2) { + int sumi = 0; + +#pragma unroll + for (int i = i0; i < i0 + QI8_0/2; ++i) { + // SIMD dot product of quantized values + sumi = ggml_cuda_dp4a(v[i], u[i], sumi); + } + + sumf += d8_0[i0/(QI8_0/2)]*sumi; + } + + return d8_1*sumf; +} + #define VDR_Q2_K_Q8_1_MMVQ 1 #define VDR_Q2_K_Q8_1_MMQ 4