ggml : remove k_quants_per_iteration macro

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-07-04 21:19:09 +03:00
parent 436787f170
commit e48fd74b45
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
12 changed files with 110 additions and 283 deletions

View File

@ -688,12 +688,6 @@ ifdef GGML_CUDA_DMMV_F16
MK_NVCCFLAGS += -DGGML_CUDA_F16 MK_NVCCFLAGS += -DGGML_CUDA_F16
endif # GGML_CUDA_DMMV_F16 endif # GGML_CUDA_DMMV_F16
ifdef GGML_CUDA_KQUANTS_ITER
MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER)
else
MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
endif
ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE
MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE)
else else
@ -810,7 +804,6 @@ ifdef GGML_HIPBLAS
GGML_CUDA_DMMV_X ?= 32 GGML_CUDA_DMMV_X ?= 32
GGML_CUDA_MMV_Y ?= 1 GGML_CUDA_MMV_Y ?= 1
GGML_CUDA_KQUANTS_ITER ?= 2
MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUDA MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUDA
@ -827,7 +820,6 @@ endif # GGML_HIP_UMA
HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS)) HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS))
HIPFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) HIPFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X)
HIPFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) HIPFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y)
HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER)
ifdef GGML_CUDA_FORCE_DMMV ifdef GGML_CUDA_FORCE_DMMV
HIPFLAGS += -DGGML_CUDA_FORCE_DMMV HIPFLAGS += -DGGML_CUDA_FORCE_DMMV

View File

@ -120,8 +120,6 @@ option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of
set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels") set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels")
set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels") set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels")
option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF)
set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING
"ggml: iters./thread per block for Q2_K/Q6_K")
set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
"ggml: max. batch size for using peer access") "ggml: max. batch size for using peer access")
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)

View File

@ -323,7 +323,6 @@ if (GGML_CUDA)
add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
if (GGML_CUDA_USE_GRAPHS) if (GGML_CUDA_USE_GRAPHS)
@ -471,7 +470,6 @@ if (GGML_HIPBLAS)
add_compile_definitions(GGML_USE_HIPBLAS) add_compile_definitions(GGML_USE_HIPBLAS)
add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
if (GGML_HIP_UMA) if (GGML_HIP_UMA)
add_compile_definitions(GGML_HIP_UMA) add_compile_definitions(GGML_HIP_UMA)

View File

