mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 20:14:29 +00:00
Faster than CPU without 80% runtime memcpy
This commit is contained in:
parent
229aa1f504
commit
d052a0ed4c
63
ggml-cuda.cu
63
ggml-cuda.cu
@ -225,21 +225,24 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, int ncols, int nrows) {
|
||||
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
const int row = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
const int row = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
dst[row] = 0;
|
||||
for (int i = 0; i < ncols; i += 2) {
|
||||
const float d = x[(row*ncols + i)/QK4_0].d;
|
||||
__shared__ float tmp[block_size]; // separate sum for each thread
|
||||
tmp[tid] = 0;
|
||||
|
||||
const uint8_t * pp = x[(row*ncols + i)/QK4_0].qs;
|
||||
for (int i = 0; i < ncols/block_size; i += 2) {
|
||||
const int col = i*block_size + 2*tid;
|
||||
|
||||
const uint8_t vui = pp[((row*ncols + i)%QK4_0)/2];
|
||||
// dequantize
|
||||
const float d = x[(row*ncols + col)/QK4_0].d;
|
||||
|
||||
const uint8_t * pp = x[(row*ncols + col)/QK4_0].qs;
|
||||
|
||||
const uint8_t vui = pp[((row*ncols + col)%QK4_0)/2];
|
||||
|
||||
const int8_t vi0 = vui & 0xF;
|
||||
const int8_t vi1 = vui >> 4;
|
||||
@ -247,8 +250,20 @@ static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y,
|
||||
const float v0 = (vi0 - 8)*d;
|
||||
const float v1 = (vi1 - 8)*d;
|
||||
|
||||
dst[row] += v0 * y[i + 0];
|
||||
dst[row] += v1 * y[i + 1];
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[col + 0];
|
||||
tmp[tid] += v1 * y[col + 1];
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
for (int s=block_size/2; s>0; s>>=1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[row] = tmp[0];
|
||||
}
|
||||
}
|
||||
|
||||
@ -282,15 +297,21 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
|
||||
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, int ncols, int nrows, cudaStream_t stream) {
|
||||
static int block_size = -1;
|
||||
if (block_size == -1) {
|
||||
int min_grid_size;
|
||||
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_mul_mat_q4_0, 0, 0));
|
||||
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
|
||||
}
|
||||
const int grid_size = (nrows + block_size - 1) / block_size; // Round up.
|
||||
dequantize_mul_mat_q4_0<<<grid_size, block_size, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
// static int block_size = -1;
|
||||
// if (block_size == -1) {
|
||||
// int min_grid_size, max_block_size = 1;
|
||||
// CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &max_block_size, dequantize_mul_mat_q4_0<256>, 0, 0));
|
||||
// max_block_size = min(max_block_size, GGML_CUDA_MAX_BLOCK_SIZE);
|
||||
// block_size = 1;
|
||||
// while (block_size*2 <= max_block_size && block_size*2 % ncols == 0) {
|
||||
// block_size *= 2;
|
||||
// }
|
||||
// }
|
||||
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
|
||||
const int block_size = 32;
|
||||
GGML_ASSERT(ncols % block_size == 0);
|
||||
dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
|
||||
}
|
||||
|
||||
// TODO: optimize
|
||||
|
Loading…
Reference in New Issue
Block a user