mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
CUDA: faster dequantize kernels for Q4_0 and Q4_1 (#4938)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
a836c8f534
commit
4a3156de2f
77
ggml-cuda.cu
77
ggml-cuda.cu
@ -1105,6 +1105,61 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in
|
|||||||
#endif // GGML_CUDA_F16
|
#endif // GGML_CUDA_F16
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
|
// assume 32 threads
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8;
|
||||||
|
const int ir = tid%8;
|
||||||
|
const int ib = 8*i + ir;
|
||||||
|
if (ib >= nb32) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
||||||
|
|
||||||
|
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
|
||||||
|
const float d = __half2float(x->d);
|
||||||
|
const float dm = -8*d;
|
||||||
|
|
||||||
|
const uint8_t * q = x->qs + 4*il;
|
||||||
|
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
y[l+ 0] = d * (q[l] & 0xF) + dm;
|
||||||
|
y[l+16] = d * (q[l] >> 4) + dm;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
|
||||||
|
|
||||||
|
const int i = blockIdx.x;
|
||||||
|
|
||||||
|
// assume 32 threads
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int il = tid/8;
|
||||||
|
const int ir = tid%8;
|
||||||
|
const int ib = 8*i + ir;
|
||||||
|
if (ib >= nb32) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst_t * y = yy + 256*i + 32*ir + 4*il;
|
||||||
|
|
||||||
|
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
|
||||||
|
const float2 d = __half22float2(x->dm);
|
||||||
|
|
||||||
|
const uint8_t * q = x->qs + 4*il;
|
||||||
|
|
||||||
|
for (int l = 0; l < 4; ++l) {
|
||||||
|
y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
|
||||||
|
y[l+16] = d.x * (q[l] >> 4) + d.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//================================== k-quants
|
//================================== k-quants
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
@ -6253,6 +6308,20 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb32 = k / 32;
|
||||||
|
const int nb = (k + 255) / 256;
|
||||||
|
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename dst_t>
|
||||||
|
static void dequantize_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
|
const int nb32 = k / 32;
|
||||||
|
const int nb = (k + 255) / 256;
|
||||||
|
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename dst_t>
|
template<typename dst_t>
|
||||||
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
@ -6301,9 +6370,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|||||||
int id;
|
int id;
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
return dequantize_q4_0_cuda;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
|
return dequantize_q4_1_cuda;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
@ -6338,9 +6407,9 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
|||||||
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
|
return dequantize_q4_0_cuda;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
return dequantize_block_cuda<QK4_1, QR4_1, dequantize_q4_1>;
|
return dequantize_q4_1_cuda;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
|
Loading…
Reference in New Issue
Block a user