mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
cuda : play with faster Q4_0 dequantization
This commit is contained in:
parent
d415669087
commit
6966474928
82
ggml-cuda.cu
82
ggml-cuda.cu
@ -4659,12 +4659,94 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
|
||||
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
|
||||
}
|
||||
|
||||
#ifdef GGML_CUDA_F16
|
||||
#define make_dfloat2(x, y) __halves2half2((x), (y))
|
||||
#else
|
||||
#define make_dfloat2(x, y) make_float2((x), (y))
|
||||
#endif
|
||||
|
||||
static __device__ __forceinline__ dfloat2 dfmul2(dfloat2 a, dfloat2 b) {
|
||||
#ifdef GGML_CUDA_F16
|
||||
return __hmul2(a, b);
|
||||
#else
|
||||
return make_float2(a.x * b.x, a.y * b.y);
|
||||
#endif
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ float2 dfloat22float2(dfloat2 a) {
|
||||
#ifdef GGML_CUDA_F16
|
||||
return __half22float2(a);
|
||||
#else
|
||||
return a;
|
||||
#endif
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q4_0_f32(const void * __restrict__ vx, float * __restrict__ y, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i*4 >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ib = i/(QK4_0/4);
|
||||
const int iqs = i%(QK4_0/4);
|
||||
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
const uchar2 qs = *(const uchar2 *)(x[ib].qs + iqs*2);
|
||||
const dfloat d = x[ib].d;
|
||||
|
||||
dfloat2 dv0 = make_dfloat2((int)(qs.x & 0xf) - 8, (int)(qs.y & 0xf) - 8);
|
||||
const float2 v0 = dfloat22float2(dfmul2(dv0, {d, d}));
|
||||
*(float2 *)(y + ib*QK4_0 + iqs*2) = v0;
|
||||
|
||||
dfloat2 dv1 = make_dfloat2((int)(qs.x >> 4) - 8, (int)(qs.y >> 4) - 8);
|
||||
const float2 v1 = dfloat22float2(dfmul2(dv1, {d, d}));
|
||||
*(float2 *)(y + ib*QK4_0 + QK4_0/2 + iqs*2) = v1;
|
||||
}
|
||||
|
||||
static __global__ void dequantize_block_q4_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i*4 >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int ib = i/(QK4_0/4);
|
||||
const int iqs = i%(QK4_0/4);
|
||||
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
const uchar2 qs = *(const uchar2 *)(x[ib].qs + iqs*2);
|
||||
const dfloat d = x[ib].d;
|
||||
|
||||
dfloat2 dv0 = make_dfloat2((int)(qs.x & 0xf) - 8, (int)(qs.y & 0xf) - 8);
|
||||
const float2 v0 = dfloat22float2(dfmul2(dv0, {d, d}));
|
||||
*(half2 *)(y + ib*QK4_0 + iqs*2) = __float22half2_rn(v0);
|
||||
|
||||
dfloat2 dv1 = make_dfloat2((int)(qs.x >> 4) - 8, (int)(qs.y >> 4) - 8);
|
||||
const float2 v1 = dfloat22float2(dfmul2(dv1, {d, d}));
|
||||
*(half2 *)(y + ib*QK4_0 + QK4_0/2 + iqs*2) = __float22half2_rn(v1);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block<QK4_0, QR4_0, dequantize_q4_0><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
template<>
|
||||
void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
|
||||
GGML_ASSERT(k % 4 == 0);
|
||||
const int num_blocks = (k/4 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block_q4_0_f32<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
template<>
|
||||
void dequantize_row_q4_0_cuda(const void * vx, half * y, const int k, cudaStream_t stream) {
|
||||
GGML_ASSERT(k % 4 == 0);
|
||||
const int num_blocks = (k/4 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
dequantize_block_q4_0_f16<<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
template<typename dst_t>
|
||||
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
|
Loading…
Reference in New Issue
Block a user