cuda : alternative q4_q8 kernel

This commit is contained in:
Georgi Gerganov 2023-05-12 15:54:07 +03:00
parent e7b9d97bae
commit a3e6d62283

View File

@ -274,6 +274,92 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
} }
} }
template <int NT, int NR> static __global__ void dequantize_mul_mat_q4_0_test(const void * vx, const void * vy, float * dst, const int ncols, const int nrows) {
const block_q4_0 * x = (const block_q4_0 *) vx;
const block_q8_0 * y = (const block_q8_0 *) vy;
const int bid = blockIdx.x;
const int tid = threadIdx.x;
__shared__ float tmp[NR][NT];
for (int i = 0; i < NR; ++i) {
tmp[i][tid] = 0.0f;
}
const int nbc = (ncols + 16*NT - 1)/(16*NT);
const int nbm = ncols/QK8_0;
uint64_t xa0;
uint64_t xa1;
const int8_t * xb0 = (const int8_t *) &xa0;
const int8_t * xb1 = (const int8_t *) &xa1;
for (int ibc = 0; ibc < nbc; ++ibc) {
const int iyb = (ibc*(16*NT) + 16*tid)/QK8_0;
const int iyq = (ibc*(16*NT) + 16*tid)%QK8_0;
if (iyb >= nbm) {
continue;
}
const int8_t * yb = (const int8_t *) &y[iyb].qs[iyq];
const float dy = y[iyb].d;
for (int ibr = 0; ibr < NR; ++ibr) {
const int ir = bid*NR + ibr;
if (ir >= nrows) {
continue;
}
// block offset
const int ixo = (ir*ncols)/QK4_0 + iyb;
memcpy(&xa0, &x[ixo].qs[iyq/2 + 0], sizeof(uint64_t));
xa1 = xa0;
xa0 = (xa0 ) & 0x0F0F0F0F0F0F0F0F;
xa1 = (xa1 >> 4) & 0x0F0F0F0F0F0F0F0F;
const float dx = x[ixo].d;
// the (int) cast is probably unnecessary, but just to make sure the result is accumulated in 32 bits
tmp[ibr][tid] += (
((int)(xb0[0] - 8))*yb[0] + ((int)(xb1[0] - 8))*yb[1] +
((int)(xb0[1] - 8))*yb[2] + ((int)(xb1[1] - 8))*yb[3] +
((int)(xb0[2] - 8))*yb[4] + ((int)(xb1[2] - 8))*yb[5] +
((int)(xb0[3] - 8))*yb[6] + ((int)(xb1[3] - 8))*yb[7] +
((int)(xb0[4] - 8))*yb[8] + ((int)(xb1[4] - 8))*yb[9] +
((int)(xb0[5] - 8))*yb[10] + ((int)(xb1[5] - 8))*yb[11] +
((int)(xb0[6] - 8))*yb[12] + ((int)(xb1[6] - 8))*yb[13] +
((int)(xb0[7] - 8))*yb[14] + ((int)(xb1[7] - 8))*yb[15]
)*dx*dy;
}
}
// reduce
__syncthreads();
for (int s = NT/2; s > 0; s >>= 1) {
if (tid < s) {
for (int ibr = 0; ibr < NR; ++ibr) {
tmp[ibr][tid] += tmp[ibr][tid + s];
}
}
__syncthreads();
}
if (tid == 0) {
for (int ibr = 0; ibr < NR; ++ibr) {
const int ir = bid*NR + ibr;
if (ir < nrows) {
dst[ir] = tmp[ibr][0];
}
}
}
}
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0; const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y); dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
@ -316,9 +402,14 @@ static void dequantize_mul_mat_q4_0_cuda(const void * vx, const void * y, float
// } // }
// } // }
// dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols); // dequantize_mul_mat_q4_0<<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
const int block_size = 32; //const int block_size = 32;
GGML_ASSERT(ncols % block_size == 0); //GGML_ASSERT(ncols % block_size == 0);
dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols); //dequantize_mul_mat_q4_0<block_size><<<nrows, block_size, 0, stream>>>(vx, y, dst, ncols);
const int NR = 1; // unroll rows (seems to not help)
const int NT = 64; // number of thrads per row
dequantize_mul_mat_q4_0_test<NT, NR><<<(nrows + NR - 1)/NR, NT, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
// TODO: optimize // TODO: optimize