mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
CUDA: fixed redundant value dequantization (#4809)
This commit is contained in:
parent
9dede37d81
commit
d5a410e855
35
ggml-cuda.cu
35
ggml-cuda.cu
@ -1872,14 +1872,6 @@ static __device__ void convert_f16(const void * vx, const int ib, const int iqs,
|
||||
v.y = x[ib + iqs + 1];
|
||||
}
|
||||
|
||||
static __device__ void convert_f32(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||
const float * x = (const float *) vx;
|
||||
|
||||
// automatic half -> float type cast if dfloat == float
|
||||
v.x = x[ib + iqs + 0];
|
||||
v.y = x[ib + iqs + 1];
|
||||
}
|
||||
|
||||
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
|
||||
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
@ -1983,7 +1975,7 @@ static __global__ void k_get_rows_float(
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
|
||||
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
@ -2002,6 +1994,19 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
|
||||
y[iybs + iqs + y_offset] = v.y;
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
|
||||
const int i = blockDim.x*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
const src_t * x = (src_t *) vx;
|
||||
|
||||
y[i] = x[i];
|
||||
}
|
||||
|
||||
// VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
|
||||
// MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
|
||||
|
||||
@ -5609,7 +5614,7 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
|
||||
|
||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
@ -5659,6 +5664,12 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename src_t, typename dst_t>
|
||||
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
|
||||
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
|
||||
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||
}
|
||||
|
||||
static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q4_0:
|
||||
@ -5682,7 +5693,7 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
|
||||
case GGML_TYPE_Q6_K:
|
||||
return dequantize_row_q6_K_cuda;
|
||||
case GGML_TYPE_F32:
|
||||
return dequantize_block_cuda<1, 1, convert_f32>;
|
||||
return convert_unary_cuda<float>;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
@ -5711,7 +5722,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||
case GGML_TYPE_Q6_K:
|
||||
return dequantize_row_q6_K_cuda;
|
||||
case GGML_TYPE_F16:
|
||||
return dequantize_block_cuda<1, 1, convert_f16>;
|
||||
return convert_unary_cuda<half>;
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user