mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
cuda : alternative q4_q8 kernel
This commit is contained in:
parent
e7b9d97bae
commit
a3e6d62283
97
ggml-cuda.cu
97
ggml-cuda.cu
@ -274,6 +274,92 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <int NT, int NR> static __global__ void dequantize_mul_mat_q4_0_test(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
|
||||||
|
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||||
|
const block_q8_0 * y = (const block_q8_0 *) vy;
|
||||||
|
|
||||||
|
const int bid = blockIdx.x;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
|
||||||
|
__shared__ float tmp[NR][NT];
|
||||||
|
for (int i = 0; i < NR; ++i) {
|
||||||
|
tmp[i][tid] = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int nbc = (ncols + 16*NT - 1)/(16*NT);
|
||||||
|
const int nbm = ncols/QK8_0;
|
||||||
|
|
||||||
|
uint64_t xa0;
|
||||||
|
uint64_t xa1;
|
||||||
|
|
||||||
|
const int8_t * xb0 = (const int8_t *) &xa0;
|
||||||
|
const int8_t * xb1 = (const int8_t *) &xa1;
|
||||||
|
|
||||||
|
for (int ibc = 0; ibc < nbc; ++ibc) {
|
||||||
|
const int iyb = (ibc*(16*NT) + 16*tid)/QK8_0;
|
||||||
|
const int iyq = (ibc*(16*NT) + 16*tid)%QK8_0;
|
||||||
|
|
||||||
|
if (iyb >= nbm) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int8_t * yb = (const int8_t *) &y[iyb].qs[iyq];
|
||||||
|
|
||||||
|
const float dy = y[iyb].d;
|
||||||
|
|
||||||
|
for (int ibr = 0; ibr < NR; ++ibr) {
|
||||||
|
const int ir = bid*NR + ibr;
|
||||||
|
if (ir >= nrows) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// block offset
|
||||||
|
const int ixo = (ir*ncols)/QK4_0 + iyb;
|
||||||
|
|
||||||
|
memcpy(&xa0, &x[ixo].qs[iyq/2 + 0], sizeof(uint64_t));
|
||||||
|
xa1 = xa0;
|
||||||
|
|
||||||
|
xa0 = (xa0 ) & 0x0F0F0F0F0F0F0F0F;
|
||||||
|
xa1 = (xa1 >> 4) & 0x0F0F0F0F0F0F0F0F;
|
||||||
|
|
||||||
|
const float dx = x[ixo].d;
|
||||||
|
|
||||||
|
// the (int) cast is probably unnecessary, but just to make sure the result is accumulated in 32 bits
|
||||||
|
tmp[ibr][tid] += (
|
||||||
|
((int)(xb0[0] - 8))*yb[0] + ((int)(xb1[0] - 8))*yb[1] +
|
||||||
|
((int)(xb0[1] - 8))*yb[2] + ((int)(xb1[1] - 8))*yb[3] +
|
||||||
|
((int)(xb0[2] - 8))*yb[4] + ((int)(xb1[2] - 8))*yb[5] +
|
||||||
|
((int)(xb0[3] - 8))*yb[6] + ((int)(xb1[3] - 8))*yb[7] +
|
||||||
|
((int)(xb0[4] - 8))*yb[8] + ((int)(xb1[4] - 8))*yb[9] +
|
||||||
|
((int)(xb0[5] - 8))*yb[10] + ((int)(xb1[5] - 8))*yb[11] +
|
||||||
|
((int)(xb0[6] - 8))*yb[12] + ((int)(xb1[6] - 8))*yb[13] +
|
||||||
|
((int)(xb0[7] - 8))*yb[14] + ((int)(xb1[7] - 8))*yb[15]
|
||||||
|
)*dx*dy;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int s = NT/2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
for (int ibr = 0; ibr < NR; ++ibr) {
|
||||||
|
tmp[ibr][tid] += tmp[ibr][tid + s];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tid == 0) {
|
||||||
|
for (int ibr = 0; ibr < NR; ++ibr) {
|
||||||
|
const int ir = bid*NR + ibr;
|
||||||
|
if (ir < nrows) {
|
||||||
|
dst[ir] = tmp[ibr][0];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
|
||||||
const int nb = k / QK4_0;
|
const int nb = k / QK4_0;
|
||||||
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||||
@ -316,9 +402,14 @@ static void dequantize_mul_mat_q4_0_cuda(const void * vx, const void * y, float
|
|||||||
// }
|
// }
|
||||||
// }
|
// }
|
||||||
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
|
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
|
||||||
const int block_size = 32;
|
//const int block_size = 32;
|
||||||
GGML_ASSERT(ncols % block_size == 0);
|
//GGML_ASSERT(ncols % block_size == 0);
|
||||||
dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
|
//dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
|
||||||
|
|
||||||
|
const int NR = 1; // unroll rows (seems to not help)
|
||||||
|
const int NT = 64; // number of thrads per row
|
||||||
|
|
||||||
|
dequantize_mul_mat_q4_0_test<NT, NR><<<(nrows + NR - 1)/NR, NT, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: optimize
|
// TODO: optimize
|
||||||
|
Loading…
Reference in New Issue
Block a user