@ -2,16 +2,7 @@
#include "dequantize.cuh" #include "dequantize.cuh"
#include "convert.cuh" #include "convert.cuh"
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 2
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return; if (row > nrows) return;
@ -22,15 +13,15 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 const int ix = threadIdx.x%2; // 0,1
const int step = 16/K_QUANTS_PER_ITERATION; const int step = 8;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7 const int in = tid - step*im; // 0...15 or 0...7
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 const int l0 = 2*in; // 0...14 in steps of 2
const int q_offset = 32*im + l0; const int q_offset = 32*im + l0;
const int s_offset = 8*im; const int s_offset = 8*im;
const int y_offset = 128*im + l0; const int y_offset = 128*im + l0;
@ -39,7 +30,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
const uint8_t * d = (const uint8_t *)aux; const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2); const uint8_t * m = (const uint8_t *)(aux + 2);
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 2) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset; const uint8_t * q = x[i].qs + q_offset;
@ -54,7 +45,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
aux[3] = (a[1] >> 4) & 0x0f0f0f0f; aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
float sum1 = 0, sum2 = 0; float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { for (int l = 0; l < 2; ++l) {
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
@ -94,11 +85,11 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
const uint16_t kmask1 = 0x0303; const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f; const uint16_t kmask2 = 0x0f0f;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int tid = threadIdx.x/2; // 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 const int ix = threadIdx.x%2; // 0,1
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop const int n = 2; // iterations in the inner loop
const int step = 16/K_QUANTS_PER_ITERATION; const int step = 8;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0....15 or 0...7 const int in = tid - step*im; // 0....15 or 0...7
@ -113,7 +104,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
const uint16_t s_shift = 4*im; const uint16_t s_shift = 4*im;
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 2) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset; const uint8_t * q = x[i].qs + q_offset;
@ -163,14 +154,14 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
const uint16_t kmask2 = 0x0f0f; const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0; const uint16_t kmask3 = 0xc0c0;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int tid = threadIdx.x/2; // 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 const int ix = threadIdx.x%2; // 0,1
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 const int step = 4;
const int il = tid/step; // 0...3 const int il = tid/step; // 0...3
const int ir = tid - step*il; // 0...7 or 0...3 const int ir = tid - step*il; // 0...7 or 0...3
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 const int n = 4;
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2; const int in = il%2;
@ -182,17 +173,12 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
uint16_t aux[4]; uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux; const uint8_t * sc = (const uint8_t *)aux;
#if K_QUANTS_PER_ITERATION == 2
uint32_t q32[4]; uint32_t q32[4];
const uint8_t * q4 = (const uint8_t *)q32; const uint8_t * q4 = (const uint8_t *)q32;
#else
uint16_t q16[4];
const uint8_t * q4 = (const uint8_t *)q16;
#endif
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 2) {
const float * y1 = yy + i*QK_K + y_offset; const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128; const float * y2 = y1 + 128;
@ -206,7 +192,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
#if K_QUANTS_PER_ITERATION == 2
const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
const uint32_t * q2 = q1 + 16; const uint32_t * q2 = q1 + 16;
@ -223,25 +208,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
} }
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#else
const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
const uint16_t * q2 = q1 + 32;
q16[0] = q1[0] & 0x0f0f;
q16[1] = q1[0] & 0xf0f0;
q16[2] = q2[0] & 0x0f0f;
q16[3] = q2[0] & 0xf0f0;
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < 2; ++l) {
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#endif
} }
// sum up partial sums and write back result // sum up partial sums and write back result
@ -341,9 +307,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
} }
static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return; if (row > nrows) return;
@ -352,21 +315,17 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
const block_q6_K * x = (const block_q6_K *)vx + ib0; const block_q6_K * x = (const block_q6_K *)vx + ib0;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const int tid = threadIdx.x/2; // 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 const int ix = threadIdx.x%2; // 0, 1
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 const int step = 8;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7 const int in = tid - step*im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
#else
const int l0 = 4 * in; // 0, 4, 8, ..., 28 const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4; const int is = in / 4;
#endif
const int ql_offset = 64*im + l0; const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0; const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is; const int s_offset = 8*im + is;
@ -374,7 +333,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 2) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * ql = x[i].ql + ql_offset; const uint8_t * ql = x[i].ql + ql_offset;
@ -383,17 +342,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
const float d = x[i].d; const float d = x[i].d;
#if K_QUANTS_PER_ITERATION == 1
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp += sum;
#else
float sum = 0; float sum = 0;
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
@ -402,8 +350,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
} }
tmp += sum; tmp += sum;
#endif
} }
// sum up partial sums and write back result // sum up partial sums and write back result
@ -547,7 +493,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 const int ny = 2;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1); const dim3 block_dims(32, ny, 1);
@ -556,7 +502,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION; const int ny = 1;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1); const dim3 block_dims(32, ny, 1);
@ -565,7 +511,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION; const int ny = 1;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1); const dim3 block_dims(32, ny, 1);
@ -580,7 +526,7 @@ static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, f
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION; const int ny = 1;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1); const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1); const dim3 block_dims(32, ny, 1);

View File

