From f514d1b306e1114c2884fcb25dd9bd48ae64ba32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 5 Aug 2023 18:20:44 +0200 Subject: [PATCH] CUDA: faster k-quant mul_mat_q kernels (#2525) --- ggml-cuda.cu | 889 ++++++++++++++++++++++++++++++--------------------- 1 file changed, 518 insertions(+), 371 deletions(-) diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d64d7045c..9d42efb0d 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -1383,8 +1383,10 @@ template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp sumi = __dp4a(vi1, u[2*i+1], sumi); } + const float2 ds8f = __half22float2(ds8); + // second part effectively subtracts 8 from each quant value - return d4 * (sumi * __half2float(ds8.x) - (8*vdr/QI4_0) * __half2float(ds8.y)); + return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y); #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1410,12 +1412,14 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp } #ifdef GGML_CUDA_F16 - const half2 tmp = __hmul2(dm4, ds8); - const float d4d8 = __half2float(tmp.x); - const float m4s8 = __half2float(tmp.y); + const float2 tmp = __half22float2(__hmul2(dm4, ds8)); + const float d4d8 = tmp.x; + const float m4s8 = tmp.y; #else - const float d4d8 = __half2float(dm4.x) * __half2float(ds8.x); - const float m4s8 = __half2float(dm4.y) * __half2float(ds8.y); + const float2 dm4f = __half22float2(dm4); + const float2 ds8f = __half22float2(ds8); + const float d4d8 = dm4f.x * ds8f.x; + const float m4s8 = dm4f.y * ds8f.y; #endif // GGML_CUDA_F16 // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it @@ -1434,6 +1438,7 @@ template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; +#pragma unroll for (int i = 0; i < vdr; ++i) { int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 @@ -1450,8 +1455,10 @@ template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp sumi = __dp4a(vi1, u[2*i+1], sumi); // SIMD dot product of quantized values } + const float2 ds8f = __half22float2(ds8); + // second part effectively subtracts 16 from each quant value - return d5 * (sumi*__half2float(ds8.x) - (16*vdr/QI5_0) * __half2float(ds8.y)); + return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y); #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1466,6 +1473,7 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; +#pragma unroll for (int i = 0; i < vdr; ++i) { int vi0 = (vl[i] >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh as 5th bits vi0 |= (vh[i] << 4) & 0x00000010; // 0 -> 4 @@ -1483,12 +1491,14 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp } #ifdef GGML_CUDA_F16 - const half2 tmp = __hmul2(dm5, ds8); - const float d5d8 = __half2float(tmp.x); - const float m5s8 = __half2float(tmp.y); + const float2 tmp = __half22float2(__hmul2(dm5, ds8)); + const float d5d8 = tmp.x; + const float m5s8 = tmp.y; #else - const float d5d8 = __half2float(dm5.x) * __half2float(ds8.x); - const float m5s8 = __half2float(dm5.y) * __half2float(ds8.y); + const float2 dm5f = __half22float2(dm5); + const float2 ds8f = __half22float2(ds8); + const float d5d8 = dm5f.x * ds8f.x; + const float m5s8 = dm5f.y * ds8f.y; #endif // GGML_CUDA_F16 // scale second part of sum by QI5_1 / vdr to compensate for multiple threads adding it @@ -1503,17 +1513,18 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp #define VDR_Q8_0_Q8_1_MMQ 8 template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( - const int * v, const int * u, const float & d8_0, const half2 & ds8_1) { + const int * v, const int * u, const float & d8_0, const float & d8_1) { #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; +#pragma unroll for (int i = 0; i < vdr; ++i) { // SIMD dot product of quantized values sumi = __dp4a(v[i], u[i], sumi); } - return sumi * d8_0 * __half2float(ds8_1.x); + return d8_0*d8_1 * sumi; #else return 0.0f; // only to satisfy the compiler #endif // __CUDA_ARCH__ >= MIN_CC_DP4A @@ -1525,18 +1536,21 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp #if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics int sumi = 0; +#pragma unroll for (int i = 0; i < vdr; ++i) { // SIMD dot product of quantized values sumi = __dp4a(v[i], u[i], sumi); } #ifdef GGML_CUDA_F16 - const half2 tmp = __hmul2(dm8, ds8); - const float d8d8 = __half2float(tmp.x); - const float m8s8 = __half2float(tmp.y); + const float2 tmp = __half22float2(__hmul2(dm8, ds8)); + const float d8d8 = tmp.x; + const float m8s8 = tmp.y; #else - const float d8d8 = __half2float(dm8.x) * __half2float(ds8.x); - const float m8s8 = __half2float(dm8.y) * __half2float(ds8.y); + const float2 dm8f = __half22float2(dm8); + const float2 ds8f = __half22float2(ds8); + const float d8d8 = dm8.x * ds8.x; + const float m8s8 = dm8.y * ds8.y; #endif // GGML_CUDA_F16 // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it @@ -1546,6 +1560,312 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp #endif // __CUDA_ARCH__ >= MIN_CC_DP4A } +#define VDR_Q2_K_Q8_1_MMVQ 1 +#define VDR_Q2_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( + const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR2_K; ++i) { + const int sc = scales[2*i]; + + const int vi = (v >> (2*i)) & 0x03030303; + + sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + sumf_m += d8[i] * __dp4a(m, u[i], 0); // multiply constant q2_K part with sum of q8_1 values + } + + const float2 dm2f = __half22float2(dm2); + + return dm2f.x*sumf_d - dm2f.y*sumf_m; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const half2 & dm2, const float & d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi_d = 0; + int sumi_m = 0; + +#pragma unroll + for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) { + int sumi_d_sc = 0; + + const int sc = scales[i0 / (QI8_1/2)]; + + // fill int with 4x m + int m = sc >> 4; + m |= m << 8; + m |= m << 16; + +#pragma unroll + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_d_sc = __dp4a(v[i], u[i], sumi_d_sc); // SIMD dot product + sumi_m = __dp4a(m, u[i], sumi_m); // multiply sum of q8_1 values with m + } + + sumi_d += sumi_d_sc * (sc & 0xF); + } + + const float2 dm2f = __half22float2(dm2); + + return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m); +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q3_K_Q8_1_MMVQ 1 +#define VDR_Q3_K_Q8_1_MMQ 2 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, + const int & scale_offset, const float & d3, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR3_K; ++i) { + const int isc = scale_offset + 2*i; + + const int isc_low = isc % (QK_K/32); + const int sc_shift_low = 4 * (isc / (QK_K/32)); + const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; + + const int isc_high = isc % (QK_K/64); + const int sc_shift_high = 2 * (isc / (QK_K/64)); + const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; + + const int sc = (sc_low | sc_high) - 32; + + const int vil = (vl >> (2*i)) & 0x03030303; + + const int vih = ((vh >> i) << 2) & 0x04040404; + + const int vi = __vsubss4(vil, vih); + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d3 * sumf; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d3, const float & d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + int sumi = 0; + +#pragma unroll + for (int i0 = 0; i0 < QR3_K*VDR_Q3_K_Q8_1_MMQ; i0 += QI8_1/2) { + int sumi_sc = 0; + + for (int i = i0; i < i0 + QI8_1/2; ++i) { + sumi_sc = __dp4a(v[i], u[i], sumi_sc); // SIMD dot product + } + + sumi += sumi_sc * scales[i0 / (QI8_1/2)]; + } + + return d3*d8 * sumi; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q4_K_Q8_1_MMVQ 2 +#define VDR_Q4_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR4_K; ++i) { + const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; + const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; + + const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; + +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +// also used for q5_K +static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < VDR_Q4_K_Q8_1_MMQ; i0 += (QI8_1/QR4_K)) { + int sumi_d = 0; + +#pragma unroll + for (int i = i0; i < i0 + (QI8_1/QR4_K); ++i) { + sumi_d = __dp4a(v[2*i+0], u[2*i+0], sumi_d); // SIMD dot product + sumi_d = __dp4a(v[2*i+1], u[2*i+1], sumi_d); // SIMD dot product + } + + const float2 ds8f = __half22float2(ds8[i0 / 4]); + + sumf_d += ds8f.x * (sc[i0/4] * sumi_d); + sumf_m += ds8f.y * m[i0/4]; // sum of q8_1 block * q4_K min val + } + + const float2 dm4f = __half22float2(dm4); + + return dm4f.x*sumf_d - dm4f.y*sumf_m; + +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q5_K_Q8_1_MMVQ 2 +#define VDR_Q5_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl( + const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, + const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + float sumf_m = 0.0f; + +#pragma unroll + for (int i = 0; i < QR5_K; ++i) { + const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; + const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; + + const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; + const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; + + const int v0i = vl0i | vh0i; + const int v1i = vl1i | vh1i; + + const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product + const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u + + sumf_d += d8[i] * (dot1 * sc[i]); + sumf_m += d8[i] * (dot2 * m[i]); + + } + + const float2 dm5f = __half22float2(dm5); + + return dm5f.x*sumf_d - dm5f.y*sumf_m; + +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +#define VDR_Q6_K_Q8_1_MMVQ 1 +#define VDR_Q6_K_Q8_1_MMQ 8 + +// contiguous v/x values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( + const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, + const float & d, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf = 0.0f; + +#pragma unroll + for (int i = 0; i < QR6_K; ++i) { + const int sc = scales[4*i]; + + const int vil = (vl >> (4*i)) & 0x0F0F0F0F; + + const int vih = ((vh >> (4*i)) << 4) & 0x30303030; + + const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 + + sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product + } + + return d*sumf; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + +// contiguous u/y values +static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( + const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, + const float & d6, const float * __restrict__ d8) { + +#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics + float sumf_d = 0.0f; + +#pragma unroll + for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) { + int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale + +#pragma unroll + for (int i = i0; i < i0 + 2; ++i) { + sumi_d.x = __dp4a(v[2*i+0], u[2*i+0], sumi_d.x); // SIMD dot product + sumi_d.x = __dp4a(v[2*i+1], u[2*i+1], sumi_d.x); // SIMD dot product + + sumi_d.y = __dp4a(v[2*i+4], u[2*i+4], sumi_d.y); // SIMD dot product + sumi_d.y = __dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product + } + + sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y); + } + + return d6 * sumf_d; + +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= MIN_CC_DP4A +} + static __device__ __forceinline__ float vec_dot_q4_0_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { @@ -1631,6 +1951,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( __builtin_assume(j < WARP_SIZE); __builtin_assume(k >= 0); __builtin_assume(k < WARP_SIZE); + __builtin_assume(k % VDR_Q4_0_Q8_1_MMQ == 0); const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); const float * x_dmf = (float *) x_dm; @@ -1729,6 +2050,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( __builtin_assume(j < WARP_SIZE); __builtin_assume(k >= 0); __builtin_assume(k < WARP_SIZE); + __builtin_assume(k % VDR_Q4_1_Q8_1_MMQ == 0); const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); @@ -1848,10 +2170,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( __builtin_assume(j < WARP_SIZE); __builtin_assume(k >= 0); __builtin_assume(k < WARP_SIZE); + __builtin_assume(k % VDR_Q5_0_Q8_1_MMQ == 0); const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; - const float * x_dmf = (float *) x_dm; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; int u[2*VDR_Q5_0_Q8_1_MMQ]; @@ -1862,7 +2186,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( } return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]); + (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]); } static __device__ __forceinline__ float vec_dot_q5_1_q8_1( @@ -1965,6 +2289,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( __builtin_assume(j < WARP_SIZE); __builtin_assume(k >= 0); __builtin_assume(k < WARP_SIZE); + __builtin_assume(k % VDR_Q5_1_Q8_1_MMQ == 0); const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; @@ -1989,12 +2314,13 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( int v[VDR_Q8_0_Q8_1_MMVQ]; int u[VDR_Q8_0_Q8_1_MMVQ]; +#pragma unroll for (int i = 0; i < VDR_Q8_0_Q8_1_MMVQ; ++i) { v[i] = get_int_from_int8(bq8_0->qs, iqs + i); u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); } - return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, bq8_1->ds); + return vec_dot_q8_0_q8_1_impl(v, u, bq8_0->d, bq8_1->ds.x); } static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { @@ -2065,43 +2391,14 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( __builtin_assume(j < WARP_SIZE); __builtin_assume(k >= 0); __builtin_assume(k < WARP_SIZE); + __builtin_assume(k % VDR_Q8_0_Q8_1_MMQ == 0); - const float * x_dmf = (float *) x_dm; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; return vec_dot_q8_0_q8_1_impl (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], - y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]); -} - -#define VDR_q2_K_q8_1 1 - -static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl( - const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, - const half2 & dm, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf_d = 0.0f; - float sumf_m = 0.0f; - - for (int i = 0; i < QR2_K; ++i) { - const int sc = scales[2*i]; - - const int vi = (v >> (2*i)) & 0x03030303; - - sumf_d += d8[i] * (__dp4a(vi, u[i], 0) * (sc & 0xF)); // SIMD dot product - - int sc_high = sc >> 4; - sc_high |= sc_high << 8; - sc_high |= sc_high << 16; - sumf_m += d8[i] * __dp4a(sc_high, u[i], 0); // multiply constant q2_K part with sum of q8_1 values - } - - const float2 dmf = __half22float2(dm); - - return dmf.x*sumf_d - dmf.y*sumf_m; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); } static __device__ __forceinline__ float vec_dot_q2_K_q8_1( @@ -2115,15 +2412,16 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( const uint8_t * scales = bq2_K->scales + scale_offset; const int v = get_int_from_uint8_aligned(bq2_K->qs, iqs); - int u[QR2_K]; + int u[QR2_K]; float d8[QR2_K]; +#pragma unroll for (int i = 0; i < QR2_K; ++ i) { u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); d8[i] = bq8_1[bq8_offset + i].ds.x; } - return vec_dot_q2_K_q8_1_impl(v, u, scales, bq2_K->dm, d8); + return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); } static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { @@ -2204,62 +2502,26 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( __builtin_assume(j < WARP_SIZE); __builtin_assume(k >= 0); __builtin_assume(k < WARP_SIZE); + __builtin_assume(k % VDR_Q2_K_Q8_1_MMQ == 0); - const int kbx = k / QI2_K; - const int kqsx = k % QI2_K; + const int kbx = k / QI2_K; + const int ky = (k % QI2_K) * QR2_K; + const float * y_df = (const float *) y_ds; - const int bq8_offset = QR2_K * (kqsx / QI8_1); - const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2); + int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4) + i / 4)) + kbx*16 + scale_offset; + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); - int u[QR2_K]; - float d8[QR2_K]; - - for (int l = 0; l < QR2_K; ++ l) { - const int y_qs_index = j * (QR2_K*WARP_SIZE) + kbx * (QR2_K*QI2_K) + (bq8_offset + l)*QI8_1 + kqsx % QI8_1; - u[l] = y_qs[y_qs_index]; - d8[l] = y_ds[y_qs_index / QI8_1].x; +#pragma unroll + for (int l = 0; l < QR2_K*VDR_Q2_K_Q8_1_MMQ; ++l) { + v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; } - return vec_dot_q2_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], u, scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], d8); -} + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; -#define VDR_q3_K_q8_1 1 - -static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl( - const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, - const int & scale_offset, const float & d, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf = 0.0f; - - for (int i = 0; i < QR3_K; ++i) { - const int isc = scale_offset + 2*i; - - const int isc_low = isc % (QK_K/32); - const int sc_shift_low = 4 * (isc / (QK_K/32)); - const int sc_low = (scales[isc_low] >> sc_shift_low) & 0xF; - - const int isc_high = isc % (QK_K/64); - const int sc_shift_high = 2 * (isc / (QK_K/64)); - const int sc_high = ((scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4; - - const int sc = (sc_low | sc_high) - 32; - - const int vil = (vl >> (2*i)) & 0x03030303; - - const int vih = ((vh >> i) << 2) & 0x04040404; - - const int vi = __vsubss4(vil, vih); - - sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d*sumf; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + const int index_y = j * (QR2_K*WARP_SIZE) + QR2_K*k; + return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q3_K_q8_1( @@ -2277,15 +2539,16 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted const int vh = ~get_int_from_uint8(bq3_K->hmask, iqs % (QI3_K/2)) >> bq8_offset; - int u[QR3_K]; + int u[QR3_K]; float d8[QR3_K]; +#pragma unroll for (int i = 0; i < QR3_K; ++i) { u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); d8[i] = bq8_1[bq8_offset + i].ds.x; } - return vec_dot_q3_K_q8_1_impl(vl, vh, u, bq3_K->scales, scale_offset, d, d8); + return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); } static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { @@ -2330,6 +2593,7 @@ template static __device__ __forceinline__ void load_tiles_q3_ const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; const int kbxd = k % blocks_per_tile_x_row; + float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI3_K) { @@ -2341,7 +2605,7 @@ template static __device__ __forceinline__ void load_tiles_q3_ const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd].x = bxi->d; + x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = bxi->d; } #pragma unroll @@ -2354,7 +2618,8 @@ template static __device__ __forceinline__ void load_tiles_q3_ const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); - x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); + // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted + x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); } #pragma unroll @@ -2367,7 +2632,19 @@ template static __device__ __forceinline__ void load_tiles_q3_ const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8(bxi->scales, k % (QI3_K/4)); + const int ksc = k % (QI3_K/4); + + const int ksc_low = ksc % (QI3_K/8); + const int shift_low = 4 * (ksc / (QI3_K/8)); + const int sc_low = (get_int_from_uint8(bxi->scales, ksc_low) >> shift_low) & 0x0F0F0F0F; + + const int ksc_high = QI3_K/8; + const int shift_high = 2 * ksc; + const int sc_high = ((get_int_from_uint8(bxi->scales, ksc_high) >> shift_high) << 4) & 0x30303030; + + const int sc = __vsubss4(sc_low | sc_high, 0x20202020); + + x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; } } @@ -2381,57 +2658,31 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( __builtin_assume(j < WARP_SIZE); __builtin_assume(k >= 0); __builtin_assume(k < WARP_SIZE); + __builtin_assume(k % VDR_Q3_K_Q8_1_MMQ == 0); const int kbx = k / QI3_K; - const int kqsx = k % QI3_K; + const int ky = (k % QI3_K) * QR3_K; + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; - const int bq8_offset = QR3_K * (kqsx / (QI3_K/2)); - const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2); + const int8_t * scales = ((int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; - const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4) + i / 4)) + kbx*16; + int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; - // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - const int vh = ~x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + kqsx % (QI3_K/2)] >> bq8_offset; +#pragma unroll + for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { + const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int shift = 2 * ((ky % 32) / 8); + const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - int u[QR3_K]; - float d8[QR3_K]; + const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vlh = (vh << 2) & 0x04040404; - for (int l = 0; l < QR3_K; ++ l) { - const int y_qs_index = j * (QR3_K*WARP_SIZE) + kbx * (QR3_K*QI3_K) + (bq8_offset + l)*QI8_1 + kqsx % QI8_1; - u[l] = y_qs[y_qs_index]; - d8[l] = y_ds[y_qs_index / QI8_1].x; + v[l] = __vsubss4(vll, vlh); } - return vec_dot_q3_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, scale_offset, - x_dm[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx].x, d8); -} - -#define VDR_q4_K_q8_1 2 - -static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl( - const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, - const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf_d = 0.0f; - float sumf_m = 0.0f; - - for (int i = 0; i < QR4_K; ++i) { - const int v0i = (v[0] >> (4*i)) & 0x0F0F0F0F; - const int v1i = (v[1] >> (4*i)) & 0x0F0F0F0F; - - const int dot1 = __dp4a(v1i, u[2*i+1], __dp4a(v0i, u[2*i+0], 0)); // SIMD dot product - const int dot2 = __dp4a(0x01010101, u[2*i+1], __dp4a(0x01010101, u[2*i+0], 0)); // sum of u - - sumf_d += d8[i] * (dot1 * sc[i]); - sumf_m += d8[i] * (dot2 * m[i]); // multiply constant part of q4_K with sum of q8_1 values - } - - return __half2float(dm4.x)*sumf_d - __half2float(dm4.y)*sumf_m; - -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + const int index_y = j * (QR3_K*WARP_SIZE) + k*QR3_K; + return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q4_K_q8_1( @@ -2478,7 +2729,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( u[2*i+1] = q8[4]; } - return vec_dot_q4_K_q8_1_impl(v, u, sc, m, bq4_K->dm, d8); + return vec_dot_q4_K_q8_1_impl_vmmq(v, u, sc, m, bq4_K->dm, d8); #else @@ -2566,7 +2817,7 @@ template static __device__ __forceinline__ void load_tiles_q4_ } const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_K) { @@ -2591,7 +2842,15 @@ template static __device__ __forceinline__ void load_tiles_q4_ const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); - x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_uint8_aligned(bxi->scales, k % (QI4_K/8)); + const int * scales = (int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; } } @@ -2605,76 +2864,20 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( __builtin_assume(j < WARP_SIZE); __builtin_assume(k >= 0); __builtin_assume(k < WARP_SIZE); + __builtin_assume(k % VDR_Q4_K_Q8_1_MMQ == 0); - const int kbx = k / QI6_K; // == 0 if QK_K == 256 - const int kqsx = k % QI6_K; // == k if QK_K == 256 + int v[QR4_K*VDR_Q4_K_Q8_1_MMQ]; - int v[2]; - int u[2*QR4_K]; - float d8[QR4_K]; - - // kqsx is in 0,2...30. bq8_offset = 2 * (kqsx/4) -> bq8_offset = 0, 2, 4, 6 - const int bq8_offset = QR4_K * ((kqsx/2) / (QI8_1/2)); - - v[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 0]; - v[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 4]; - - const uint16_t * scales = (const uint16_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + kbx * 4]; - uint16_t aux[2]; - const int l = bq8_offset/2; - if (l < 2) { - aux[0] = scales[l+0] & 0x3f3f; - aux[1] = scales[l+2] & 0x3f3f; - } else { - aux[0] = ((scales[l+2] >> 0) & 0x0f0f) | ((scales[l-2] & 0xc0c0) >> 2); - aux[1] = ((scales[l+2] >> 4) & 0x0f0f) | ((scales[l-0] & 0xc0c0) >> 2); - } - const uint8_t * sc = (const uint8_t *)aux; - const uint8_t * m = sc + 2; - - for (int l = 0; l < QR4_K; ++l) { - const int kqsy = j * (QR4_K*WARP_SIZE) + kbx * (QR4_K*QI4_K) + (bq8_offset + l) * QI8_1 + (kqsx/2) % (QI8_1/2); - u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)]; - u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)]; - d8[l] = y_ds[kqsy / QI8_1].x; +#pragma unroll + for (int l = 0; l < VDR_Q4_K_Q8_1_MMQ; ++l) { + v[l + 0] = (x_ql[i * (WARP_SIZE + 1) + k + l] >> 0) & 0x0F0F0F0F; + v[l + (QI4_K/4)] = (x_ql[i * (WARP_SIZE + 1) + k + l] >> 4) & 0x0F0F0F0F; } - return vec_dot_q4_K_q8_1_impl(v, u, sc, m, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K + kbx], d8); -} + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); -#define VDR_q5_K_q8_1 2 - -static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl( - const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, - const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf_d = 0.0f; - float sumf_m = 0.0f; - - for (int i = 0; i < QR5_K; ++i) { - const int vl0i = (vl[0] >> (4*i)) & 0x0F0F0F0F; - const int vl1i = (vl[1] >> (4*i)) & 0x0F0F0F0F; - - const int vh0i = ((vh[0] >> i) << 4) & 0x10101010; - const int vh1i = ((vh[1] >> i) << 4) & 0x10101010; - - const int v0i = vl0i | vh0i; - const int v1i = vl1i | vh1i; - - const int dot1 = __dp4a(v0i, u[2*i+0], __dp4a(v1i, u[2*i+1], 0)); // SIMD dot product - const int dot2 = __dp4a(0x01010101, u[2*i+0], __dp4a(0x01010101, u[2*i+1], 0)); // sum of u - - sumf_d += d8[i] * (dot1 * sc[i]); - sumf_m += d8[i] * (dot2 * m[i]); - - } - - return __half2float(dm5.x)*sumf_d - __half2float(dm5.y)*sumf_m; - -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + const int index_y = j * (QR4_K*WARP_SIZE) + QR4_K*k; + return vec_dot_q4_K_q8_1_impl_mmq(v, &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q5_K_q8_1( @@ -2711,6 +2914,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( const uint8_t * sc = (const uint8_t *)aux; const uint8_t * m = sc + 2; +#pragma unroll for (int i = 0; i < QR5_K; ++i) { const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; d8[i] = bq8i->ds.x; @@ -2767,14 +2971,12 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y]; + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE) + GGML_CUDA_MMQ_Y]; __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_K) + GGML_CUDA_MMQ_Y/QI5_K]; - __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/4) + GGML_CUDA_MMQ_Y/4]; __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8) + GGML_CUDA_MMQ_Y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; - *x_qh = tile_x_qh; *x_sc = tile_x_sc; } @@ -2801,12 +3003,25 @@ template static __device__ __forceinline__ void load_tiles_q5_ } const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR5_K*kqsx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql = get_int_from_uint8_aligned(bxi->qs, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8_aligned(bxi->qh, kqsx % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (kqsx / (QI5_K/4)) + 1)) << 4) & 0x10101010; + + const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; } const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_K) { @@ -2821,19 +3036,6 @@ template static __device__ __forceinline__ void load_tiles_q5_ x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; } -#pragma unroll - for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); - - if (need_check) { - i = min(i, i_max); - } - - const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI5_K/4); - - x_qh[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8(bxi->qh, k % (QI5_K/4)); - } - #pragma unroll for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) { int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y; @@ -2844,7 +3046,15 @@ template static __device__ __forceinline__ void load_tiles_q5_ const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); - x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_uint8_aligned(bxi->scales, k % (QI5_K/8)); + const int * scales = (int *) bxi->scales; + + const int ksc = k % (WARP_SIZE/8); + + // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 + int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits + scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits + + x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; } } @@ -2858,71 +3068,13 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( __builtin_assume(j < WARP_SIZE); __builtin_assume(k >= 0); __builtin_assume(k < WARP_SIZE); + __builtin_assume(k % VDR_Q5_K_Q8_1_MMQ == 0); - const int kbx = k / QI6_K; // == 0 if QK_K == 256 - const int kqsx = k % QI6_K; // == k if QK_K == 256 + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); - int vl[2]; - int vh[2]; - int u[2*QR4_K]; - float d8[QR4_K]; - - const int bq8_offset = QR5_K * ((kqsx/2) / (QI8_1/2)); - - vl[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 0]; - vl[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + (kqsx/2) % 4 + 4]; - - vh[0] = x_qh[i * (WARP_SIZE/4) + i/4 + (kqsx/2) % 4 + 0] >> bq8_offset; - vh[1] = x_qh[i * (WARP_SIZE/4) + i/4 + (kqsx/2) % 4 + 4] >> bq8_offset; - - const uint16_t * scales = (const uint16_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + kbx * 4]; - uint16_t aux[2]; - const int l = bq8_offset/2; - if (l < 2) { - aux[0] = scales[l+0] & 0x3f3f; - aux[1] = scales[l+2] & 0x3f3f; - } else { - aux[0] = ((scales[l+2] >> 0) & 0x0f0f) | ((scales[l-2] & 0xc0c0) >> 2); - aux[1] = ((scales[l+2] >> 4) & 0x0f0f) | ((scales[l-0] & 0xc0c0) >> 2); - } - const uint8_t * sc = (const uint8_t *)aux; - const uint8_t * m = sc + 2; - - for (int l = 0; l < QR5_K; ++l) { - const int kqsy = j * (QR5_K*WARP_SIZE) + kbx * (QR5_K*QI5_K) + (bq8_offset + l) * QI8_1 + (kqsx/2) % (QI8_1/2); - u[2*l+0] = y_qs[kqsy + 0*(QI8_1/2)]; - u[2*l+1] = y_qs[kqsy + 1*(QI8_1/2)]; - d8[l] = y_ds[kqsy / QI8_1].x; - } - - return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K + kbx], d8); -} - -#define VDR_q6_K_q8_1 1 - -static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl( - const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, - const float & d, const float * __restrict__ d8) { - -#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics - float sumf = 0.0f; - - for (int i = 0; i < QR6_K; ++i) { - const int sc = scales[4*i]; - - const int vil = (vl >> (4*i)) & 0x0F0F0F0F; - - const int vih = ((vh >> (4*i)) << 4) & 0x30303030; - - const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32 - - sumf += d8[i] * (__dp4a(vi, u[i], 0) * sc); // SIMD dot product - } - - return d*sumf; -#else - return 0.0f; // only to satisfy the compiler -#endif // __CUDA_ARCH__ >= MIN_CC_DP4A + const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; + const int index_y = j * (QR5_K*WARP_SIZE) + QR5_K*k; + return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q6_K_q8_1( @@ -2942,24 +3094,23 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( int u[QR6_K]; float d8[QR6_K]; +#pragma unroll for (int i = 0; i < QR6_K; ++i) { u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); d8[i] = bq8_1[bq8_offset + 2*i].ds.x; } - return vec_dot_q6_K_q8_1_impl(vl, vh, u, scales, bq6_K->d, d8); + return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); } static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y]; + __shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (2*WARP_SIZE) + GGML_CUDA_MMQ_Y]; __shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K) + GGML_CUDA_MMQ_Y/QI6_K]; - __shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2) + GGML_CUDA_MMQ_Y/2]; __shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8) + GGML_CUDA_MMQ_Y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; - *x_qh = tile_x_qh; *x_sc = tile_x_sc; } @@ -2986,12 +3137,26 @@ template static __device__ __forceinline__ void load_tiles_q6_ } const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx; + const int ky = QR6_K*kqsx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->ql, kqsx); + const int ql = get_int_from_uint8(bxi->ql, kqsx); + const int ql0 = (ql >> 0) & 0x0F0F0F0F; + const int ql1 = (ql >> 4) & 0x0F0F0F0F; + + const int qh = get_int_from_uint8(bxi->qh, (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)); + const int qh0 = ((qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) << 4) & 0x30303030; + const int qh1 = (qh >> (2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)))) & 0x30303030; + + const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; + const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); + + x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); } const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 - const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 + float * x_dmf = (float *) x_dm; #pragma unroll for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI6_K) { @@ -3003,20 +3168,7 @@ template static __device__ __forceinline__ void load_tiles_q6_ const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd].x = bxi->d; - } - -#pragma unroll - for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 2) { - int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); - - if (need_check) { - i = min(i, i_max); - } - - const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI6_K/2); - - x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = get_int_from_uint8(bxi->qh, k % (QI6_K/2)); + x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = bxi->d; } #pragma unroll @@ -3043,33 +3195,19 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( __builtin_assume(j < WARP_SIZE); __builtin_assume(k >= 0); __builtin_assume(k < WARP_SIZE); + __builtin_assume(k % VDR_Q6_K_Q8_1_MMQ == 0); - const int kbx = k / QI6_K; // == 0 if QK_K == 256 - const int kqsx = k % QI6_K; // == k if QK_K == 256 + const float * x_dmf = (const float *) x_dm; + const float * y_df = (const float *) y_ds; - const int bq8_offset = 2 * QR6_K * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/4); - const int scale_offset = (QI6_K/4) * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/8); - const int vh_shift = 2 * ((kqsx % (QI6_K/2)) / (QI6_K/4)); + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI6_K/2) + (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)] >> vh_shift; - - const int x_sc_offset = i * (WARP_SIZE/8) + i/8 + kbx * (QI6_K/8); - const int8_t * scales = ((int8_t *) (x_sc + x_sc_offset)) + scale_offset; - - int u[QR6_K]; - float d8[QR6_K]; - - for (int l = 0; l < QR6_K; ++l) { - const int kqsy = j * (QR6_K*WARP_SIZE) + kbx * (QR6_K*QI6_K) + (bq8_offset + 2*l)*QI8_1 + kqsx % QI8_1; - u[l] = y_qs[kqsy]; - d8[l] = y_ds[kqsy / QI8_1].x; - } - - return vec_dot_q6_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, - x_dm[i * (WARP_SIZE/QI6_K) + i/QI6_K + kbx].x, d8); + const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; + const int index_y = j * (QR6_K*WARP_SIZE) + QR6_K*k; + return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); } -template static __global__ void mul_mat_q( const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, @@ -3130,7 +3268,16 @@ static __global__ void mul_mat_q( const int ids = (ids0 + tid_y * (WARP_SIZE/blocks_per_tile_y_col) + tid_x / blocks_per_tile_y_col) % WARP_SIZE; const int kby = tid_x % blocks_per_tile_y_col; const int col_y_eff = min(col_y_0 + ids, ncols_y-1); - tile_y_ds[ids * (qr*WARP_SIZE/QI8_1) + kby] = y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby].ds; + + // if the sum is not needed it's faster to transform the scale to f32 ahead of time + const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (qr*WARP_SIZE/QI8_1) + kby]; + if (need_sum) { + *dsi_dst = *dsi_src; + } else { + float * dfi_dst = (float *) dsi_dst; + *dfi_dst = (*dsi_src).x; + } } __syncthreads(); @@ -3780,7 +3927,7 @@ static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -3789,7 +3936,7 @@ static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -3798,7 +3945,7 @@ static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -3807,7 +3954,7 @@ static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -3816,7 +3963,7 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - mul_mat_vec_q + mul_mat_vec_q <<>>(vx, vy, dst, ncols, nrows); } @@ -3873,10 +4020,10 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); if (nrows_x % GGML_CUDA_MMQ_Y == 0) { - mul_mat_q, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + mul_mat_q, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { - mul_mat_q, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> + mul_mat_q, VDR_Q4_0_Q8_1_MMQ, vec_dot_q4_0_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -3891,10 +4038,10 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); if (nrows_x % GGML_CUDA_MMQ_Y == 0) { - mul_mat_q, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + mul_mat_q, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { - mul_mat_q, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> + mul_mat_q, VDR_Q4_1_Q8_1_MMQ, vec_dot_q4_1_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -3909,10 +4056,10 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); if (nrows_x % GGML_CUDA_MMQ_Y == 0) { - mul_mat_q, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + mul_mat_q, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { - mul_mat_q, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> + mul_mat_q, VDR_Q5_0_Q8_1_MMQ, vec_dot_q5_0_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -3927,10 +4074,10 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); if (nrows_x % GGML_CUDA_MMQ_Y == 0) { - mul_mat_q, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + mul_mat_q, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { - mul_mat_q, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> + mul_mat_q, VDR_Q5_1_Q8_1_MMQ, vec_dot_q5_1_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -3945,10 +4092,10 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); if (nrows_x % GGML_CUDA_MMQ_Y == 0) { - mul_mat_q, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + mul_mat_q, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { - mul_mat_q, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> + mul_mat_q, VDR_Q8_0_Q8_1_MMQ, vec_dot_q8_0_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -3963,10 +4110,10 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); if (nrows_x % GGML_CUDA_MMQ_Y == 0) { - mul_mat_q, VDR_q2_K_q8_1, vec_dot_q2_K_q8_1_mul_mat> + mul_mat_q, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { - mul_mat_q, VDR_q2_K_q8_1, vec_dot_q2_K_q8_1_mul_mat> + mul_mat_q, VDR_Q2_K_Q8_1_MMQ, vec_dot_q2_K_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -3981,10 +4128,10 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); if (nrows_x % GGML_CUDA_MMQ_Y == 0) { - mul_mat_q, VDR_q3_K_q8_1, vec_dot_q3_K_q8_1_mul_mat> + mul_mat_q, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { - mul_mat_q, VDR_q3_K_q8_1, vec_dot_q3_K_q8_1_mul_mat> + mul_mat_q, VDR_Q3_K_Q8_1_MMQ, vec_dot_q3_K_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -3999,10 +4146,10 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); if (nrows_x % GGML_CUDA_MMQ_Y == 0) { - mul_mat_q, VDR_q4_K_q8_1, vec_dot_q4_K_q8_1_mul_mat> + mul_mat_q, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { - mul_mat_q, VDR_q4_K_q8_1, vec_dot_q4_K_q8_1_mul_mat> + mul_mat_q, VDR_Q4_K_Q8_1_MMQ, vec_dot_q4_K_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -4017,10 +4164,10 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); if (nrows_x % GGML_CUDA_MMQ_Y == 0) { - mul_mat_q, VDR_q5_K_q8_1, vec_dot_q5_K_q8_1_mul_mat> + mul_mat_q, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { - mul_mat_q, VDR_q5_K_q8_1, vec_dot_q5_K_q8_1_mul_mat> + mul_mat_q, VDR_Q5_K_Q8_1_MMQ, vec_dot_q5_K_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } } @@ -4035,10 +4182,10 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1); if (nrows_x % GGML_CUDA_MMQ_Y == 0) { - mul_mat_q, VDR_q6_K_q8_1, vec_dot_q6_K_q8_1_mul_mat> + mul_mat_q, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } else { - mul_mat_q, VDR_q6_K_q8_1, vec_dot_q6_K_q8_1_mul_mat> + mul_mat_q, VDR_Q6_K_Q8_1_MMQ, vec_dot_q6_K_q8_1_mul_mat> <<>>(vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); } }