diff --git a/Makefile b/Makefile index 332496cfc..3f0d91b85 100644 --- a/Makefile +++ b/Makefile @@ -688,12 +688,6 @@ ifdef GGML_CUDA_DMMV_F16 MK_NVCCFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_DMMV_F16 -ifdef GGML_CUDA_KQUANTS_ITER - MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) -else - MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 -endif - ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) else @@ -810,7 +804,6 @@ ifdef GGML_HIPBLAS GGML_CUDA_DMMV_X ?= 32 GGML_CUDA_MMV_Y ?= 1 - GGML_CUDA_KQUANTS_ITER ?= 2 MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUDA @@ -827,7 +820,6 @@ endif # GGML_HIP_UMA HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS)) HIPFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) HIPFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) - HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) ifdef GGML_CUDA_FORCE_DMMV HIPFLAGS += -DGGML_CUDA_FORCE_DMMV diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index cc1685884..c55d889f3 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -120,8 +120,6 @@ option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels") set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels") option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) -set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING - "ggml: iters./thread per block for Q2_K/Q6_K") set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index ff84b9bb5..843c2a68f 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -323,7 +323,6 @@ if (GGML_CUDA) add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_USE_GRAPHS) @@ -471,7 +470,6 @@ if (GGML_HIPBLAS) add_compile_definitions(GGML_USE_HIPBLAS) add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) if (GGML_HIP_UMA) add_compile_definitions(GGML_HIP_UMA) diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu index 96a5adef5..a45a92224 100644 --- a/ggml/src/ggml-cuda/dmmv.cu +++ b/ggml/src/ggml-cuda/dmmv.cu @@ -2,16 +2,7 @@ #include "dequantize.cuh" #include "convert.cuh" -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 2 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -22,15 +13,15 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, float tmp = 0; // partial sum for thread in warp - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + const int tid = threadIdx.x/2; // 0...15 + const int ix = threadIdx.x%2; // 0,1 - const int step = 16/K_QUANTS_PER_ITERATION; + const int step = 8; const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int in = tid - step*im; // 0...15 or 0...7 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 + const int l0 = 2*in; // 0...14 in steps of 2 const int q_offset = 32*im + l0; const int s_offset = 8*im; const int y_offset = 128*im + l0; @@ -39,7 +30,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const uint8_t * d = (const uint8_t *)aux; const uint8_t * m = (const uint8_t *)(aux + 2); - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; @@ -54,7 +45,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, aux[3] = (a[1] >> 4) & 0x0f0f0f0f; float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) @@ -94,17 +85,17 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + const int tid = threadIdx.x/2; // 0...16 + const int ix = threadIdx.x%2; // 0,1 - const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop - const int step = 16/K_QUANTS_PER_ITERATION; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0....15 or 0...7 + const int n = 2; // iterations in the inner loop + const int step = 8; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 const uint8_t m = 1 << (4*im); - const int l0 = n*in; // 0...15 or 0...14 in steps of 2 + const int l0 = n*in; // 0...15 or 0...14 in steps of 2 const int q_offset = 32*im + l0; const int y_offset = 128*im + l0; @@ -113,7 +104,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const uint16_t s_shift = 4*im; - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; @@ -163,14 +154,14 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 + const int tid = threadIdx.x/2; // 0...16 + const int ix = threadIdx.x%2; // 0,1 - const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + const int step = 4; - const int il = tid/step; // 0...3 - const int ir = tid - step*il; // 0...7 or 0...3 - const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + const int il = tid/step; // 0...3 + const int ir = tid - step*il; // 0...7 or 0...3 + const int n = 4; const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int in = il%2; @@ -182,17 +173,12 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; -#if K_QUANTS_PER_ITERATION == 2 uint32_t q32[4]; const uint8_t * q4 = (const uint8_t *)q32; -#else - uint16_t q16[4]; - const uint8_t * q4 = (const uint8_t *)q16; -#endif float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; @@ -206,7 +192,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); -#if K_QUANTS_PER_ITERATION == 2 const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); const uint32_t * q2 = q1 + 16; @@ -223,25 +208,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; } tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#else - const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); - const uint16_t * q2 = q1 + 32; - - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[0] & 0xf0f0; - q16[2] = q2[0] & 0x0f0f; - q16[3] = q2[0] & 0xf0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 2; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; - s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#endif - } // sum up partial sums and write back result @@ -341,9 +307,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, } static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = blockIdx.x*blockDim.y + threadIdx.y; if (row > nrows) return; @@ -352,21 +315,17 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const block_q6_K * x = (const block_q6_K *)vx + ib0; - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const int tid = threadIdx.x/2; // 0...16 + const int ix = threadIdx.x%2; // 0, 1 - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const int step = 8; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 -#if K_QUANTS_PER_ITERATION == 1 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 - const int is = 0; -#else - const int l0 = 4 * in; // 0, 4, 8, ..., 28 + const int l0 = 4 * in; // 0, 4, 8, ..., 28 const int is = in / 4; -#endif + const int ql_offset = 64*im + l0; const int qh_offset = 32*im + l0; const int s_offset = 8*im + is; @@ -374,7 +333,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * ql = x[i].ql + ql_offset; @@ -383,17 +342,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float d = x[i].d; -#if K_QUANTS_PER_ITERATION == 1 - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) - + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) - + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) - + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) - +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); - tmp += sum; -#else float sum = 0; for (int l = 0; l < 4; ++l) { sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) @@ -402,8 +350,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); } tmp += sum; -#endif - } // sum up partial sums and write back result @@ -547,7 +493,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 + const int ny = 2; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); @@ -556,7 +502,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); @@ -565,7 +511,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); @@ -580,7 +526,7 @@ static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, f static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const dim3 block_nums(block_num_y, 1, 1); const dim3 block_dims(32, ny, 1); diff --git a/ggml/src/ggml-sycl/dmmv.cpp b/ggml/src/ggml-sycl/dmmv.cpp index 5c343822f..7b1b317fa 100644 --- a/ggml/src/ggml-sycl/dmmv.cpp +++ b/ggml/src/ggml-sycl/dmmv.cpp @@ -124,9 +124,6 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, float *__restrict__ dst, const int ncols, int nrows, const sycl::nd_item<3> &item_ct1) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); if (row > nrows) return; @@ -140,16 +137,16 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, #if QK_K == 256 const int tid = - item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15 + item_ct1.get_local_id(2) / 2; // 0...15 const int ix = - item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + item_ct1.get_local_id(2) % 2; // 0,1 - const int step = 16/K_QUANTS_PER_ITERATION; + const int step = 8; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0...15 or 0...7 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 + const int l0 = 2*in; // 0...14 in steps of 2 const int q_offset = 32*im + l0; const int s_offset = 8*im; const int y_offset = 128*im + l0; @@ -158,7 +155,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, const uint8_t * d = (const uint8_t *)aux; const uint8_t * m = (const uint8_t *)(aux + 2); - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; @@ -173,7 +170,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, aux[3] = (a[1] >> 4) & 0x0f0f0f0f; float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) @@ -190,18 +187,15 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, } #else - const int tid = item_ct1.get_local_id(2) / - (2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = item_ct1.get_local_id(2) % - (2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3 - const int offset = tid * K_QUANTS_PER_ITERATION; + const int tid = item_ct1.get_local_id(2) / 4; // 0...7 + const int ix = item_ct1.get_local_id(2) % 4; // 0...3 + const int offset = tid * 2; uint32_t uaux[2]; const uint8_t * d = (const uint8_t *)uaux; - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { - + for (int i = ix; i < num_blocks_per_row; i += 4) { const float * y = yy + i * QK_K + offset; const uint8_t * q = x[i].qs + offset; const uint32_t * s = (const uint32_t *)x[i].scales; @@ -213,7 +207,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx, x[i].dm.convert(); float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { const uint8_t ql = q[l]; sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) + y[l+16] * d[1] * ((ql >> 2) & 3) @@ -268,14 +262,14 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, const uint16_t kmask2 = 0x0f0f; const int tid = - item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + item_ct1.get_local_id(2) / 2; // 0...16 const int ix = - item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + item_ct1.get_local_id(2) % 2; // 0,1 - const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop - const int step = 16/K_QUANTS_PER_ITERATION; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0....15 or 0...7 + const int n = 2; // iterations in the inner loop + const int step = 8; + const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... + const int in = tid - step*im; // 0....15 or 0...7 const uint8_t m = 1 << (4*im); @@ -288,7 +282,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, const uint16_t s_shift = 4*im; - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * q = x[i].qs + q_offset; @@ -318,13 +312,13 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, } #else - const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 - const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 - const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 - const int in = offset/8; // 0 or 1 - const int im = offset%8; // 0...7 + const int tid = item_ct1.get_local_id(2)/4; // 0...7 + const int ix = item_ct1.get_local_id(2)%4; // 0...3 + const int offset = tid * 2; // 0...14 + const int in = offset/8; // 0 or 1 + const int im = offset%8; // 0...7 - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 4) { const float * y = yy + i * QK_K + offset; const uint8_t * q = x[i].qs + offset; @@ -333,7 +327,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx, const float dall = (float)x[i].d; float sum = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { const uint8_t hl = x[i].hmask[im+l] >> in; const uint8_t ql = q[l]; sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) @@ -384,15 +378,15 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, const uint16_t kmask3 = 0xc0c0; const int tid = - item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + item_ct1.get_local_id(2) / 2; // 0...16 const int ix = - item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 + item_ct1.get_local_id(2) % 2; // 0,1 - const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + const int step = 4; const int il = tid/step; // 0...3 const int ir = tid - step*il; // 0...7 or 0...3 - const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + const int n = 4; const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int in = il%2; @@ -404,17 +398,12 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, uint16_t aux[4]; const uint8_t * sc = (const uint8_t *)aux; -#if K_QUANTS_PER_ITERATION == 2 uint32_t q32[4]; const uint8_t * q4 = (const uint8_t *)q32; -#else - uint16_t q16[4]; - const uint8_t * q4 = (const uint8_t *)q16; -#endif float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y1 = yy + i*QK_K + y_offset; const float * y2 = y1 + 128; @@ -428,7 +417,6 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); -#if K_QUANTS_PER_ITERATION == 2 const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); const uint32_t * q2 = q1 + 16; @@ -447,38 +435,19 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f + s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) - dmin * smin; -#else - const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); - const uint16_t * q2 = q1 + 32; - - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[0] & 0xf0f0; - q16[2] = q2[0] & 0x0f0f; - q16[3] = q2[0] & 0xf0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 2; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; - s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#endif - } #else - const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); + const int tid = item_ct1.get_local_id(2)/4; // 0...15 + const int ix = item_ct1.get_local_id(2)%4; - const int step = tid * K_QUANTS_PER_ITERATION; + const int step = tid * 2; uint16_t aux16[2]; const uint8_t * s = (const uint8_t *)aux16; float tmp = 0; - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 4) { const uint8_t * q = x[i].qs + step; const float * y = yy + i*QK_K + step; const uint16_t * a = (const uint16_t *)x[i].scales; @@ -487,7 +456,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx, const float d = (float)x[i].dm[0]; const float m = (float)x[i].dm[1]; float sum = 0.f; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + for (int j = 0; j < 2; ++j) { sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) @@ -609,19 +578,19 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx, } #else - const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 - const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); - const int step = tid * K_QUANTS_PER_ITERATION; + const int tid = item_ct1.get_local_id(2)/4; // 0...15 + const int ix = item_ct1.get_local_id(2)%4; + const int step = tid * 2; const int im = step/8; const int in = step%8; - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 4) { const uint8_t * q = x[i].qs + step; const int8_t * s = x[i].scales; const float * y = yy + i*QK_K + step; const float d = x[i].d; float sum = 0.f; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + for (int j = 0; j < 2; ++j) { const uint8_t h = x[i].qh[in+j] >> im; sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) @@ -646,9 +615,6 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx, static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, const sycl::nd_item<3> &item_ct1) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + item_ct1.get_local_id(1); if (row > nrows) return; @@ -661,22 +627,18 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa #if QK_K == 256 const int tid = - item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 + item_ct1.get_local_id(2) / 2; // 0...16 const int ix = - item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1 + item_ct1.get_local_id(2) % 2; // 0, 1 - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const int step = 8; const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int in = tid - step*im; // 0...15 or 0...7 -#if K_QUANTS_PER_ITERATION == 1 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 - const int is = 0; -#else const int l0 = 4 * in; // 0, 4, 8, ..., 28 const int is = in / 4; -#endif + const int ql_offset = 64*im + l0; const int qh_offset = 32*im + l0; const int s_offset = 8*im + is; @@ -684,7 +646,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 2) { const float * y = yy + i * QK_K + y_offset; const uint8_t * ql = x[i].ql + ql_offset; @@ -693,17 +655,6 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa const float d = x[i].d; -#if K_QUANTS_PER_ITERATION == 1 - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) - + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) - + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) - + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) - +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); - tmp += sum; -#else float sum = 0; for (int l = 0; l < 4; ++l) { sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) @@ -712,20 +663,18 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); } tmp += sum; -#endif - } #else - const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7 - const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3 + const int tid = item_ct1.get_local_id(2)/4; // 0...7 + const int ix = item_ct1.get_local_id(2)%4; // 0...3 - const int step = tid * K_QUANTS_PER_ITERATION; + const int step = tid * 2; float tmp = 0; // partial sum for thread in warp - for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { + for (int i = ix; i < num_blocks_per_row; i += 4) { const float * y = yy + i * QK_K + step; const uint8_t * ql = x[i].ql + step; @@ -735,7 +684,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa const float d = x[i+0].d; float sum = 0; - for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { + for (int j = 0; j < 2; ++j) { sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) @@ -871,7 +820,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 + const int ny = 2; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); @@ -887,7 +836,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); @@ -903,7 +852,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); @@ -932,7 +881,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y, const int nrows, dpct::queue_ptr stream) { GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; + const int ny = 1; const int block_num_y = (nrows + ny - 1) / ny; const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); diff --git a/ggml/src/ggml-sycl/presets.hpp b/ggml/src/ggml-sycl/presets.hpp index 340ab8e93..a3a584108 100644 --- a/ggml/src/ggml-sycl/presets.hpp +++ b/ggml/src/ggml-sycl/presets.hpp @@ -52,12 +52,6 @@ #define GGML_SYCL_MMV_Y 1 #endif -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 2 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - #ifndef GGML_SYCL_PEER_MAX_BATCH_SIZE #define GGML_SYCL_PEER_MAX_BATCH_SIZE 128 #endif // GGML_SYCL_PEER_MAX_BATCH_SIZE diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index 32fda32a8..9814e82b0 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -41,12 +41,6 @@ #define MAX_VK_BUFFERS 256 -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 1 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - #define VK_CHECK(err, msg) \ do { \ vk::Result err_ = (err); \ diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_base.comp b/ggml/src/vulkan-shaders/mul_mat_vec_base.comp index 5920bc936..cac4250fc 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_base.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_base.comp @@ -2,8 +2,6 @@ #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_8bit_storage : require -#define K_QUANTS_PER_ITERATION 2 - #ifdef MUL_MAT_ID #define EXPERT_COUNT 8 #endif diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp index ec8eadcd5..614f2bf8b 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -15,22 +15,22 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/2; // 0...16 + const uint ix = gl_LocalInvocationID.x%2; // 0, 1 - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 8; const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = tid - step*v_im; // 0...15 or 0...7 - const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 + const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; const uint s_offset = 8*v_im; const uint y_offset = 128*v_im + l0; tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { const uint y_idx = i * QUANT_K + y_offset; const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); @@ -38,7 +38,7 @@ void main() { FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3), diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp index 3ca4ad85a..5d37e24b7 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -15,17 +15,17 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/2; // 0...16 + const uint ix = gl_LocalInvocationID.x%2; // 0, 1 - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 8; const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = tid - step*v_im; // 0...15 or 0...7 const uint8_t m = uint8_t(1 << (4 * v_im)); - const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 + const uint l0 = 2*v_in; // 0...15 const uint q_offset = 32*v_im + l0; const uint y_offset = 128*v_im + l0; @@ -33,13 +33,13 @@ void main() { const uint s_shift = 4 * v_im; - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { const uint y_idx = i * QUANT_K + y_offset; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); FLOAT_TYPE sum = FLOAT_TYPE(0.0); - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { + for (int l = 0; l < 2; ++l) { sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp index d91e00e10..36b6e9460 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -15,14 +15,14 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/2; // 0...16 + const uint ix = gl_LocalInvocationID.x%2; // 0, 1 - const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 + const uint step = 4; const uint il = tid/step; // 0...3 const uint ir = tid - step*il; // 0...7 or 0...3 - const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 + const uint n = 4; const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_in = il % 2; @@ -33,7 +33,7 @@ void main() { tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { const uint y1_idx = i * QUANT_K + y_offset; const uint y2_idx = y1_idx + 128; @@ -49,7 +49,6 @@ void main() { const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); -#if K_QUANTS_PER_ITERATION == 2 const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf); @@ -78,30 +77,6 @@ void main() { fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7))))))))))))))); const uint tmp_idx = 16 * ix + tid; tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx])); -#else - const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); - const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); - const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4); - const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4); - const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf); - const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf); - const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4); - const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4); - - const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1); - const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3); - const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5); - const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7); - const FLOAT_TYPE smin = - fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7, - + fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7))))))); - - tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + - sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin); - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f), - fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx])); -#endif } // sum up partial sums and write back result diff --git a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp index 95c286eeb..122a82f1f 100644 --- a/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp +++ b/ggml/src/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -15,21 +15,16 @@ void main() { const uint num_blocks_per_row = p.ncols / QUANT_K; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; - const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 + const uint tid = gl_LocalInvocationID.x/2; // 0...16 + const uint ix = gl_LocalInvocationID.x%2; // 0, 1 - const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 + const uint step = 8; const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_in = tid - step*v_im; // 0...15 or 0...7 -#if K_QUANTS_PER_ITERATION == 1 - const uint l0 = v_in; // 0...15 - const uint is = 0; -#else const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 const uint is = v_in / 4; -#endif const uint ql_offset = 64*v_im + l0; const uint qh_offset = 32*v_im + l0; @@ -38,22 +33,11 @@ void main() { tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp - [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) { const uint y_idx = i * QUANT_K + y_offset; const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); -#if K_QUANTS_PER_ITERATION == 1 - const uint tmp_idx = 16 * ix + tid; - tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32), - fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx])))))))); -#else FLOAT_TYPE sum = FLOAT_TYPE(0.0); [[unroll]] for (int l = 0; l < 4; ++l) { sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32), @@ -62,7 +46,6 @@ void main() { fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum)))); } tmp[16 * ix + tid] += sum; -#endif } // sum up partial sums and write back result