@ -124,9 +124,6 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
float *__restrict__ dst, float *__restrict__ dst,
const int ncols, int nrows, const int ncols, int nrows,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1); item_ct1.get_local_id(1);
if (row > nrows) return; if (row > nrows) return;
@ -140,16 +137,16 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
#if QK_K == 256 #if QK_K == 256
const int tid = const int tid =
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15 item_ct1.get_local_id(2) / 2; // 0...15
const int ix = const int ix =
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 item_ct1.get_local_id(2) % 2; // 0,1
const int step = 16/K_QUANTS_PER_ITERATION; const int step = 8;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7 const int in = tid - step*im; // 0...15 or 0...7
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 const int l0 = 2*in; // 0...14 in steps of 2
const int q_offset = 32*im + l0; const int q_offset = 32*im + l0;
const int s_offset = 8*im; const int s_offset = 8*im;
const int y_offset = 128*im + l0; const int y_offset = 128*im + l0;
@ -158,7 +155,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
const uint8_t * d = (const uint8_t *)aux; const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2); const uint8_t * m = (const uint8_t *)(aux + 2);
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 2) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset; const uint8_t * q = x[i].qs + q_offset;
@ -173,7 +170,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
aux[3] = (a[1] >> 4) & 0x0f0f0f0f; aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
float sum1 = 0, sum2 = 0; float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { for (int l = 0; l < 2; ++l) {
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
@ -190,18 +187,15 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
} }
#else #else
const int tid = item_ct1.get_local_id(2) / const int tid = item_ct1.get_local_id(2) / 4; // 0...7
(2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7 const int ix = item_ct1.get_local_id(2) % 4; // 0...3
const int ix = item_ct1.get_local_id(2) % const int offset = tid * 2;
(2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3
const int offset = tid * K_QUANTS_PER_ITERATION;
uint32_t uaux[2]; uint32_t uaux[2];
const uint8_t * d = (const uint8_t *)uaux; const uint8_t * d = (const uint8_t *)uaux;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 4) {
const float * y = yy + i * QK_K + offset; const float * y = yy + i * QK_K + offset;
const uint8_t * q = x[i].qs + offset; const uint8_t * q = x[i].qs + offset;
const uint32_t * s = (const uint32_t *)x[i].scales; const uint32_t * s = (const uint32_t *)x[i].scales;
@ -213,7 +207,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
x[i].dm.convert<float, sycl::rounding_mode::automatic>(); x[i].dm.convert<float, sycl::rounding_mode::automatic>();
float sum1 = 0, sum2 = 0; float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { for (int l = 0; l < 2; ++l) {
const uint8_t ql = q[l]; const uint8_t ql = q[l];
sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3) sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
+ y[l+16] * d[1] * ((ql >> 2) & 3) + y[l+16] * d[1] * ((ql >> 2) & 3)
@ -268,12 +262,12 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
const uint16_t kmask2 = 0x0f0f; const uint16_t kmask2 = 0x0f0f;
const int tid = const int tid =
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 item_ct1.get_local_id(2) / 2; // 0...16
const int ix = const int ix =
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 item_ct1.get_local_id(2) % 2; // 0,1
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop const int n = 2; // iterations in the inner loop
const int step = 16/K_QUANTS_PER_ITERATION; const int step = 8;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0....15 or 0...7 const int in = tid - step*im; // 0....15 or 0...7
@ -288,7 +282,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
const uint16_t s_shift = 4*im; const uint16_t s_shift = 4*im;
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 2) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset; const uint8_t * q = x[i].qs + q_offset;
@ -318,13 +312,13 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
} }
#else #else
const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7 const int tid = item_ct1.get_local_id(2)/4; // 0...7
const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3 const int ix = item_ct1.get_local_id(2)%4; // 0...3
const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14 const int offset = tid * 2; // 0...14
const int in = offset/8; // 0 or 1 const int in = offset/8; // 0 or 1
const int im = offset%8; // 0...7 const int im = offset%8; // 0...7
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 4) {
const float * y = yy + i * QK_K + offset; const float * y = yy + i * QK_K + offset;
const uint8_t * q = x[i].qs + offset; const uint8_t * q = x[i].qs + offset;
@ -333,7 +327,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
const float dall = (float)x[i].d; const float dall = (float)x[i].d;
float sum = 0; float sum = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { for (int l = 0; l < 2; ++l) {
const uint8_t hl = x[i].hmask[im+l] >> in; const uint8_t hl = x[i].hmask[im+l] >> in;
const uint8_t ql = q[l]; const uint8_t ql = q[l];
sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4)) sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
@ -384,15 +378,15 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
const uint16_t kmask3 = 0xc0c0; const uint16_t kmask3 = 0xc0c0;
const int tid = const int tid =
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 item_ct1.get_local_id(2) / 2; // 0...16
const int ix = const int ix =
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1 item_ct1.get_local_id(2) % 2; // 0,1
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 const int step = 4;
const int il = tid/step; // 0...3 const int il = tid/step; // 0...3
const int ir = tid - step*il; // 0...7 or 0...3 const int ir = tid - step*il; // 0...7 or 0...3
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 const int n = 4;
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2; const int in = il%2;
@ -404,17 +398,12 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
uint16_t aux[4]; uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux; const uint8_t * sc = (const uint8_t *)aux;
#if K_QUANTS_PER_ITERATION == 2
uint32_t q32[4]; uint32_t q32[4];
const uint8_t * q4 = (const uint8_t *)q32; const uint8_t * q4 = (const uint8_t *)q32;
#else
uint16_t q16[4];
const uint8_t * q4 = (const uint8_t *)q16;
#endif
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 2) {
const float * y1 = yy + i*QK_K + y_offset; const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128; const float * y2 = y1 + 128;
@ -428,7 +417,6 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
#if K_QUANTS_PER_ITERATION == 2
const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
const uint32_t * q2 = q1 + 16; const uint32_t * q2 = q1 + 16;
@ -447,38 +435,19 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f + tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f +
s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) - s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) -
dmin * smin; dmin * smin;
#else
const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
const uint16_t * q2 = q1 + 32;
q16[0] = q1[0] & 0x0f0f;
q16[1] = q1[0] & 0xf0f0;
q16[2] = q2[0] & 0x0f0f;
q16[3] = q2[0] & 0xf0f0;
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < 2; ++l) {
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#endif
} }
#else #else
const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 const int tid = item_ct1.get_local_id(2)/4; // 0...15
const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); const int ix = item_ct1.get_local_id(2)%4;
const int step = tid * K_QUANTS_PER_ITERATION; const int step = tid * 2;
uint16_t aux16[2]; uint16_t aux16[2];
const uint8_t * s = (const uint8_t *)aux16; const uint8_t * s = (const uint8_t *)aux16;
float tmp = 0; float tmp = 0;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 4) {
const uint8_t * q = x[i].qs + step; const uint8_t * q = x[i].qs + step;
const float * y = yy + i*QK_K + step; const float * y = yy + i*QK_K + step;
const uint16_t * a = (const uint16_t *)x[i].scales; const uint16_t * a = (const uint16_t *)x[i].scales;
@ -487,7 +456,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
const float d = (float)x[i].dm[0]; const float d = (float)x[i].dm[0];
const float m = (float)x[i].dm[1]; const float m = (float)x[i].dm[1];
float sum = 0.f; float sum = 0.f;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { for (int j = 0; j < 2; ++j) {
sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+ y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2]) + y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+ y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3]) + y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
@ -609,19 +578,19 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
} }
#else #else
const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 const int tid = item_ct1.get_local_id(2)/4; // 0...15
const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); const int ix = item_ct1.get_local_id(2)%4;
const int step = tid * K_QUANTS_PER_ITERATION; const int step = tid * 2;
const int im = step/8; const int im = step/8;
const int in = step%8; const int in = step%8;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 4) {
const uint8_t * q = x[i].qs + step; const uint8_t * q = x[i].qs + step;
const int8_t * s = x[i].scales; const int8_t * s = x[i].scales;
const float * y = yy + i*QK_K + step; const float * y = yy + i*QK_K + step;
const float d = x[i].d; const float d = x[i].d;
float sum = 0.f; float sum = 0.f;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { for (int j = 0; j < 2; ++j) {
const uint8_t h = x[i].qh[in+j] >> im; const uint8_t h = x[i].qh[in+j] >> im;
sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16)) sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
+ y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16)) + y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
@ -646,9 +615,6 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows, static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows,
const sycl::nd_item<3> &item_ct1) { const sycl::nd_item<3> &item_ct1) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) + const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
item_ct1.get_local_id(1); item_ct1.get_local_id(1);
if (row > nrows) return; if (row > nrows) return;
@ -661,22 +627,18 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
#if QK_K == 256 #if QK_K == 256
const int tid = const int tid =
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16 item_ct1.get_local_id(2) / 2; // 0...16
const int ix = const int ix =
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1 item_ct1.get_local_id(2) % 2; // 0, 1
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 const int step = 8;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7 const int in = tid - step*im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
#else
const int l0 = 4 * in; // 0, 4, 8, ..., 28 const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4; const int is = in / 4;
#endif
const int ql_offset = 64*im + l0; const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0; const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is; const int s_offset = 8*im + is;
@ -684,7 +646,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 2) {
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * ql = x[i].ql + ql_offset; const uint8_t * ql = x[i].ql + ql_offset;
@ -693,17 +655,6 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
const float d = x[i].d; const float d = x[i].d;
#if K_QUANTS_PER_ITERATION == 1
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp += sum;
#else
float sum = 0; float sum = 0;
for (int l = 0; l < 4; ++l) { for (int l = 0; l < 4; ++l) {
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
@ -712,20 +663,18 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
} }
tmp += sum; tmp += sum;
#endif
} }
#else #else
const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7 const int tid = item_ct1.get_local_id(2)/4; // 0...7
const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3 const int ix = item_ct1.get_local_id(2)%4; // 0...3
const int step = tid * K_QUANTS_PER_ITERATION; const int step = tid * 2;
float tmp = 0; // partial sum for thread in warp float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) { for (int i = ix; i < num_blocks_per_row; i += 4) {
const float * y = yy + i * QK_K + step; const float * y = yy + i * QK_K + step;
const uint8_t * ql = x[i].ql + step; const uint8_t * ql = x[i].ql + step;
@ -735,7 +684,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
const float d = x[i+0].d; const float d = x[i+0].d;
float sum = 0; float sum = 0;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { for (int j = 0; j < 2; ++j) {
sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32) sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+ y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32) + y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+ y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32) + y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
@ -871,7 +820,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 const int ny = 2;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
@ -887,7 +836,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION; const int ny = 1;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
@ -903,7 +852,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION; const int ny = 1;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);
@ -932,7 +881,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
const int nrows, const int nrows,
dpct::queue_ptr stream) { dpct::queue_ptr stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION; const int ny = 1;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const sycl::range<3> block_nums(1, 1, block_num_y); const sycl::range<3> block_nums(1, 1, block_num_y);
const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE); const sycl::range<3> block_dims(1, ny, QK_WARP_SIZE);

