mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
CUDA: use tensor cores for MMQ (#7676)
* CUDA: int8 tensor cores for MMQ (legacy quants) * fix out-of-bounds writes * __builtin_assume -> GGML_CUDA_ASSUME * fix writeback returning too early
This commit is contained in:
parent
af4ae502dd
commit
1f0dabda8d
@ -139,6 +139,7 @@
|
||||
#define CC_PASCAL 600
|
||||
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
|
||||
#define CC_VOLTA 700
|
||||
#define CC_TURING 750
|
||||
#define CC_AMPERE 800
|
||||
#define CC_OFFSET_AMD 1000000
|
||||
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
|
||||
@ -326,9 +327,17 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
|
||||
#endif // defined(__HIP_PLATFORM_AMD__) && HIP_VERSION < 50600000
|
||||
#endif // defined(GGML_USE_HIPBLAS)
|
||||
|
||||
#define FP16_AVAILABLE (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||
#if (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||
#define FP16_AVAILABLE
|
||||
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||
|
||||
#define FP16_MMA_AVAILABLE !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
#define FP16_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||
#define INT8_MMA_AVAILABLE
|
||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
|
||||
|
||||
static bool fast_fp16_available(const int cc) {
|
||||
return cc >= CC_PASCAL && cc != 610;
|
||||
@ -338,6 +347,10 @@ static bool fp16_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
|
||||
}
|
||||
|
||||
static bool int8_mma_available(const int cc) {
|
||||
return cc < CC_OFFSET_AMD && cc >= CC_TURING;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
static __device__ void no_device_code(
|
||||
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
|
||||
@ -379,7 +392,7 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
|
||||
#pragma unroll
|
||||
@ -412,7 +425,7 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
|
||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
||||
return __float2half(fmaxf(__half2float(a), __half2float(b)));
|
||||
|
@ -74,7 +74,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0(
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
@ -122,7 +122,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1(
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
@ -181,7 +181,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0(
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
@ -236,7 +236,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1(
|
||||
|
||||
const int sumi = __dp4a(v, u, 0);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_ds = (const half2 *) Q_ds_v;
|
||||
|
||||
@ -314,7 +314,7 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16(
|
||||
GGML_UNUSED(Q_q8);
|
||||
GGML_UNUSED(Q_ds_v);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
const half2 * Q_h2 = (const half2 *) Q_v;
|
||||
|
||||
@ -407,7 +407,7 @@ static __device__ __forceinline__ T dequantize_1_q4_0(const void * __restrict__
|
||||
const int q0 = x[ib].qs[iqs];
|
||||
const int q = ((q0 >> (4*shift)) & 0x0F) - 8;
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return ((half) d)*((half) q);
|
||||
}
|
||||
@ -428,7 +428,7 @@ static __device__ __forceinline__ T dequantize_1_q4_1(const void * __restrict__
|
||||
const int q0 = x[ib].qs[iqs];
|
||||
const int q = ((q0 >> (4*shift)) & 0x0F);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return __low2half(dm)*((half) q) + __high2half(dm);
|
||||
}
|
||||
@ -453,7 +453,7 @@ static __device__ __forceinline__ T dequantize_1_q5_0(const void * __restrict__
|
||||
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
||||
const int q = (ql | qh) - 16;
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return ((half) d)*((half) q);
|
||||
}
|
||||
@ -478,7 +478,7 @@ static __device__ __forceinline__ T dequantize_1_q5_1(const void * __restrict__
|
||||
const int qh = ((qh0 >> idq) << 4) & 0x10;
|
||||
const int q = (ql | qh);
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return __low2half(dm)*((half) q) + __high2half(dm);
|
||||
}
|
||||
@ -497,7 +497,7 @@ static __device__ __forceinline__ T dequantize_1_q8_0(const void * __restrict__
|
||||
const T d = x[ib].d;
|
||||
const int q = x[ib].qs[iqs];
|
||||
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
if (std::is_same<T, half>::value) {
|
||||
return ((half) d)*((half) q);
|
||||
}
|
||||
|
@ -43,7 +43,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
|
||||
|
@ -40,7 +40,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if FP16_AVAILABLE
|
||||
#ifdef FP16_AVAILABLE
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
constexpr vec_dot_KQ_f16_t vec_dot_KQ = get_vec_dot_KQ_f16<D>(type_K);
|
||||
|
@ -1,9 +1,9 @@
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
#if FP16_MMA_AVAILABLE
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
#include <mma.h>
|
||||
#endif
|
||||
#endif // FP16_MMA_AVAILABLE
|
||||
|
||||
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
|
||||
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t>
|
||||
@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const int ne1,
|
||||
const int ne2,
|
||||
const int ne3) {
|
||||
#if FP16_MMA_AVAILABLE
|
||||
#ifdef FP16_MMA_AVAILABLE
|
||||
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
|
||||
|
||||
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
|
||||
|
95
ggml-cuda/mma.cuh
Normal file
95
ggml-cuda/mma.cuh
Normal file
@ -0,0 +1,95 @@
|
||||
#include "common.cuh"
|
||||
|
||||
struct mma_int_A_I16K8 {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int K = 8;
|
||||
static constexpr int ne = 4;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int l) {
|
||||
const int ret = (l/2) * (K/2) + threadIdx.x % (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_B_J8K8 {
|
||||
static constexpr int J = 8;
|
||||
static constexpr int K = 8;
|
||||
static constexpr int ne = 2;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int /* l */) {
|
||||
const int ret = threadIdx.x / (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_k(const int l) {
|
||||
const int ret = l * (K/2) + threadIdx.x % (K/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < K);
|
||||
return ret;
|
||||
}
|
||||
};
|
||||
|
||||
struct mma_int_C_I16J8 {
|
||||
static constexpr int I = 16;
|
||||
static constexpr int J = 8;
|
||||
static constexpr int ne = 4;
|
||||
|
||||
int x[ne] = {0};
|
||||
|
||||
static __device__ __forceinline__ int get_i(const int l) {
|
||||
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < I);
|
||||
return ret;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int get_j(const int l) {
|
||||
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
|
||||
GGML_CUDA_ASSUME(ret >= 0);
|
||||
GGML_CUDA_ASSUME(ret < J);
|
||||
return ret;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
#if __CUDA_ARCH__ >= CC_AMPERE
|
||||
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
|
||||
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
|
||||
#else
|
||||
// On Turing m16n8k32 mma is not available, use 4x m8n8k16 mma instead:
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "r"(mma_A.x[0]), "r"(mma_B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[1]), "r"(mma_B.x[0]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[0]), "+r"(x[1])
|
||||
: "r"(mma_A.x[2]), "r"(mma_B.x[1]));
|
||||
asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
|
||||
: "+r"(x[2]), "+r"(x[3])
|
||||
: "r"(mma_A.x[3]), "r"(mma_B.x[1]));
|
||||
#endif // __CUDA_ARCH__ >= CC_AMPERE
|
||||
#else
|
||||
GGML_UNUSED(mma_A);
|
||||
GGML_UNUSED(mma_B);
|
||||
NO_DEVICE_CODE;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
}
|
||||
};
|
@ -2,6 +2,7 @@
|
||||
|
||||
#include "common.cuh"
|
||||
#include "vecdotq.cuh"
|
||||
#include "mma.cuh"
|
||||
|
||||
#include <climits>
|
||||
#include <cstdint>
|
||||
@ -14,6 +15,7 @@ typedef void (*load_tiles_mmq_t)(
|
||||
typedef void (*vec_dot_mmq_t)(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0);
|
||||
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1);
|
||||
|
||||
struct block_q8_1_mmq {
|
||||
half2 ds[4];
|
||||
@ -141,15 +143,15 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
const float * x_dmf = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
const float * x_df = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
@ -170,12 +172,76 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mul_mat(
|
||||
}
|
||||
|
||||
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl<VDR_Q4_0_Q8_1_MMQ>
|
||||
(&x_ql[i*(WARP_SIZE + 1) + k0], u, x_dmf[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
|
||||
(&x_ql[i*(WARP_SIZE + 1) + k0], u, x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0],
|
||||
y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const float * x_df = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
mma_A A;
|
||||
float dA[mma_C::ne/2];
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = k0 + mma_A::get_k(l) % QI4_0;
|
||||
const int shift = 4*(mma_A::get_k(l) / QI4_0);
|
||||
|
||||
A.x[l] = __vsubss4((x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F, 0x08080808);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dA[l] = x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
|
||||
}
|
||||
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
half2 dsB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C.mma_K8(A, B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
||||
@ -215,7 +281,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@ -249,6 +315,70 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
mma_A A;
|
||||
half2 dmA[mma_C::ne/2];
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = k0 + mma_A::get_k(l) % QI4_0;
|
||||
const int shift = 4*(mma_A::get_k(l) / QI4_0);
|
||||
|
||||
A.x[l] = (x_ql[i*(WARP_SIZE + 1) + k] >> shift) & 0x0F0F0F0F;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dmA[l] = x_dm[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/QI4_0];
|
||||
}
|
||||
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
half2 dsB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C.mma_K8(A, B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
|
||||
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
||||
@ -308,7 +438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@ -343,6 +473,68 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const float * x_df = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const float * y_df = (const float *) y;
|
||||
|
||||
mma_A A;
|
||||
float dA[mma_C::ne/2];
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = 2*(k0 + mma_A::get_k(l) % QI5_0) + mma_A::get_k(l) / QI5_0;
|
||||
|
||||
A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dA[l] = x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + k0/QI5_0];
|
||||
}
|
||||
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
float dB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C.mma_K8(A, B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
@ -400,7 +592,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@ -434,6 +626,69 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const half2 * y_ds = (const half2 *) y;
|
||||
|
||||
mma_A A;
|
||||
half2 dmA[mma_C::ne/2];
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = 2*(k0 + mma_A::get_k(l) % QI5_1) + mma_A::get_k(l) / QI5_1;
|
||||
|
||||
A.x[l] = x_ql[i*(2*WARP_SIZE + 1) + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dmA[l] = x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI5_1];
|
||||
}
|
||||
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
half2 dsB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = (2*k0 + mma_B::get_k(l)) % WARP_SIZE;
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dsB[l] = y_ds[j*MMQ_TILE_Y_K + (2*k0/QI8_1) % (WARP_SIZE/QI8_1)];
|
||||
}
|
||||
|
||||
C.mma_K8(A, B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
const half2 dmA_dsB = dmA[l/2]*dsB[l%2];
|
||||
sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
||||
@ -475,7 +730,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
@ -500,6 +755,69 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps>
|
||||
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
|
||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y, float * __restrict__ sum, const int & k0) {
|
||||
|
||||
GGML_UNUSED(x_qh); GGML_UNUSED(x_sc);
|
||||
|
||||
typedef mma_int_A_I16K8 mma_A;
|
||||
typedef mma_int_B_J8K8 mma_B;
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const float * x_df = (const float *) x_dm;
|
||||
const int * y_qs = (const int *) y + 4;
|
||||
const float * y_df = (const float *) y;
|
||||
|
||||
mma_A A;
|
||||
float dA[mma_C::ne/2];
|
||||
|
||||
const int i0 = threadIdx.y*mma_A::I;
|
||||
static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_A::ne; ++l) {
|
||||
const int i = i0 + mma_A::get_i(l);
|
||||
const int k = k0 + mma_A::get_k(l);
|
||||
|
||||
A.x[l] = x_ql[i*(WARP_SIZE + 1) + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int i = i0 + mma_C::get_i(2*l);
|
||||
|
||||
dA[l] = x_df[i*(WARP_SIZE/QI8_0) + i/QI8_0 + k0/QI8_0];
|
||||
}
|
||||
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_int_B_J8K8::J) {
|
||||
mma_C C;
|
||||
mma_B B;
|
||||
float dB[mma_C::ne/2];
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_B::ne; ++l) {
|
||||
const int j = j0 + mma_B::get_j(l);
|
||||
const int k = k0 + mma_B::get_k(l);
|
||||
|
||||
B.x[l] = y_qs[j*MMQ_TILE_Y_K + k];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne/2; ++l) {
|
||||
const int j = j0 + mma_C::get_j(l);
|
||||
|
||||
dB[l] = y_df[j*MMQ_TILE_Y_K + k0/QI8_1];
|
||||
}
|
||||
|
||||
C.mma_K8(A, B);
|
||||
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
|
||||
const char * __restrict__ x, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||
int * __restrict__ x_sc, const int & kbx0, const int & i_max, const int & stride) {
|
||||
@ -989,6 +1307,57 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mul_mat(
|
||||
}
|
||||
}
|
||||
|
||||
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
static __device__ __forceinline__ void mmq_write_back_dp4a(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
|
||||
|
||||
if (j >= ne1) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||
const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
|
||||
|
||||
if (need_check && i >= ne0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
static __device__ __forceinline__ void mmq_write_back_mma(const float * __restrict__ sum, float * __restrict__ dst, const int & ne0, const int & ne1) {
|
||||
typedef mma_int_C_I16J8 mma_C;
|
||||
|
||||
const int i0 = threadIdx.y*mma_C::I;
|
||||
static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y");
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < mma_C::ne; ++l) {
|
||||
const int j = blockIdx.y*mmq_x + j0 + mma_C::get_j(l);
|
||||
|
||||
if (j >= ne1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int i = blockIdx.x*mmq_y + i0 + mma_C::get_i(l);
|
||||
|
||||
if (need_check && i >= ne0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[j*ne0 + i] = sum[(j0/mma_C::J)*mma_C::ne + l];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check, ggml_type type>
|
||||
@ -998,35 +1367,65 @@ template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_0> {
|
||||
static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_1> {
|
||||
static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_0> {
|
||||
static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_1> {
|
||||
static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_1_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q8_0> {
|
||||
static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
#ifdef INT8_MMA_AVAILABLE
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_mma<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#else
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
#endif // INT8_MMA_AVAILABLE
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
@ -1034,6 +1433,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q2_K> {
|
||||
static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q2_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
@ -1041,6 +1441,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q3_K> {
|
||||
static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q3_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
@ -1048,6 +1449,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q4_K> {
|
||||
static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q4_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
@ -1055,6 +1457,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q5_K> {
|
||||
static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q5_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
};
|
||||
|
||||
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
|
||||
@ -1062,6 +1465,7 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_Q6_K> {
|
||||
static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ;
|
||||
static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K<mmq_y, nwarps, need_check>;
|
||||
static constexpr vec_dot_mmq_t vec_dot = vec_dot_q6_K_q8_1_mul_mat<mmq_x, mmq_y, nwarps>;
|
||||
static constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, nwarps, need_check>;
|
||||
};
|
||||
|
||||
static int mmq_need_sum(const ggml_type type_x) {
|
||||
@ -1118,6 +1522,7 @@ static __global__ void mul_mat_q(
|
||||
constexpr int vdr = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vdr;
|
||||
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
|
||||
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::vec_dot;
|
||||
constexpr mmq_write_back_t write_back = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::write_back;
|
||||
|
||||
constexpr tile_x_sizes txs = get_tile_x_sizes_device<mmq_y>(type);
|
||||
|
||||
@ -1137,7 +1542,7 @@ static __global__ void mul_mat_q(
|
||||
|
||||
const int * y = (const int *) yc + blockIdx.y*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
|
||||
|
||||
float sum[(mmq_x/nwarps) * (mmq_y/WARP_SIZE)] = {0.0f};
|
||||
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
|
||||
|
||||
for (int kb0 = 0; kb0 < blocks_per_row_x; kb0 += blocks_per_warp) {
|
||||
|
||||
@ -1164,25 +1569,7 @@ static __global__ void mul_mat_q(
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
|
||||
const int j = blockIdx.y*mmq_x + j0 + threadIdx.y;
|
||||
|
||||
if (j >= ne1) {
|
||||
return;
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
|
||||
const int i = blockIdx.x*mmq_y + i0 + threadIdx.x;
|
||||
|
||||
if (need_check && i >= ne0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
dst[j*ne0 + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
|
||||
}
|
||||
}
|
||||
write_back(sum, dst, ne0, ne1);
|
||||
}
|
||||
|
||||
struct mmq_args {
|
||||
@ -1256,10 +1643,10 @@ void mul_mat_q_case(const mmq_args & args, cudaStream_t stream) {
|
||||
launch_mul_mat_q<type, 8, 4>(args, stream);
|
||||
break;
|
||||
case 16:
|
||||
launch_mul_mat_q<type, 16, 8>(args, stream);
|
||||
launch_mul_mat_q<type, 16, 4>(args, stream);
|
||||
break;
|
||||
case 24:
|
||||
launch_mul_mat_q<type, 24, 8>(args, stream);
|
||||
launch_mul_mat_q<type, 24, 4>(args, stream);
|
||||
break;
|
||||
case 32:
|
||||
launch_mul_mat_q<type, 32, 8>(args, stream);
|
||||
|
Loading…
Reference in New Issue
Block a user