From 924dd22fd3ba93e097f8d19ba5cda919ca2fe2fb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Wed, 5 Jul 2023 14:19:42 +0200 Subject: [PATCH] Quantized dot products for CUDA mul mat vec (#2067) --- CMakeLists.txt | 13 +- Makefile | 13 +- README.md | 3 +- ggml-cuda.cu | 491 ++++++++++++++++++++++++++++++++++++++++--------- 4 files changed, 427 insertions(+), 93 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ac0f6f4e..a2404548f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,8 +68,9 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework option(LLAMA_BLAS "llama: use BLAS" OFF) set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF) set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") -set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") +set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") option(LLAMA_CUDA_DMMV_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF) set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") option(LLAMA_CLBLAST "llama: use CLBlast" OFF) @@ -246,8 +247,14 @@ if (LLAMA_CUBLAS) set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) + if (LLAMA_CUDA_FORCE_DMMV) + add_compile_definitions(GGML_CUDA_FORCE_DMMV) + endif() add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) + add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) + if (DEFINED LLAMA_CUDA_DMMV_Y) + add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_DMMV_Y}) # for backwards compatibility + endif() if (LLAMA_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_DMMV_F16) endif() @@ -263,7 +270,7 @@ if (LLAMA_CUBLAS) if (LLAMA_CUDA_DMMV_F16) set(CMAKE_CUDA_ARCHITECTURES "61") # needed for f16 CUDA intrinsics else() - set(CMAKE_CUDA_ARCHITECTURES "52") # lowest CUDA 12 standard + set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics endif() endif() message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") diff --git a/Makefile b/Makefile index 8966a3590..71415664b 100644 --- a/Makefile +++ b/Makefile @@ -164,16 +164,21 @@ ifdef LLAMA_CUBLAS OBJS += ggml-cuda.o NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native +ifdef LLAMA_CUDA_FORCE_DMMV + NVCCFLAGS += -DGGML_CUDA_FORCE_DMMV +endif # LLAMA_CUDA_FORCE_DMMV ifdef LLAMA_CUDA_DMMV_X NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) else NVCCFLAGS += -DGGML_CUDA_DMMV_X=32 endif # LLAMA_CUDA_DMMV_X -ifdef LLAMA_CUDA_DMMV_Y - NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y) +ifdef LLAMA_CUDA_MMV_Y + NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) +else ifdef LLAMA_CUDA_DMMV_Y + NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_DMMV_Y) # for backwards compatibility else - NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 -endif # LLAMA_CUDA_DMMV_Y + NVCCFLAGS += -DGGML_CUDA_MMV_Y=1 +endif # LLAMA_CUDA_MMV_Y ifdef LLAMA_CUDA_DMMV_F16 NVCCFLAGS += -DGGML_CUDA_DMMV_F16 endif # LLAMA_CUDA_DMMV_F16 diff --git a/README.md b/README.md index 6c2bb392e..32f17c2d1 100644 --- a/README.md +++ b/README.md @@ -345,8 +345,9 @@ Building the program with BLAS support may lead to some performance improvements | Option | Legal values | Default | Description | |-------------------------|------------------------|---------|-------------| + | LLAMA_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 7.0/Turing/RTX 2000 or higher). Does not affect k-quants. | | LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | - | LLAMA_CUDA_DMMV_Y | Positive integer | 1 | Block size in y direction for the CUDA dequantization + mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | + | LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | | LLAMA_CUDA_DMMV_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 0b12a9e76..7965ff741 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -70,9 +70,11 @@ typedef void (*ggml_cuda_op_t)( // QK = number of values after dequantization // QR = QK / number of values before dequantization +// QI = number of 32 bit integers before dequantization #define QK4_0 32 #define QR4_0 2 +#define QI4_0 4 typedef struct { half d; // delta uint8_t qs[QK4_0 / 2]; // nibbles / quants @@ -81,6 +83,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 #define QK4_1 32 #define QR4_1 2 +#define QI4_1 4 typedef struct { half d; // delta half m; // min @@ -90,6 +93,7 @@ static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong #define QK5_0 32 #define QR5_0 2 +#define QI5_0 4 typedef struct { half d; // delta uint8_t qh[4]; // 5-th bit of quants @@ -99,6 +103,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5 #define QK5_1 32 #define QR5_1 2 +#define QI5_1 4 typedef struct { half d; // delta half m; // min @@ -109,12 +114,25 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + #define QK8_0 32 #define QR8_0 1 +#define QI8_0 8 typedef struct { half d; // delta int8_t qs[QK8_0]; // quants } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); +#define QK8_1 32 +#define QR8_1 1 +#define QI8_1 8 +typedef struct { + half d; // delta + half s; // unquantized sum + int8_t qs[QK8_0]; // quants +} block_q8_1; +static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding"); + +typedef float (*vec_dot_q_cuda_t)(const void * vbq, const block_q8_1 * bq8_1, const int iqs); + //================================= k-quants #ifdef GGML_QKK_64 @@ -198,14 +216,15 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 +#define CUDA_QUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 // dmmv = dequantize_mul_mat_vec #ifndef GGML_CUDA_DMMV_X #define GGML_CUDA_DMMV_X 32 #endif -#ifndef GGML_CUDA_DMMV_Y -#define GGML_CUDA_DMMV_Y 1 +#ifndef GGML_CUDA_MMV_Y +#define GGML_CUDA_MMV_Y 1 #endif #ifndef K_QUANTS_PER_ITERATION @@ -270,7 +289,6 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol } // sum up partial sums - __syncthreads(); #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); @@ -714,7 +732,6 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float #endif // sum up partial sums and write back result - __syncthreads(); #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); @@ -819,7 +836,6 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float #endif // sum up partial sums and write back result - __syncthreads(); #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); @@ -923,7 +939,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float #endif // sum up partial sums and write back result - __syncthreads(); #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); @@ -1028,7 +1043,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float #endif // sum up partial sums and write back result - __syncthreads(); #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); @@ -1139,7 +1153,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float #endif // sum up partial sums and write back result - __syncthreads(); #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); @@ -1158,6 +1171,41 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, v.y = x[ib + iqs + 1]; } +static __global__ void quantize_q8_1(const float * x, void * vy, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + block_q8_1 * y = (block_q8_1 *) vy; + + const int ib = i / QK8_0; // block index + const int iqs = i % QK8_0; // quant index + + const float xi = x[i]; + float amax = fabsf(xi); + float sum = xi; + +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); + sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); + } + + const float d = amax / 127; + const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); + + y[ib].qs[iqs] = q; + + if (iqs > 0) { + return; + } + + y[ib].d = d; + y[ib].s = sum; +} + template static __global__ void dequantize_block(const void * vx, float * y, const int k) { const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; @@ -1179,6 +1227,182 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k) y[iybs + iqs + y_offset] = v.y; } +static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics + const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; + + int vi; + memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); + const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]); + + const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d); + + // subtract 8 from each quantized value + const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808); + const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808); + + // SIMD dot product of quantized values + int sumi = __dp4a(vi0, ui0, 0); + sumi = __dp4a(vi1, ui1, sumi); + + return sumi*d; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= 600 +} + +static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics + const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; + + const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]); + const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]); + + const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d); + const float m = bq4_1->m; + const float s = bq8_1->s; + + const int vi0 = (vi >> 0) & 0x0F0F0F0F; + const int vi1 = (vi >> 4) & 0x0F0F0F0F; + + // SIMD dot product of quantized values + int sumi = __dp4a(vi0, ui0, 0); + sumi = __dp4a(vi1, ui1, sumi); + + return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= 600 +} + +static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics + const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; + + int qs; + memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); + const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2); + const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2); + const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]); + + const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d); + + int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits + vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5 + vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13 + vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21 + vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29 + vi0 = __vsub4(vi0, 0x10101010); // subtract 16 from quantized values + int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values + + int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits + vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5 + vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13 + vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21 + vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29 + vi1 = __vsub4(vi1, 0x10101010); // subtract 16 from quantized values + sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values + + return sumi*d; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= 600 +} + +static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics + const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; + + const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]); + const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2); + const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2); + const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]); + + const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d); + const float m = bq5_1->m; + const float s = bq8_1->s; + + int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits + vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5 + vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13 + vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21 + vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29 + int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values + + int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits + vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5 + vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13 + vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21 + vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29 + sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values + + return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= 600 +} + +static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) { +#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics + const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; + + int vi; + memcpy(&vi, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int)); + const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]); + + const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d); + + // SIMD dot product of quantized values + int sumi = __dp4a(vi, ui, 0); + + return sumi*d; +#else + return 0.0f; // only to satisfy the compiler +#endif // __CUDA_ARCH__ >= 600 +} + +template +static __global__ void mul_mat_vec_q(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) { + const int row = blockIdx.y*blockDim.y + threadIdx.y; + + if (row >= nrows) { + return; + } + + const int blocks_per_row = ncols / qk; + const int blocks_per_warp = WARP_SIZE / qi; + +// partial sum for each thread + float tmp = 0.0f; + + const block_q_t * x = (const block_q_t *) vx; + const block_q8_1 * y = (const block_q8_1 *) vy; + + for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { + const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index + + const int iby = i + threadIdx.x / qi; // y block index + + const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int + + tmp += vec_dot_q_cuda(&x[ibx], &y[iby], iqs); + } + + // sum up partial sums and write back result +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + } + + if (threadIdx.x == 0) { + dst[row] = tmp; + } +} + template static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) { // qk = quantized weights per x block @@ -1233,7 +1457,6 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, } // sum up partial sums and write back result - __syncthreads(); #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); @@ -1284,7 +1507,6 @@ static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, fl const int idst = channel*nrows_dst + row_dst; // sum up partial sums and write back result - __syncthreads(); #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); @@ -1330,7 +1552,6 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous } // sum up partial sums and write back result - __syncthreads(); #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); @@ -1440,7 +1661,6 @@ static __global__ void soft_max_f32(const float * x, float * dst, const int ncol } // sum up partial sums - __syncthreads(); #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); @@ -1494,6 +1714,11 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con rms_norm_f32<<>>(x, dst, ncols); } +static void quantize_row_q8_1_cuda(const float * x, void * vy, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; + quantize_q8_1<<>>(x, vy, k); +} + static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<<>>(vx, y, k); @@ -1562,45 +1787,45 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec <<>>(vx, y, dst, ncols, nrows); } @@ -1647,6 +1872,51 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } +static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + +static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(1, block_num_y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + mul_mat_vec_q + <<>>(vx, vy, dst, ncols, nrows); +} + static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; dequantize_block<1, 1, convert_f16><<>>(vx, y, k); @@ -1654,9 +1924,9 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y; + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const dim3 block_nums(1, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); dequantize_mul_mat_vec<1, 1, convert_f16> <<>>(vx, y, dst, ncols, nrows); } @@ -1822,6 +2092,7 @@ static size_t g_scratch_offset = 0; static int g_device_count = -1; static int g_main_device = 0; +static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; @@ -1839,9 +2110,12 @@ void ggml_init_cublas() { for (int id = 0; id < g_device_count; ++id) { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); - fprintf(stderr, " Device %d: %s\n", id, prop.name); + fprintf(stderr, " Device %d: %s, compute capability %d.%d\n", id, prop.name, prop.major, prop.minor); + g_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; + + g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; } for (int id = 0; id < g_device_count; ++id) { g_tensor_split[id] /= total_vram; @@ -2057,7 +2331,7 @@ inline void ggml_cuda_op_rms_norm( (void) i1; } -inline void ggml_cuda_op_dequantize_mul_mat_vec( +inline void ggml_cuda_op_mul_mat_vec( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, cudaStream_t & cudaStream_main){ @@ -2069,69 +2343,116 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec( const int64_t ne00 = src0->ne[0]; const int64_t nrows = i01_high - i01_low; -// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics -#ifdef GGML_CUDA_DMMV_F16 - size_t ash; - dfloat * src1_dfloat = nullptr; // dfloat == half - - bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || - src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; - - if (src1_convert_f16) { - src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash); - ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00, - ne00, 1, sizeof(float), 0, 0, - ne00, 1, sizeof(half), 0, 0, cudaStream_main); - } +#ifdef GGML_CUDA_FORCE_DMMV + const bool use_mul_mat_vec_q = false; #else - dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion + int id; + CUDA_CHECK(cudaGetDevice(&id)); + + const bool mul_mat_vec_q_implemented = src0->type == GGML_TYPE_Q4_0 || + src0->type == GGML_TYPE_Q4_1 || + src0->type == GGML_TYPE_Q5_0 || + src0->type == GGML_TYPE_Q5_1 || + src0->type == GGML_TYPE_Q8_0; + + // The integer intrinsics used in mul_mat_vec_q are available with compute capability 6. + // However, they have bad performance with Pascal cards. + // Therefore, in a multi GPU setting decide at runtime which GPUs should use mul_mat_vec_q. + const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 700 && mul_mat_vec_q_implemented; +#endif + + if (use_mul_mat_vec_q) { + size_t as; + void * src1_q8_1 = ggml_cuda_pool_malloc(ne00*sizeof(block_q8_1)/QK8_1, &as); + quantize_row_q8_1_cuda(src1_ddf_i, src1_q8_1, ne00, cudaStream_main); + + switch (src0->type) { + case GGML_TYPE_Q4_0: + mul_mat_vec_q4_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q4_1: + mul_mat_vec_q4_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q5_0: + mul_mat_vec_q5_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q5_1: + mul_mat_vec_q5_1_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q8_0: + mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + default: + GGML_ASSERT(false); + break; + } + + ggml_cuda_pool_free(src1_q8_1, as); + } else { + // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics +#ifdef GGML_CUDA_DMMV_F16 + size_t ash; + dfloat * src1_dfloat = nullptr; // dfloat == half + + bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || + src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || + src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; + + if (src1_convert_f16) { + src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash); + ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00, + ne00, 1, sizeof(float), 0, 0, + ne00, 1, sizeof(half), 0, 0, cudaStream_main); + } +#else + dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion #endif // GGML_CUDA_DMMV_F16 - switch (src0->type) { - case GGML_TYPE_Q4_0: - dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q4_1: - dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q5_0: - dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q5_1: - dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q2_K: - dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q4_K: - dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q5_K: - dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - case GGML_TYPE_F16: - convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); - break; - default: - GGML_ASSERT(false); - break; - } + switch (src0->type) { + case GGML_TYPE_Q4_0: + dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q4_1: + dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q5_0: + dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q5_1: + dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q8_0: + dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q2_K: + dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q3_K: + dequantize_mul_mat_vec_q3_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q4_K: + dequantize_mul_mat_vec_q4_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q5_K: + dequantize_mul_mat_vec_q5_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_Q6_K: + dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + case GGML_TYPE_F16: + convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main); + break; + default: + GGML_ASSERT(false); + break; + } #ifdef GGML_CUDA_DMMV_F16 - if (src1_convert_f16) { - ggml_cuda_pool_free(src1_dfloat, ash); - } + if (src1_convert_f16) { + ggml_cuda_pool_free(src1_dfloat, ash); + } #endif // GGML_CUDA_DMMV_F16 + } (void) src1; (void) dst; @@ -2701,8 +3022,8 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ }else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { - if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0 && src0->ne[1] % GGML_CUDA_DMMV_Y == 0) { - ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false, false); + if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { + ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); } else { ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); }