View File

@ -52,12 +52,6 @@
#define GGML_SYCL_MMV_Y 1 #define GGML_SYCL_MMV_Y 1
#endif #endif
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 2
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
#ifndef GGML_SYCL_PEER_MAX_BATCH_SIZE #ifndef GGML_SYCL_PEER_MAX_BATCH_SIZE
#define GGML_SYCL_PEER_MAX_BATCH_SIZE 128 #define GGML_SYCL_PEER_MAX_BATCH_SIZE 128
#endif // GGML_SYCL_PEER_MAX_BATCH_SIZE #endif // GGML_SYCL_PEER_MAX_BATCH_SIZE

View File

@ -41,12 +41,6 @@
#define MAX_VK_BUFFERS 256 #define MAX_VK_BUFFERS 256
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 1
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
#define VK_CHECK(err, msg) \ #define VK_CHECK(err, msg) \
do { \ do { \
vk::Result err_ = (err); \ vk::Result err_ = (err); \

View File

@ -2,8 +2,6 @@
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require #extension GL_EXT_shader_8bit_storage : require
#define K_QUANTS_PER_ITERATION 2
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
#define EXPERT_COUNT 8 #define EXPERT_COUNT 8
#endif #endif

View File

@ -15,22 +15,22 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const uint tid = gl_LocalInvocationID.x/2; // 0...16
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 const uint ix = gl_LocalInvocationID.x%2; // 0, 1
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 const uint step = 8;
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = tid - step*v_im; // 0...15 or 0...7 const uint v_in = tid - step*v_im; // 0...15 or 0...7
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0; const uint q_offset = 32*v_im + l0;
const uint s_offset = 8*v_im; const uint s_offset = 8*v_im;
const uint y_offset = 128*v_im + l0; const uint y_offset = 128*v_im + l0;
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
const uint y_idx = i * QUANT_K + y_offset; const uint y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x); const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
@ -38,7 +38,7 @@ void main() {
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3), sum1 = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]), FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3),

View File

@ -15,17 +15,17 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const uint tid = gl_LocalInvocationID.x/2; // 0...16
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 const uint ix = gl_LocalInvocationID.x%2; // 0, 1
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 const uint step = 8;
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = tid - step*v_im; // 0...15 or 0...7 const uint v_in = tid - step*v_im; // 0...15 or 0...7
const uint8_t m = uint8_t(1 << (4 * v_im)); const uint8_t m = uint8_t(1 << (4 * v_im));
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15 const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0; const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0; const uint y_offset = 128*v_im + l0;
@ -33,13 +33,13 @@ void main() {
const uint s_shift = 4 * v_im; const uint s_shift = 4 * v_im;
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
const uint y_idx = i * QUANT_K + y_offset; const uint y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
FLOAT_TYPE sum = FLOAT_TYPE(0.0); FLOAT_TYPE sum = FLOAT_TYPE(0.0);
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)), sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)), fma(FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),

View File

@ -15,14 +15,14 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const uint tid = gl_LocalInvocationID.x/2; // 0...16
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 const uint ix = gl_LocalInvocationID.x%2; // 0, 1
const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 const uint step = 4;
const uint il = tid/step; // 0...3 const uint il = tid/step; // 0...3
const uint ir = tid - step*il; // 0...7 or 0...3 const uint ir = tid - step*il; // 0...7 or 0...3
const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 const uint n = 4;
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const uint v_in = il % 2; const uint v_in = il % 2;
@ -33,7 +33,7 @@ void main() {
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
const uint y1_idx = i * QUANT_K + y_offset; const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128; const uint y2_idx = y1_idx + 128;
@ -49,7 +49,6 @@ void main() {
const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2)); const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2)); const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
#if K_QUANTS_PER_ITERATION == 2
const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf); const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf); const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf); const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf);
@ -78,30 +77,6 @@ void main() {
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7))))))))))))))); fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7)))))))))))))));
const uint tmp_idx = 16 * ix + tid; const uint tmp_idx = 16 * ix + tid;
tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx])); tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
#else
const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1);
const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5);
const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
+ fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7)))))));
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) +
sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
const uint tmp_idx = 16 * ix + tid;
tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f),
fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx]));
#endif
} }
// sum up partial sums and write back result // sum up partial sums and write back result

View File

@ -15,21 +15,16 @@ void main() {
const uint num_blocks_per_row = p.ncols / QUANT_K; const uint num_blocks_per_row = p.ncols / QUANT_K;
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row; const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 const uint tid = gl_LocalInvocationID.x/2; // 0...16
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 const uint ix = gl_LocalInvocationID.x%2; // 0, 1
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 const uint step = 8;
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = tid - step*v_im; // 0...15 or 0...7 const uint v_in = tid - step*v_im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const uint l0 = v_in; // 0...15
const uint is = 0;
#else
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
const uint is = v_in / 4; const uint is = v_in / 4;
#endif
const uint ql_offset = 64*v_im + l0; const uint ql_offset = 64*v_im + l0;
const uint qh_offset = 32*v_im + l0; const uint qh_offset = 32*v_im + l0;
@ -38,22 +33,11 @@ void main() {
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
const uint y_idx = i * QUANT_K + y_offset; const uint y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
#if K_QUANTS_PER_ITERATION == 1
const uint tmp_idx = 16 * ix + tid;
tmp[tmp_idx] = fma(FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32),
fma(FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32), tmp[tmp_idx]))))))));
#else
FLOAT_TYPE sum = FLOAT_TYPE(0.0); FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 4; ++l) { [[unroll]] for (int l = 0; l < 4; ++l) {
sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32), sum = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32),
@ -62,7 +46,6 @@ void main() {
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum)))); fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), sum))));
} }
tmp[16 * ix + tid] += sum; tmp[16 * ix + tid] += sum;
#endif
} }
// sum up partial sums and write back result // sum up partial sums and write back result