mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
Works but slower than CPU
This commit is contained in:
parent
41654efea8
commit
229aa1f504
104
ggml-cuda.cu
104
ggml-cuda.cu
@ -225,6 +225,33 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
|
||||
}
|
||||
}
|
||||
|
||||
static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, int ncols, int nrows) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
|
||||
const int row = blockIdx.x*blockDim.x + threadIdx.x;
|
||||
|
||||
if (row >= nrows) {
|
||||
return;
|
||||
}
|
||||
dst[row] = 0;
|
||||
for (int i = 0; i < ncols; i += 2) {
|
||||
const float d = x[(row*ncols + i)/QK4_0].d;
|
||||
|
||||
const uint8_t * pp = x[(row*ncols + i)/QK4_0].qs;
|
||||
|
||||
const uint8_t vui = pp[((row*ncols + i)%QK4_0)/2];
|
||||
|
||||
const int8_t vi0 = vui & 0xF;
|
||||
const int8_t vi1 = vui >> 4;
|
||||
|
||||
const float v0 = (vi0 - 8)*d;
|
||||
const float v1 = (vi1 - 8)*d;
|
||||
|
||||
dst[row] += v0 * y[i + 0];
|
||||
dst[row] += v1 * y[i + 1];
|
||||
}
|
||||
}
|
||||
|
||||
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||
const int nb = k / QK4_0;
|
||||
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
@ -255,6 +282,17 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
|
||||
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, int ncols, int nrows, cudaStream_t stream) {
|
||||
static int block_size = -1;
|
||||
if (block_size == -1) {
|
||||
int min_grid_size;
|
||||
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_mul_mat_q4_0, 0, 0));
|
||||
block_size = min(block_size, GGML_CUDA_MAX_BLOCK_SIZE);
|
||||
}
|
||||
const int grid_size = (nrows + block_size - 1) / block_size; // Round up.
|
||||
dequantize_mul_mat_q4_0<<<grid_size, block_size, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||
}
|
||||
|
||||
// TODO: optimize
|
||||
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
|
||||
const half * x = (const half *) vx;
|
||||
@ -597,7 +635,10 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
||||
const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
|
||||
|
||||
size_t x_size, y_size, d_size, q_size;
|
||||
float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
|
||||
float * d_X;
|
||||
if (ne11 > 1) {
|
||||
d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
|
||||
}
|
||||
float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
|
||||
float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
|
||||
char * d_Q = (char *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
|
||||
@ -612,31 +653,49 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
||||
cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
|
||||
cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
|
||||
|
||||
float * c_X = d_X + i * x_ne;
|
||||
float * c_Y = d_Y + i * y_ne;
|
||||
float * c_D = d_D + i * d_ne;
|
||||
char * c_Q = d_Q + i * q_sz;
|
||||
|
||||
// copy src0 and convert to fp32 on device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
|
||||
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||
if (ne11 == 1) {
|
||||
// copy src0 to device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||
|
||||
// copy src1 to device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
||||
// copy src1 to device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
||||
|
||||
// wait for conversion
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||
// wait for data
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||
|
||||
// compute
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, c_X, ne00,
|
||||
c_Y, ne10,
|
||||
&beta, c_D, ne01));
|
||||
// compute
|
||||
dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
} else {
|
||||
float * c_X = d_X + i * x_ne;
|
||||
|
||||
// copy src0 and convert to fp32 on device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
|
||||
to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
|
||||
|
||||
// copy src1 to device
|
||||
CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
|
||||
|
||||
// wait for conversion
|
||||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||
|
||||
// compute
|
||||
CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
|
||||
CUBLAS_CHECK(
|
||||
cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
ne01, ne11, ne10,
|
||||
&alpha, c_X, ne00,
|
||||
c_Y, ne10,
|
||||
&beta, c_D, ne01));
|
||||
}
|
||||
|
||||
// copy dst to host
|
||||
float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
|
||||
@ -645,7 +704,9 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
||||
}
|
||||
|
||||
CUDA_CHECK(cudaDeviceSynchronize());
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
if (ne11 > 1) {
|
||||
ggml_cuda_pool_free(d_X, x_size);
|
||||
}
|
||||
ggml_cuda_pool_free(d_Y, y_size);
|
||||
ggml_cuda_pool_free(d_D, d_size);
|
||||
ggml_cuda_pool_free(d_Q, q_size);
|
||||
@ -660,8 +721,7 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
|
||||
// TODO: find the optimal values for these
|
||||
if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
|
||||
src1->type == GGML_TYPE_F32 &&
|
||||
dst->type == GGML_TYPE_F32 &&
|
||||
(ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
|
||||
dst->type == GGML_TYPE_F32) {
|
||||
|
||||
return true;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user