cuda : performance optimizations (#1530)

* xor hack

* block y dim

* loop unrolling

* Fixed cmake LLAMA_CUDA_BY option

* Removed hipblas compatibility code

* Define GGML_CUDA_DMMV_BLOCK_Y if not defined

* Fewer iters, more ops per iter

* Renamed DMMV X/Y compilation options
This commit is contained in:
Johannes Gäßler 2023-05-25 23:07:29 +02:00 committed by GitHub
parent ac7876ac20
commit 1fcdcc28b1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 110 additions and 64 deletions

View File

@ -68,6 +68,8 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework"
option(LLAMA_BLAS "llama: use BLAS" OFF) option(LLAMA_BLAS "llama: use BLAS" OFF)
option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic) option(LLAMA_BLAS_VENDOR "llama: BLA_VENDOR from https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" Generic)
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
@ -184,6 +186,8 @@ if (LLAMA_CUBLAS)
set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS) add_compile_definitions(GGML_USE_CUBLAS)
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
if (LLAMA_STATIC) if (LLAMA_STATIC)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)

View File

@ -133,9 +133,19 @@ ifdef LLAMA_CUBLAS
OBJS += ggml-cuda.o OBJS += ggml-cuda.o
NVCC = nvcc NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native
ifdef LLAMA_CUDA_DMMV_X
NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
else
NVCCFLAGS += -DGGML_CUDA_DMMV_X=32
endif # LLAMA_CUDA_DMMV_X
ifdef LLAMA_CUDA_DMMV_Y
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y)
else
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
endif # LLAMA_CUDA_DMMV_Y
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif endif # LLAMA_CUBLAS
ifdef LLAMA_CLBLAST ifdef LLAMA_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST CFLAGS += -DGGML_USE_CLBLAST
CXXFLAGS += -DGGML_USE_CLBLAST CXXFLAGS += -DGGML_USE_CLBLAST

View File

@ -83,9 +83,19 @@ typedef struct {
} block_q8_0; } block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding"); static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
#define WARP_SIZE 32
#define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec
// dmmv = dequantize_mul_mat_vec
#ifndef GGML_CUDA_DMMV_X
#define GGML_CUDA_DMMV_X 32
#endif
#ifndef GGML_CUDA_DMMV_Y
#define GGML_CUDA_DMMV_Y 1
#endif
static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -200,41 +210,51 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
dequantize_kernel(vx, ib, iqs, v0, v1); dequantize_kernel(vx, ib, iqs, v0, v1);
} }
template <int block_size, int qk, int qr, dequantize_kernel_t dequantize_kernel> template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) {
const int row = blockIdx.x; // qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int iter_stride = 2*GGML_CUDA_DMMV_X;
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
__shared__ float tmp[block_size]; // separate sum for each thread float tmp = 0; // partial sum for thread in warp
tmp[tid] = 0;
for (int i = 0; i < ncols/block_size; i += 2) { for (int i = 0; i < ncols; i += iter_stride) {
const int col = i*block_size + 2*tid; const int col = i + vals_per_iter*tid;
const int ib = (row*ncols + col)/qk; // block index const int ib = (row*ncols + col)/qk; // x block index
const int iqs = (col%qk)/qr; // quant index const int iqs = (col%qk)/qr; // x quant index
const int iybs = col - col%qk; // y block start index const int iybs = col - col%qk; // y block start index
// processing >2 values per i iter is faster for fast GPUs
#pragma unroll
for (int j = 0; j < vals_per_iter; j += 2) {
// process 2 vals per j iter
// dequantize // dequantize
float v0, v1; float v0, v1;
dequantize_kernel(vx, ib, iqs, v0, v1); dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
// matrix multiplication // matrix multiplication
tmp[tid] += v0 * y[iybs + iqs + 0]; tmp += v0 * y[iybs + iqs + j/qr + 0];
tmp[tid] += v1 * y[iybs + iqs + y_offset]; tmp += v1 * y[iybs + iqs + j/qr + y_offset];
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
}
} }
// sum up partial sums and write back result // sum up partial sums and write back result
__syncthreads(); __syncthreads();
for (int s=block_size/2; s>0; s>>=1) { #pragma unroll
if (tid < s) { for (int mask = 16; mask > 0; mask >>= 1) {
tmp[tid] += tmp[tid + s]; tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
__syncthreads();
} }
if (tid == 0) { if (tid == 0) {
dst[row] = tmp[0]; dst[row] = tmp;
} }
} }
@ -269,33 +289,43 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cu
} }
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_0, QR4_0, dequantize_q4_0> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK4_1, QR4_1, dequantize_q4_1> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_0, QR5_0, dequantize_q5_0> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK5_1, QR5_1, dequantize_q5_1> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, QK8_0, QR8_0, dequantize_q8_0> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@ -304,9 +334,11 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
} }
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
dequantize_mul_mat_vec<CUDA_DMMV_BLOCK_SIZE, 32, 1, convert_f16> GGML_ASSERT(nrows % GGML_CUDA_DMMV_Y == 0);
<<<nrows, CUDA_DMMV_BLOCK_SIZE, 0, stream>>>(vx, y, dst, ncols); const dim3 block_dims(WARP_SIZE, GGML_CUDA_DMMV_Y, 1);
dequantize_mul_mat_vec<1, 1, convert_f16>
<<<nrows/GGML_CUDA_DMMV_Y, block_dims, 0, stream>>>(vx, y, dst, ncols);
} }
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {