mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
parent
91deef4606
commit
117f7adbd9
8
Makefile
8
Makefile
@ -640,12 +640,6 @@ ifdef GGML_CUDA_DMMV_F16
|
||||
MK_NVCCFLAGS += -DGGML_CUDA_F16
|
||||
endif # GGML_CUDA_DMMV_F16
|
||||
|
||||
ifdef GGML_CUDA_KQUANTS_ITER
|
||||
MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER)
|
||||
else
|
||||
MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
|
||||
endif
|
||||
|
||||
ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE
|
||||
MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE)
|
||||
else
|
||||
@ -734,7 +728,6 @@ ifdef GGML_HIPBLAS
|
||||
|
||||
GGML_CUDA_DMMV_X ?= 32
|
||||
GGML_CUDA_MMV_Y ?= 1
|
||||
GGML_CUDA_KQUANTS_ITER ?= 2
|
||||
|
||||
MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUDA
|
||||
|
||||
@ -751,7 +744,6 @@ endif # GGML_HIP_UMA
|
||||
HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS))
|
||||
HIPFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X)
|
||||
HIPFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y)
|
||||
HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER)
|
||||
|
||||
ifdef GGML_CUDA_FORCE_DMMV
|
||||
HIPFLAGS += -DGGML_CUDA_FORCE_DMMV
|
||||
|
@ -521,7 +521,6 @@ Building the program with BLAS support may lead to some performance improvements
|
||||
| GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. |
|
||||
| GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models |
|
||||
| GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
|
||||
| GGML_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
|
||||
| GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. |
|
||||
| GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. |
|
||||
|
||||
@ -582,7 +581,6 @@ Building the program with BLAS support may lead to some performance improvements
|
||||
|------------------------|------------------------|---------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| GGML_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the HIP dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
|
||||
| GGML_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the HIP mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
|
||||
| GGML_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per HIP thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
|
||||
|
||||
- #### Vulkan
|
||||
|
||||
|
@ -113,8 +113,6 @@ option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of
|
||||
set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels")
|
||||
set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels")
|
||||
option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF)
|
||||
set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING
|
||||
"ggml: iters./thread per block for Q2_K/Q6_K")
|
||||
set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING
|
||||
"ggml: max. batch size for using peer access")
|
||||
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
|
||||
|
@ -297,7 +297,6 @@ if (GGML_CUDA)
|
||||
|
||||
add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
|
||||
add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
|
||||
add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
|
||||
add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
|
||||
|
||||
if (GGML_CUDA_USE_GRAPHS)
|
||||
@ -426,7 +425,6 @@ if (GGML_HIPBLAS)
|
||||
add_compile_definitions(GGML_USE_HIPBLAS)
|
||||
add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X})
|
||||
add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y})
|
||||
add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER})
|
||||
|
||||
if (GGML_HIP_UMA)
|
||||
add_compile_definitions(GGML_HIP_UMA)
|
||||
|
@ -2,16 +2,7 @@
|
||||
#include "dequantize.cuh"
|
||||
#include "convert.cuh"
|
||||
|
||||
#ifndef K_QUANTS_PER_ITERATION
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
#else
|
||||
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
||||
#endif
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
@ -22,15 +13,15 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
|
||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||
const int tid = threadIdx.x/2; // 0...15
|
||||
const int ix = threadIdx.x%2; // 0,1
|
||||
|
||||
const int step = 16/K_QUANTS_PER_ITERATION;
|
||||
const int step = 8;
|
||||
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0...15 or 0...7
|
||||
|
||||
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
|
||||
const int l0 = 2*in; // 0...14 in steps of 2
|
||||
const int q_offset = 32*im + l0;
|
||||
const int s_offset = 8*im;
|
||||
const int y_offset = 128*im + l0;
|
||||
@ -39,7 +30,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
||||
const uint8_t * d = (const uint8_t *)aux;
|
||||
const uint8_t * m = (const uint8_t *)(aux + 2);
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
||||
|
||||
const float * y = yy + i * QK_K + y_offset;
|
||||
const uint8_t * q = x[i].qs + q_offset;
|
||||
@ -54,7 +45,7 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
|
||||
aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
|
||||
|
||||
float sum1 = 0, sum2 = 0;
|
||||
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
|
||||
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
|
||||
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
|
||||
@ -94,17 +85,17 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
|
||||
const uint16_t kmask1 = 0x0303;
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
|
||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||
const int tid = threadIdx.x/2; // 0...16
|
||||
const int ix = threadIdx.x%2; // 0,1
|
||||
|
||||
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
|
||||
const int step = 16/K_QUANTS_PER_ITERATION;
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0....15 or 0...7
|
||||
const int n = 2; // iterations in the inner loop
|
||||
const int step = 8;
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0....15 or 0...7
|
||||
|
||||
const uint8_t m = 1 << (4*im);
|
||||
|
||||
const int l0 = n*in; // 0...15 or 0...14 in steps of 2
|
||||
const int l0 = n*in; // 0...15 or 0...14 in steps of 2
|
||||
const int q_offset = 32*im + l0;
|
||||
const int y_offset = 128*im + l0;
|
||||
|
||||
@ -113,7 +104,7 @@ static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx,
|
||||
|
||||
const uint16_t s_shift = 4*im;
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
||||
|
||||
const float * y = yy + i * QK_K + y_offset;
|
||||
const uint8_t * q = x[i].qs + q_offset;
|
||||
@ -163,14 +154,14 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
const uint16_t kmask3 = 0xc0c0;
|
||||
|
||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||
const int tid = threadIdx.x/2; // 0...16
|
||||
const int ix = threadIdx.x%2; // 0,1
|
||||
|
||||
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
|
||||
const int step = 4;
|
||||
|
||||
const int il = tid/step; // 0...3
|
||||
const int ir = tid - step*il; // 0...7 or 0...3
|
||||
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
|
||||
const int il = tid/step; // 0...3
|
||||
const int ir = tid - step*il; // 0...7 or 0...3
|
||||
const int n = 4;
|
||||
|
||||
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
||||
const int in = il%2;
|
||||
@ -182,17 +173,12 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
|
||||
uint16_t aux[4];
|
||||
const uint8_t * sc = (const uint8_t *)aux;
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 2
|
||||
uint32_t q32[4];
|
||||
const uint8_t * q4 = (const uint8_t *)q32;
|
||||
#else
|
||||
uint16_t q16[4];
|
||||
const uint8_t * q4 = (const uint8_t *)q16;
|
||||
#endif
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
||||
|
||||
const float * y1 = yy + i*QK_K + y_offset;
|
||||
const float * y2 = y1 + 128;
|
||||
@ -206,7 +192,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
|
||||
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
|
||||
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 2
|
||||
const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
|
||||
const uint32_t * q2 = q1 + 16;
|
||||
|
||||
@ -223,25 +208,6 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
|
||||
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
|
||||
}
|
||||
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
|
||||
#else
|
||||
const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
|
||||
const uint16_t * q2 = q1 + 32;
|
||||
|
||||
q16[0] = q1[0] & 0x0f0f;
|
||||
q16[1] = q1[0] & 0xf0f0;
|
||||
q16[2] = q2[0] & 0x0f0f;
|
||||
q16[3] = q2[0] & 0xf0f0;
|
||||
|
||||
float4 s = {0.f, 0.f, 0.f, 0.f};
|
||||
float smin = 0;
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
|
||||
s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
|
||||
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
|
||||
}
|
||||
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
@ -341,9 +307,6 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||
if (row > nrows) return;
|
||||
|
||||
@ -352,21 +315,17 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
||||
|
||||
const block_q6_K * x = (const block_q6_K *)vx + ib0;
|
||||
|
||||
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||
const int tid = threadIdx.x/2; // 0...16
|
||||
const int ix = threadIdx.x%2; // 0, 1
|
||||
|
||||
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
||||
const int step = 8;
|
||||
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0...15 or 0...7
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0...15 or 0...7
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
|
||||
const int is = 0;
|
||||
#else
|
||||
const int l0 = 4 * in; // 0, 4, 8, ..., 28
|
||||
const int l0 = 4 * in; // 0, 4, 8, ..., 28
|
||||
const int is = in / 4;
|
||||
#endif
|
||||
|
||||
const int ql_offset = 64*im + l0;
|
||||
const int qh_offset = 32*im + l0;
|
||||
const int s_offset = 8*im + is;
|
||||
@ -374,7 +333,7 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
||||
|
||||
const float * y = yy + i * QK_K + y_offset;
|
||||
const uint8_t * ql = x[i].ql + ql_offset;
|
||||
@ -383,17 +342,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
|
||||
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
|
||||
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
|
||||
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
|
||||
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
|
||||
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
|
||||
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
|
||||
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
|
||||
tmp += sum;
|
||||
#else
|
||||
float sum = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
|
||||
@ -402,8 +350,6 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
|
||||
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
tmp += sum;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
@ -547,7 +493,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y,
|
||||
|
||||
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
|
||||
const int ny = 2;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
@ -556,7 +502,7 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
|
||||
|
||||
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int ny = 1;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
@ -565,7 +511,7 @@ static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, f
|
||||
|
||||
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int ny = 1;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
@ -580,7 +526,7 @@ static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, f
|
||||
|
||||
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int ny = 1;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const dim3 block_nums(block_num_y, 1, 1);
|
||||
const dim3 block_dims(32, ny, 1);
|
||||
|
@ -123,9 +123,6 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
|
||||
float *__restrict__ dst,
|
||||
const int ncols, int nrows,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
if (row > nrows) return;
|
||||
@ -139,16 +136,16 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
|
||||
|
||||
#if QK_K == 256
|
||||
const int tid =
|
||||
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...15
|
||||
item_ct1.get_local_id(2) / 2; // 0...15
|
||||
const int ix =
|
||||
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||
item_ct1.get_local_id(2) % 2; // 0,1
|
||||
|
||||
const int step = 16/K_QUANTS_PER_ITERATION;
|
||||
const int step = 8;
|
||||
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0...15 or 0...7
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0...15 or 0...7
|
||||
|
||||
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
|
||||
const int l0 = 2*in; // 0...14 in steps of 2
|
||||
const int q_offset = 32*im + l0;
|
||||
const int s_offset = 8*im;
|
||||
const int y_offset = 128*im + l0;
|
||||
@ -157,7 +154,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
|
||||
const uint8_t * d = (const uint8_t *)aux;
|
||||
const uint8_t * m = (const uint8_t *)(aux + 2);
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
||||
|
||||
const float * y = yy + i * QK_K + y_offset;
|
||||
const uint8_t * q = x[i].qs + q_offset;
|
||||
@ -172,7 +169,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
|
||||
aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
|
||||
|
||||
float sum1 = 0, sum2 = 0;
|
||||
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
|
||||
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
|
||||
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
|
||||
@ -189,18 +186,15 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
|
||||
|
||||
}
|
||||
#else
|
||||
const int tid = item_ct1.get_local_id(2) /
|
||||
(2 * K_QUANTS_PER_ITERATION); // 0...15 or 0...7
|
||||
const int ix = item_ct1.get_local_id(2) %
|
||||
(2 * K_QUANTS_PER_ITERATION); // 0....1 or 0...3
|
||||
const int offset = tid * K_QUANTS_PER_ITERATION;
|
||||
const int tid = item_ct1.get_local_id(2) / 4; // 0...7
|
||||
const int ix = item_ct1.get_local_id(2) % 4; // 0...3
|
||||
const int offset = tid * 2;
|
||||
|
||||
uint32_t uaux[2];
|
||||
const uint8_t * d = (const uint8_t *)uaux;
|
||||
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += 4) {
|
||||
const float * y = yy + i * QK_K + offset;
|
||||
const uint8_t * q = x[i].qs + offset;
|
||||
const uint32_t * s = (const uint32_t *)x[i].scales;
|
||||
@ -212,7 +206,7 @@ static void dequantize_mul_mat_vec_q2_k(const void *__restrict__ vx,
|
||||
x[i].dm.convert<float, sycl::rounding_mode::automatic>();
|
||||
|
||||
float sum1 = 0, sum2 = 0;
|
||||
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
const uint8_t ql = q[l];
|
||||
sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
|
||||
+ y[l+16] * d[1] * ((ql >> 2) & 3)
|
||||
@ -267,14 +261,14 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
|
||||
const int tid =
|
||||
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
item_ct1.get_local_id(2) / 2; // 0...16
|
||||
const int ix =
|
||||
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||
item_ct1.get_local_id(2) % 2; // 0,1
|
||||
|
||||
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
|
||||
const int step = 16/K_QUANTS_PER_ITERATION;
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0....15 or 0...7
|
||||
const int n = 2; // iterations in the inner loop
|
||||
const int step = 8;
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0....15 or 0...7
|
||||
|
||||
const uint8_t m = 1 << (4*im);
|
||||
|
||||
@ -287,7 +281,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
|
||||
|
||||
const uint16_t s_shift = 4*im;
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
||||
|
||||
const float * y = yy + i * QK_K + y_offset;
|
||||
const uint8_t * q = x[i].qs + q_offset;
|
||||
@ -317,13 +311,13 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
|
||||
}
|
||||
#else
|
||||
|
||||
const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
|
||||
const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
|
||||
const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
|
||||
const int in = offset/8; // 0 or 1
|
||||
const int im = offset%8; // 0...7
|
||||
const int tid = item_ct1.get_local_id(2)/4; // 0...7
|
||||
const int ix = item_ct1.get_local_id(2)%4; // 0...3
|
||||
const int offset = tid * 2; // 0...14
|
||||
const int in = offset/8; // 0 or 1
|
||||
const int im = offset%8; // 0...7
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 4) {
|
||||
|
||||
const float * y = yy + i * QK_K + offset;
|
||||
const uint8_t * q = x[i].qs + offset;
|
||||
@ -332,7 +326,7 @@ static void dequantize_mul_mat_vec_q3_k(const void *__restrict__ vx,
|
||||
const float dall = (float)x[i].d;
|
||||
|
||||
float sum = 0;
|
||||
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
const uint8_t hl = x[i].hmask[im+l] >> in;
|
||||
const uint8_t ql = q[l];
|
||||
sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
|
||||
@ -383,15 +377,15 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
|
||||
const uint16_t kmask3 = 0xc0c0;
|
||||
|
||||
const int tid =
|
||||
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
item_ct1.get_local_id(2) / 2; // 0...16
|
||||
const int ix =
|
||||
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0,1
|
||||
item_ct1.get_local_id(2) % 2; // 0,1
|
||||
|
||||
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
|
||||
const int step = 4;
|
||||
|
||||
const int il = tid/step; // 0...3
|
||||
const int ir = tid - step*il; // 0...7 or 0...3
|
||||
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
|
||||
const int n = 4;
|
||||
|
||||
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
||||
const int in = il%2;
|
||||
@ -403,17 +397,12 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
|
||||
uint16_t aux[4];
|
||||
const uint8_t * sc = (const uint8_t *)aux;
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 2
|
||||
uint32_t q32[4];
|
||||
const uint8_t * q4 = (const uint8_t *)q32;
|
||||
#else
|
||||
uint16_t q16[4];
|
||||
const uint8_t * q4 = (const uint8_t *)q16;
|
||||
#endif
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
||||
|
||||
const float * y1 = yy + i*QK_K + y_offset;
|
||||
const float * y2 = y1 + 128;
|
||||
@ -427,7 +416,6 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
|
||||
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
|
||||
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 2
|
||||
const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
|
||||
const uint32_t * q2 = q1 + 16;
|
||||
|
||||
@ -446,38 +434,19 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
|
||||
tmp += dall * (s.x() * sc[0] + s.y() * sc[1] * 1.f / 16.f +
|
||||
s.z() * sc[4] + s.w() * sc[5] * 1.f / 16.f) -
|
||||
dmin * smin;
|
||||
#else
|
||||
const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
|
||||
const uint16_t * q2 = q1 + 32;
|
||||
|
||||
q16[0] = q1[0] & 0x0f0f;
|
||||
q16[1] = q1[0] & 0xf0f0;
|
||||
q16[2] = q2[0] & 0x0f0f;
|
||||
q16[3] = q2[0] & 0xf0f0;
|
||||
|
||||
float4 s = {0.f, 0.f, 0.f, 0.f};
|
||||
float smin = 0;
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
|
||||
s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
|
||||
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
|
||||
}
|
||||
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
|
||||
#endif
|
||||
|
||||
}
|
||||
#else
|
||||
const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
|
||||
const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
|
||||
const int tid = item_ct1.get_local_id(2)/4; // 0...15
|
||||
const int ix = item_ct1.get_local_id(2)%4;
|
||||
|
||||
const int step = tid * K_QUANTS_PER_ITERATION;
|
||||
const int step = tid * 2;
|
||||
|
||||
uint16_t aux16[2];
|
||||
const uint8_t * s = (const uint8_t *)aux16;
|
||||
|
||||
float tmp = 0;
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 4) {
|
||||
const uint8_t * q = x[i].qs + step;
|
||||
const float * y = yy + i*QK_K + step;
|
||||
const uint16_t * a = (const uint16_t *)x[i].scales;
|
||||
@ -486,7 +455,7 @@ static void dequantize_mul_mat_vec_q4_k(const void *__restrict__ vx,
|
||||
const float d = (float)x[i].dm[0];
|
||||
const float m = (float)x[i].dm[1];
|
||||
float sum = 0.f;
|
||||
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
|
||||
+ y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
|
||||
+ y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
|
||||
@ -608,19 +577,19 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
|
||||
}
|
||||
|
||||
#else
|
||||
const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...15
|
||||
const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION);
|
||||
const int step = tid * K_QUANTS_PER_ITERATION;
|
||||
const int tid = item_ct1.get_local_id(2)/4; // 0...15
|
||||
const int ix = item_ct1.get_local_id(2)%4;
|
||||
const int step = tid * 2;
|
||||
const int im = step/8;
|
||||
const int in = step%8;
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 4) {
|
||||
const uint8_t * q = x[i].qs + step;
|
||||
const int8_t * s = x[i].scales;
|
||||
const float * y = yy + i*QK_K + step;
|
||||
const float d = x[i].d;
|
||||
float sum = 0.f;
|
||||
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
const uint8_t h = x[i].qh[in+j] >> im;
|
||||
sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
|
||||
+ y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
|
||||
@ -645,9 +614,6 @@ static void dequantize_mul_mat_vec_q5_k(const void *__restrict__ vx,
|
||||
|
||||
static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows,
|
||||
const sycl::nd_item<3> &item_ct1) {
|
||||
|
||||
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
|
||||
|
||||
const int row = item_ct1.get_group(2) * item_ct1.get_local_range(1) +
|
||||
item_ct1.get_local_id(1);
|
||||
if (row > nrows) return;
|
||||
@ -660,22 +626,18 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
||||
#if QK_K == 256
|
||||
|
||||
const int tid =
|
||||
item_ct1.get_local_id(2) / K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
item_ct1.get_local_id(2) / 2; // 0...16
|
||||
const int ix =
|
||||
item_ct1.get_local_id(2) % K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||
item_ct1.get_local_id(2) % 2; // 0, 1
|
||||
|
||||
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
||||
const int step = 8;
|
||||
|
||||
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const int in = tid - step*im; // 0...15 or 0...7
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
|
||||
const int is = 0;
|
||||
#else
|
||||
const int l0 = 4 * in; // 0, 4, 8, ..., 28
|
||||
const int is = in / 4;
|
||||
#endif
|
||||
|
||||
const int ql_offset = 64*im + l0;
|
||||
const int qh_offset = 32*im + l0;
|
||||
const int s_offset = 8*im + is;
|
||||
@ -683,7 +645,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2) {
|
||||
|
||||
const float * y = yy + i * QK_K + y_offset;
|
||||
const uint8_t * ql = x[i].ql + ql_offset;
|
||||
@ -692,17 +654,6 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
||||
|
||||
const float d = x[i].d;
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
|
||||
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
|
||||
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
|
||||
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
|
||||
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
|
||||
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
|
||||
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
|
||||
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
|
||||
tmp += sum;
|
||||
#else
|
||||
float sum = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
|
||||
@ -711,20 +662,18 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
||||
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
tmp += sum;
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
const int tid = item_ct1.get_local_id(2)/(2*K_QUANTS_PER_ITERATION); // 0...7
|
||||
const int ix = item_ct1.get_local_id(2)%(2*K_QUANTS_PER_ITERATION); // 0...3
|
||||
const int tid = item_ct1.get_local_id(2)/4; // 0...7
|
||||
const int ix = item_ct1.get_local_id(2)%4; // 0...3
|
||||
|
||||
const int step = tid * K_QUANTS_PER_ITERATION;
|
||||
const int step = tid * 2;
|
||||
|
||||
float tmp = 0; // partial sum for thread in warp
|
||||
|
||||
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
|
||||
for (int i = ix; i < num_blocks_per_row; i += 4) {
|
||||
|
||||
const float * y = yy + i * QK_K + step;
|
||||
const uint8_t * ql = x[i].ql + step;
|
||||
@ -734,7 +683,7 @@ static void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const floa
|
||||
const float d = x[i+0].d;
|
||||
|
||||
float sum = 0;
|
||||
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
|
||||
for (int j = 0; j < 2; ++j) {
|
||||
sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
|
||||
+ y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
|
||||
+ y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
|
||||
@ -870,7 +819,7 @@ static void dequantize_mul_mat_vec_q2_K_sycl(const void *vx, const float *y,
|
||||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
|
||||
const int ny = 2;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
@ -886,7 +835,7 @@ static void dequantize_mul_mat_vec_q3_K_sycl(const void *vx, const float *y,
|
||||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int ny = 1;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
@ -902,7 +851,7 @@ static void dequantize_mul_mat_vec_q4_K_sycl(const void *vx, const float *y,
|
||||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int ny = 1;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
@ -931,7 +880,7 @@ static void dequantize_mul_mat_vec_q6_K_sycl(const void *vx, const float *y,
|
||||
const int nrows,
|
||||
dpct::queue_ptr stream) {
|
||||
GGML_ASSERT(ncols % QK_K == 0);
|
||||
const int ny = 2 / K_QUANTS_PER_ITERATION;
|
||||
const int ny = 1;
|
||||
const int block_num_y = (nrows + ny - 1) / ny;
|
||||
const sycl::range<3> block_nums(1, 1, block_num_y);
|
||||
const sycl::range<3> block_dims(1, ny, WARP_SIZE);
|
||||
|
@ -50,12 +50,6 @@
|
||||
#define GGML_SYCL_MMV_Y 1
|
||||
#endif
|
||||
|
||||
#ifndef K_QUANTS_PER_ITERATION
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
#else
|
||||
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
||||
#endif
|
||||
|
||||
#ifndef GGML_SYCL_PEER_MAX_BATCH_SIZE
|
||||
#define GGML_SYCL_PEER_MAX_BATCH_SIZE 128
|
||||
#endif // GGML_SYCL_PEER_MAX_BATCH_SIZE
|
||||
|
@ -44,12 +44,6 @@
|
||||
|
||||
#define MAX_VK_BUFFERS 256
|
||||
|
||||
#ifndef K_QUANTS_PER_ITERATION
|
||||
#define K_QUANTS_PER_ITERATION 1
|
||||
#else
|
||||
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
|
||||
#endif
|
||||
|
||||
#define VK_CHECK(err, msg) \
|
||||
do { \
|
||||
vk::Result err_ = (err); \
|
||||
|
@ -2,8 +2,6 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_8bit_storage : require
|
||||
|
||||
#define K_QUANTS_PER_ITERATION 2
|
||||
|
||||
#ifdef MUL_MAT_ID
|
||||
#define EXPERT_COUNT 8
|
||||
#endif
|
||||
|
@ -15,22 +15,22 @@ void main() {
|
||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||
const uint tid = gl_LocalInvocationID.x/2; // 0...16
|
||||
const uint ix = gl_LocalInvocationID.x%2; // 0, 1
|
||||
|
||||
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
||||
const uint step = 8;
|
||||
|
||||
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const uint v_in = tid - step*v_im; // 0...15 or 0...7
|
||||
|
||||
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
|
||||
const uint l0 = 2*v_in; // 0...15
|
||||
const uint q_offset = 32*v_im + l0;
|
||||
const uint s_offset = 8*v_im;
|
||||
const uint y_offset = 128*v_im + l0;
|
||||
|
||||
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
|
||||
@ -38,7 +38,7 @@ void main() {
|
||||
|
||||
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
|
||||
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
|
||||
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
sum1 += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 0) & 3)
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l +16] >> 0) & 3)
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2] & 0xF) * FLOAT_TYPE((data_a[ib0 + i].qs[q_offset + l + 0] >> 2) & 3)
|
||||
|
@ -15,17 +15,17 @@ void main() {
|
||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||
const uint tid = gl_LocalInvocationID.x/2; // 0...16
|
||||
const uint ix = gl_LocalInvocationID.x%2; // 0, 1
|
||||
|
||||
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
||||
const uint step = 8;
|
||||
|
||||
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const uint v_in = tid - step*v_im; // 0...15 or 0...7
|
||||
|
||||
const uint8_t m = uint8_t(1 << (4 * v_im));
|
||||
|
||||
const uint l0 = K_QUANTS_PER_ITERATION*v_in; // 0...15
|
||||
const uint l0 = 2*v_in; // 0...15
|
||||
const uint q_offset = 32*v_im + l0;
|
||||
const uint y_offset = 128*v_im + l0;
|
||||
|
||||
@ -33,13 +33,13 @@ void main() {
|
||||
|
||||
const uint s_shift = 4 * v_im;
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
sum += FLOAT_TYPE(data_b[b_offset + y_idx + l + 0]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[0] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4))
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 32]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[2] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[10] >> (s_shift + 0) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4))
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + l + 64]) * FLOAT_TYPE(int8_t(((data_a[ib0 + i].scales[4] >> s_shift) & 0xF) | ((data_a[ib0 + i].scales[ 8] >> (s_shift + 2) & 0x3) << 4)) - 32) * FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4))
|
||||
|
@ -15,14 +15,14 @@ void main() {
|
||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||
const uint tid = gl_LocalInvocationID.x/2; // 0...16
|
||||
const uint ix = gl_LocalInvocationID.x%2; // 0, 1
|
||||
|
||||
const uint step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
|
||||
const uint step = 4;
|
||||
|
||||
const uint il = tid/step; // 0...3
|
||||
const uint ir = tid - step*il; // 0...7 or 0...3
|
||||
const uint n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
|
||||
const uint n = 4;
|
||||
|
||||
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
||||
const uint v_in = il % 2;
|
||||
@ -33,7 +33,7 @@ void main() {
|
||||
|
||||
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
|
||||
const uint y1_idx = i * QUANT_K + y_offset;
|
||||
const uint y2_idx = y1_idx + 128;
|
||||
|
||||
@ -49,7 +49,6 @@ void main() {
|
||||
const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
|
||||
const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 2
|
||||
const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
|
||||
const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
|
||||
const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf);
|
||||
@ -78,27 +77,6 @@ void main() {
|
||||
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7
|
||||
);
|
||||
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * sc0 + sy * sc1 + sz * sc4 + sw * sc5) - dmin * smin);
|
||||
#else
|
||||
const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
|
||||
const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
|
||||
const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
|
||||
const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
|
||||
const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
|
||||
const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
|
||||
const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
|
||||
const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
|
||||
|
||||
const FLOAT_TYPE sx = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx ]) * q4_0 + FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1);
|
||||
const FLOAT_TYPE sy = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * q4_2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
|
||||
const FLOAT_TYPE sz = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx ]) * q4_4 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5);
|
||||
const FLOAT_TYPE sw = FLOAT_TYPE(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * q4_6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
|
||||
const FLOAT_TYPE smin = FLOAT_TYPE(
|
||||
FLOAT_TYPE(data_b[b_offset + y1_idx]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) * sc7
|
||||
+ FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * sc2 + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * sc3 + FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * sc6 + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7
|
||||
);
|
||||
|
||||
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) + sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
|
||||
#endif
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
|
@ -15,21 +15,16 @@ void main() {
|
||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
||||
|
||||
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||
const uint tid = gl_LocalInvocationID.x/2; // 0...16
|
||||
const uint ix = gl_LocalInvocationID.x%2; // 0, 1
|
||||
|
||||
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
||||
const uint step = 8;
|
||||
|
||||
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||
const uint v_in = tid - step*v_im; // 0...15 or 0...7
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
const uint l0 = v_in; // 0...15
|
||||
const uint is = 0;
|
||||
#else
|
||||
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
|
||||
const uint is = v_in / 4;
|
||||
#endif
|
||||
|
||||
const uint ql_offset = 64*v_im + l0;
|
||||
const uint qh_offset = 32*v_im + l0;
|
||||
@ -38,22 +33,11 @@ void main() {
|
||||
|
||||
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||
|
||||
#if K_QUANTS_PER_ITERATION == 1
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(data_b[b_offset + y_idx + 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + 16]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + 32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + 48]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] & 0xF) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + 64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 0] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + 80]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 16] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + 96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 32] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx +112]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + 48] >> 4) | ((data_a[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
|
||||
tmp[16 * ix + tid] += sum;
|
||||
#else
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||
[[unroll]] for (int l = 0; l < 4; ++l) {
|
||||
sum += FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
|
||||
@ -62,7 +46,6 @@ void main() {
|
||||
+ FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
|
||||
}
|
||||
tmp[16 * ix + tid] += sum;
|
||||
#endif
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
|
Loading…
Reference in New Issue
Block a user