mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-14 06:49:54 +00:00
CUDA: fewer memory bank conflicts for mul_mat_q (#2458)
This commit is contained in:
parent
9d2382b3e4
commit
2dbf518911
608
ggml-cuda.cu
608
ggml-cuda.cu
@ -162,7 +162,7 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
|
|||||||
typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
|
typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
|
||||||
typedef void (*load_tiles_cuda_t)(
|
typedef void (*load_tiles_cuda_t)(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row);
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row);
|
||||||
typedef float (*vec_dot_q_mul_mat_cuda_t)(
|
typedef float (*vec_dot_q_mul_mat_cuda_t)(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ms, const int & i, const int & j, const int & k);
|
||||||
@ -1397,8 +1397,8 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
|
|||||||
|
|
||||||
static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
__shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
__shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
|
||||||
__shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0)];
|
__shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0) + GGML_CUDA_MMQ_Y/QI4_0];
|
||||||
|
|
||||||
*x_ql = tile_x_qs;
|
*x_ql = tile_x_qs;
|
||||||
*x_dm = tile_x_d;
|
*x_dm = tile_x_d;
|
||||||
@ -1406,26 +1406,61 @@ static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 **
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_tiles_q4_0(
|
static __device__ __forceinline__ void load_tiles_q4_0(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
|
__builtin_assume(i_offset >= 0);
|
||||||
|
__builtin_assume(i_offset < 8);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI4_0;
|
const int kbx = k / QI4_0;
|
||||||
const int kqsx = k % QI4_0;
|
const int kqsx = k % QI4_0;
|
||||||
|
|
||||||
const block_q4_0 * bx = ((block_q4_0 *) vx) + i*blocks_per_row + kbx;
|
const block_q4_0 * bx0 = (block_q4_0 *) vx;
|
||||||
|
|
||||||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
|
#pragma unroll
|
||||||
x_dm[i * (WARP_SIZE / QI4_0) + kbx].x = bx->d;
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
|
||||||
|
const int i = i0 + i_offset;
|
||||||
|
|
||||||
|
const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
|
||||||
|
x_dm[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx].x = bxi->d;
|
||||||
|
}
|
||||||
|
|
||||||
|
// const int blocks_per_tile_x_row = WARP_SIZE / QI4_0;
|
||||||
|
// const int kbxd = k % blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
// #pragma unroll
|
||||||
|
// for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_0) {
|
||||||
|
// const int i = i0 + i_offset * QI4_0 + k / blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
// if (i >= GGML_CUDA_MMQ_Y) {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd;
|
||||||
|
|
||||||
|
// x_dm[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd].x = bxi->d;
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
|
static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
|
__builtin_assume(i >= 0);
|
||||||
|
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||||
|
__builtin_assume(j >= 0);
|
||||||
|
__builtin_assume(j < WARP_SIZE);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
||||||
|
|
||||||
return vec_dot_q4_0_q8_1_impl(
|
return vec_dot_q4_0_q8_1_impl(
|
||||||
x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
|
x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
|
||||||
x_dm[i * (WARP_SIZE/QI4_0) + k/QI4_0].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
|
x_dm[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0].x, y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_q4_1_q8_1 1
|
#define VDR_q4_1_q8_1 1
|
||||||
@ -1471,8 +1506,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
|
|||||||
|
|
||||||
static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
__shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
__shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE) + + GGML_CUDA_MMQ_Y];
|
||||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_1)];
|
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_1) + GGML_CUDA_MMQ_Y/QI4_1];
|
||||||
|
|
||||||
*x_ql = tile_x_qs;
|
*x_ql = tile_x_qs;
|
||||||
*x_dm = tile_x_dm;
|
*x_dm = tile_x_dm;
|
||||||
@ -1480,26 +1515,56 @@ static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 **
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_tiles_q4_1(
|
static __device__ __forceinline__ void load_tiles_q4_1(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
|
__builtin_assume(i_offset >= 0);
|
||||||
|
__builtin_assume(i_offset < 8);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI4_1;
|
const int kbx = k / QI4_1;
|
||||||
const int kqsx = k % QI4_1;
|
const int kqsx = k % QI4_1;
|
||||||
|
|
||||||
const block_q4_1 * bx = ((block_q4_1 *) vx) + i*blocks_per_row + kbx;
|
const block_q4_1 * bx0 = (block_q4_1 *) vx;
|
||||||
|
|
||||||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx);
|
#pragma unroll
|
||||||
x_dm[i * (WARP_SIZE / QI4_1) + kbx] = bx->dm;
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
|
||||||
|
const int i = i0 + i_offset;
|
||||||
|
|
||||||
|
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int blocks_per_tile_x_row = WARP_SIZE / QI4_1;
|
||||||
|
const int kbxd = k % blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_1) {
|
||||||
|
const int i = i0 + i_offset * QI4_1 + k / blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd;
|
||||||
|
|
||||||
|
x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
|
static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
|
__builtin_assume(i >= 0);
|
||||||
|
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||||
|
__builtin_assume(j >= 0);
|
||||||
|
__builtin_assume(j < WARP_SIZE);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
||||||
|
|
||||||
return vec_dot_q4_1_q8_1_impl(
|
return vec_dot_q4_1_q8_1_impl(
|
||||||
x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
|
x_ql[i * (WARP_SIZE + 1) + k], y_qs[j * (2*WARP_SIZE) + kyqs], y_qs[j * (2*WARP_SIZE) + kyqs + (QI8_1/2)],
|
||||||
x_dm[i * (WARP_SIZE/QI4_1) + k/QI4_1], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
|
x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], y_ds[j * (2*WARP_SIZE/QI8_1) + 2*k/QI8_1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_q5_0_q8_1 1
|
#define VDR_q5_0_q8_1 1
|
||||||
@ -1543,9 +1608,9 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
|
|||||||
|
|
||||||
static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
|
||||||
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0)];
|
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0) + GGML_CUDA_MMQ_Y/QI5_0];
|
||||||
__shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0)];
|
__shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0) + GGML_CUDA_MMQ_Y/QI5_0];
|
||||||
|
|
||||||
*x_ql = tile_x_ql;
|
*x_ql = tile_x_ql;
|
||||||
*x_qh = tile_x_qh;
|
*x_qh = tile_x_qh;
|
||||||
@ -1554,24 +1619,54 @@ static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 **
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_tiles_q5_0(
|
static __device__ __forceinline__ void load_tiles_q5_0(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
|
__builtin_assume(i_offset >= 0);
|
||||||
|
__builtin_assume(i_offset < 8);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI5_0;
|
const int kbx = k / QI5_0;
|
||||||
const int kqsx = k % QI5_0;
|
const int kqsx = k % QI5_0;
|
||||||
|
|
||||||
const block_q5_0 * bx = ((block_q5_0 *) vx) + i*blocks_per_row + kbx;
|
const block_q5_0 * bx0 = (block_q5_0 *) vx;
|
||||||
|
|
||||||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
|
#pragma unroll
|
||||||
x_qh[i * (WARP_SIZE / QI5_0) + kbx] = get_int_from_uint8(bx->qh, 0);
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
|
||||||
x_dm[i * (WARP_SIZE / QI5_0) + kbx].x = bx->d;
|
const int i = i0 + i_offset;
|
||||||
|
|
||||||
|
const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int blocks_per_tile_x_row = WARP_SIZE / QI5_0;
|
||||||
|
const int kbxd = k % blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_0) {
|
||||||
|
const int i = i0 + i_offset * QI5_0 + k / blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd;
|
||||||
|
|
||||||
|
x_qh[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = get_int_from_uint8(bxi->qh, 0);
|
||||||
|
x_dm[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd].x = bxi->d;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
|
static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
|
__builtin_assume(i >= 0);
|
||||||
|
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||||
|
__builtin_assume(j >= 0);
|
||||||
|
__builtin_assume(j < WARP_SIZE);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
||||||
const int index_bx = i * (WARP_SIZE/QI5_0) + k/QI5_0;
|
const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0;
|
||||||
|
|
||||||
return vec_dot_q5_0_q8_1_impl(
|
return vec_dot_q5_0_q8_1_impl(
|
||||||
x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_0)), y_qs[j * (2*WARP_SIZE) + kyqs],
|
x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_0)), y_qs[j * (2*WARP_SIZE) + kyqs],
|
||||||
@ -1629,9 +1724,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
|
|||||||
|
|
||||||
static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE ) + GGML_CUDA_MMQ_Y];
|
||||||
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1)];
|
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1) + GGML_CUDA_MMQ_Y/QI5_1];
|
||||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1)];
|
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1) + GGML_CUDA_MMQ_Y/QI5_1];
|
||||||
|
|
||||||
*x_ql = tile_x_ql;
|
*x_ql = tile_x_ql;
|
||||||
*x_qh = tile_x_qh;
|
*x_qh = tile_x_qh;
|
||||||
@ -1640,24 +1735,54 @@ static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 **
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_tiles_q5_1(
|
static __device__ __forceinline__ void load_tiles_q5_1(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
|
__builtin_assume(i_offset >= 0);
|
||||||
|
__builtin_assume(i_offset < 8);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI5_1;
|
const int kbx = k / QI5_1;
|
||||||
const int kqsx = k % QI5_1;
|
const int kqsx = k % QI5_1;
|
||||||
|
|
||||||
const block_q5_1 * bx = ((block_q5_1 *) vx) + i*blocks_per_row + kbx;
|
const block_q5_1 * bx0 = (block_q5_1 *) vx;
|
||||||
|
|
||||||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
|
#pragma unroll
|
||||||
x_qh[i * (WARP_SIZE / QI5_1) + kbx] = get_int_from_uint8(bx->qh, 0);
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
|
||||||
x_dm[i * (WARP_SIZE / QI5_1) + kbx] = bx->dm;
|
const int i = i0 + i_offset;
|
||||||
|
|
||||||
|
const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int blocks_per_tile_x_row = WARP_SIZE / QI5_1;
|
||||||
|
const int kbxd = k % blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_1) {
|
||||||
|
const int i = i0 + i_offset * QI5_1 + k / blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd;
|
||||||
|
|
||||||
|
x_qh[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = get_int_from_uint8_aligned(bxi->qh, 0);
|
||||||
|
x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
|
static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
|
__builtin_assume(i >= 0);
|
||||||
|
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||||
|
__builtin_assume(j >= 0);
|
||||||
|
__builtin_assume(j < WARP_SIZE);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2));
|
||||||
const int index_bx = i * (WARP_SIZE/QI5_0) + k/QI5_0;
|
const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1;
|
||||||
|
|
||||||
return vec_dot_q5_1_q8_1_impl(
|
return vec_dot_q5_1_q8_1_impl(
|
||||||
x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_1)), y_qs[j * (2*WARP_SIZE) + kyqs],
|
x_ql[i * (WARP_SIZE + 1) + k], x_qh[index_bx] >> (4 * (k % QI5_1)), y_qs[j * (2*WARP_SIZE) + kyqs],
|
||||||
@ -1692,8 +1817,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
|
|||||||
|
|
||||||
static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
__shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
__shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
|
||||||
__shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0)];
|
__shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0) + GGML_CUDA_MMQ_Y/QI8_0];
|
||||||
|
|
||||||
*x_ql = tile_x_qs;
|
*x_ql = tile_x_qs;
|
||||||
*x_dm = tile_x_d;
|
*x_dm = tile_x_d;
|
||||||
@ -1701,24 +1826,61 @@ static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 **
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_tiles_q8_0(
|
static __device__ __forceinline__ void load_tiles_q8_0(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
|
__builtin_assume(i_offset >= 0);
|
||||||
|
__builtin_assume(i_offset < 8);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI8_0;
|
const int kbx = k / QI8_0;
|
||||||
const int kqsx = k % QI8_0;
|
const int kqsx = k % QI8_0;
|
||||||
|
|
||||||
const block_q8_0 * bx = ((block_q8_0 *) vx) + i*blocks_per_row + kbx;
|
const block_q8_0 * bx0 = (block_q8_0 *) vx;
|
||||||
|
|
||||||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bx->qs, kqsx);
|
#pragma unroll
|
||||||
x_dm[i * (WARP_SIZE / QI8_0) + kbx].x = bx->d;
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
|
||||||
|
const int i = i0 + i_offset;
|
||||||
|
|
||||||
|
const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx);
|
||||||
|
x_dm[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbx].x = bxi->d;
|
||||||
|
}
|
||||||
|
|
||||||
|
// const int blocks_per_tile_x_row = WARP_SIZE / QI8_0;
|
||||||
|
// const int kbxd = k % blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
// #pragma unroll
|
||||||
|
// for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI8_0) {
|
||||||
|
// const int i = i0 + i_offset * QI8_0 + k / blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
// #if GGML_CUDA_MMQ_Y < 64
|
||||||
|
// if (i >= GGML_CUDA_MMQ_Y) {
|
||||||
|
// return;
|
||||||
|
// }
|
||||||
|
// #endif // GGML_CUDA_MMQ_Y < 64
|
||||||
|
|
||||||
|
// const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd;
|
||||||
|
|
||||||
|
// x_dm[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd].x = bxi->d;
|
||||||
|
// }
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
|
static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
|
__builtin_assume(i >= 0);
|
||||||
|
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||||
|
__builtin_assume(j >= 0);
|
||||||
|
__builtin_assume(j < WARP_SIZE);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
return vec_dot_q8_0_q8_1_impl(
|
return vec_dot_q8_0_q8_1_impl(
|
||||||
x_ql[i * (WARP_SIZE + 1) + k], y_qs[j*WARP_SIZE + k],
|
x_ql[i * (WARP_SIZE + 1) + k], y_qs[j*WARP_SIZE + k],
|
||||||
x_dm[i * (WARP_SIZE/QI8_0) + k/QI8_0].x, y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
|
x_dm[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0].x, y_ds[j * (WARP_SIZE/QI8_1) + k/QI8_1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_q2_K_q8_1 1
|
#define VDR_q2_K_q8_1 1
|
||||||
@ -1776,9 +1938,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
|
|||||||
|
|
||||||
static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
|
||||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE / QI2_K)];
|
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI2_K) + GGML_CUDA_MMQ_Y/QI2_K];
|
||||||
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE / 4)];
|
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/4) + GGML_CUDA_MMQ_Y/4];
|
||||||
|
|
||||||
*x_ql = tile_x_ql;
|
*x_ql = tile_x_ql;
|
||||||
*x_dm = tile_x_dm;
|
*x_dm = tile_x_dm;
|
||||||
@ -1787,25 +1949,59 @@ static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 **
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_tiles_q2_K(
|
static __device__ __forceinline__ void load_tiles_q2_K(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
|
__builtin_assume(i_offset >= 0);
|
||||||
|
__builtin_assume(i_offset < 8);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI2_K;
|
const int kbx = k / QI2_K;
|
||||||
const int kqsx = k % QI2_K;
|
const int kqsx = k % QI2_K;
|
||||||
|
|
||||||
const block_q2_K * bx = ((block_q2_K *) vx) + i*blocks_per_row + kbx;
|
const block_q2_K * bx0 = (block_q2_K *) vx;
|
||||||
|
|
||||||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx);
|
#pragma unroll
|
||||||
x_dm[i * (WARP_SIZE / QI2_K) + kbx] = bx->dm;
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
|
||||||
x_sc[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8_aligned(bx->scales, kqsx / 4);
|
const int i = i0 + i_offset;
|
||||||
|
|
||||||
|
const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int blocks_per_tile_x_row = WARP_SIZE / QI2_K;
|
||||||
|
const int kbxd = k % blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI2_K) {
|
||||||
|
const int i = (i0 + i_offset * QI2_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
|
||||||
|
|
||||||
|
const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd;
|
||||||
|
|
||||||
|
x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
|
||||||
|
const int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
|
||||||
|
|
||||||
|
const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4);
|
||||||
|
|
||||||
|
x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
|
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
__builtin_assume(i >= 0);
|
||||||
__builtin_assume(j < WARP_SIZE);
|
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||||
__builtin_assume(k < WARP_SIZE);
|
__builtin_assume(j >= 0);
|
||||||
|
__builtin_assume(j < WARP_SIZE);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI2_K;
|
const int kbx = k / QI2_K;
|
||||||
const int kqsx = k % QI2_K;
|
const int kqsx = k % QI2_K;
|
||||||
@ -1813,7 +2009,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
|
|||||||
const int bq8_offset = QR2_K * (kqsx / QI8_1);
|
const int bq8_offset = QR2_K * (kqsx / QI8_1);
|
||||||
const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2);
|
const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2);
|
||||||
|
|
||||||
const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4))) + kbx*16 + scale_offset;
|
const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4) + i / 4)) + kbx*16 + scale_offset;
|
||||||
|
|
||||||
int u[QR2_K];
|
int u[QR2_K];
|
||||||
float d8[QR2_K];
|
float d8[QR2_K];
|
||||||
@ -1824,7 +2020,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
|
|||||||
d8[l] = y_ds[y_qs_index / QI8_1].x;
|
d8[l] = y_ds[y_qs_index / QI8_1].x;
|
||||||
}
|
}
|
||||||
|
|
||||||
return vec_dot_q2_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], u, scales, x_dm[i * (WARP_SIZE/QI2_K) + kbx], d8);
|
return vec_dot_q2_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], u, scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], d8);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_q3_K_q8_1 1
|
#define VDR_q3_K_q8_1 1
|
||||||
@ -1892,10 +2088,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
|
|||||||
|
|
||||||
static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
|
||||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE / QI2_K)];
|
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI3_K) + GGML_CUDA_MMQ_Y/QI3_K];
|
||||||
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE / 2)];
|
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2) + GGML_CUDA_MMQ_Y/2];
|
||||||
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE / 4)];
|
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/4) + GGML_CUDA_MMQ_Y/4];
|
||||||
|
|
||||||
*x_ql = tile_x_ql;
|
*x_ql = tile_x_ql;
|
||||||
*x_dm = tile_x_dm;
|
*x_dm = tile_x_dm;
|
||||||
@ -1905,33 +2101,79 @@ static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 **
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_tiles_q3_K(
|
static __device__ __forceinline__ void load_tiles_q3_K(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
|
__builtin_assume(i_offset >= 0);
|
||||||
|
__builtin_assume(i_offset < 8);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI3_K;
|
const int kbx = k / QI3_K;
|
||||||
const int kqsx = k % QI3_K;
|
const int kqsx = k % QI3_K;
|
||||||
|
|
||||||
const block_q3_K * bx = ((block_q3_K *) vx) + i*blocks_per_row + kbx;
|
const block_q3_K * bx0 = (block_q3_K *) vx;
|
||||||
|
|
||||||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->qs, kqsx);
|
#pragma unroll
|
||||||
x_dm[i * (WARP_SIZE / QI3_K) + kbx].x = bx->d;
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
|
||||||
x_qh[i * (WARP_SIZE / 2) + k/2] = get_int_from_uint8(bx->hmask, kqsx / 2);
|
const int i = i0 + i_offset;
|
||||||
x_sc[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8(bx->scales, kqsx / 4);
|
|
||||||
|
const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int blocks_per_tile_x_row = WARP_SIZE / QI3_K;
|
||||||
|
const int kbxd = k % blocks_per_tile_x_row;
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI3_K) {
|
||||||
|
const int i = (i0 + i_offset * QI3_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
|
||||||
|
|
||||||
|
const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd;
|
||||||
|
|
||||||
|
x_dm[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd].x = bxi->d;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 2) {
|
||||||
|
const int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
|
||||||
|
|
||||||
|
const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2);
|
||||||
|
|
||||||
|
x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = get_int_from_uint8(bxi->hmask, k % (QI3_K/2));
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
|
||||||
|
const int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
|
||||||
|
|
||||||
|
const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4);
|
||||||
|
|
||||||
|
x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8(bxi->scales, k % (QI3_K/4));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
|
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
|
__builtin_assume(i >= 0);
|
||||||
|
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||||
|
__builtin_assume(j >= 0);
|
||||||
|
__builtin_assume(j < WARP_SIZE);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI3_K;
|
const int kbx = k / QI3_K;
|
||||||
const int kqsx = k % QI3_K;
|
const int kqsx = k % QI3_K;
|
||||||
|
|
||||||
const int bq8_offset = QR3_K * (kqsx / (QI3_K/2));
|
const int bq8_offset = QR3_K * (kqsx / (QI3_K/2));
|
||||||
const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2);
|
const int scale_offset = kqsx - kqsx % QI8_1 + (kqsx % QI8_1) / (QI8_1/2);
|
||||||
|
|
||||||
const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4))) + kbx*16;
|
const uint8_t * scales = ((uint8_t *) (x_sc + i * (WARP_SIZE/4) + i / 4)) + kbx*16;
|
||||||
|
|
||||||
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
|
// invert the mask with ~ so that a 0/1 results in 4/0 being subtracted
|
||||||
const int vh = ~x_qh[i * (WARP_SIZE/2) + kbx * (QI3_K/2) + kqsx % (QI3_K/2)] >> bq8_offset;
|
const int vh = ~x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + kqsx % (QI3_K/2)] >> bq8_offset;
|
||||||
|
|
||||||
int u[QR3_K];
|
int u[QR3_K];
|
||||||
float d8[QR3_K];
|
float d8[QR3_K];
|
||||||
@ -1942,7 +2184,8 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat(
|
|||||||
d8[l] = y_ds[y_qs_index / QI8_1].x;
|
d8[l] = y_ds[y_qs_index / QI8_1].x;
|
||||||
}
|
}
|
||||||
|
|
||||||
return vec_dot_q3_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, scale_offset, x_dm[i * (WARP_SIZE/QI3_K) + kbx].x, d8);
|
return vec_dot_q3_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, scale_offset,
|
||||||
|
x_dm[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx].x, d8);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_q4_K_q8_1 2
|
#define VDR_q4_K_q8_1 2
|
||||||
@ -2068,9 +2311,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
|||||||
|
|
||||||
static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
|
||||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K)];
|
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K) + GGML_CUDA_MMQ_Y/QI4_K];
|
||||||
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (3*WARP_SIZE/32)];
|
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8) + GGML_CUDA_MMQ_Y/8];
|
||||||
|
|
||||||
*x_ql = tile_x_ql;
|
*x_ql = tile_x_ql;
|
||||||
*x_dm = tile_x_dm;
|
*x_dm = tile_x_dm;
|
||||||
@ -2079,25 +2322,59 @@ static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 **
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_tiles_q4_K(
|
static __device__ __forceinline__ void load_tiles_q4_K(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
const int kbx = k / QI4_K;
|
__builtin_assume(i_offset >= 0);
|
||||||
const int kqsx = k % QI4_K;
|
__builtin_assume(i_offset < 8);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const block_q4_K * bx = ((block_q4_K *) vx) + i*blocks_per_row + kbx;
|
const int kbx = k / QI4_K; // == 0 if QK_K == 256
|
||||||
|
const int kqsx = k % QI4_K; // == k if QK_K == 256
|
||||||
|
|
||||||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx);
|
const block_q4_K * bx0 = (block_q4_K *) vx;
|
||||||
x_dm[i * (WARP_SIZE / QI6_K) + kbx] = bx->dm;
|
|
||||||
x_sc[i * (3*WARP_SIZE/32) + k % (3*WARP_SIZE/32)] = get_int_from_uint8_aligned(bx->scales, k % (3*WARP_SIZE/32));
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
|
||||||
|
const int i = i0 + i_offset;
|
||||||
|
|
||||||
|
const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256
|
||||||
|
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI4_K) {
|
||||||
|
const int i = (i0 + i_offset * QI4_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
|
||||||
|
|
||||||
|
const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd;
|
||||||
|
|
||||||
|
x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
|
||||||
|
const int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
|
||||||
|
|
||||||
|
const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8);
|
||||||
|
|
||||||
|
x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_uint8_aligned(bxi->scales, k % (QI4_K/8));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
|
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
__builtin_assume(i >= 0);
|
||||||
__builtin_assume(j < WARP_SIZE);
|
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||||
__builtin_assume(k < WARP_SIZE);
|
__builtin_assume(j >= 0);
|
||||||
|
__builtin_assume(j < WARP_SIZE);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI6_K; // == 0 if QK_K == 256
|
const int kbx = k / QI6_K; // == 0 if QK_K == 256
|
||||||
const int kqsx = k % QI6_K; // == k if QK_K == 256
|
const int kqsx = k % QI6_K; // == k if QK_K == 256
|
||||||
@ -2112,7 +2389,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
|
|||||||
v[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 0];
|
v[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 0];
|
||||||
v[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4];
|
v[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4];
|
||||||
|
|
||||||
const uint16_t * scales = (const uint16_t *) &x_sc[i * (3*WARP_SIZE/32) + kbx * (3*WARP_SIZE/32)];
|
const uint16_t * scales = (const uint16_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + kbx * 4];
|
||||||
uint16_t aux[2];
|
uint16_t aux[2];
|
||||||
const int l = bq8_offset/2;
|
const int l = bq8_offset/2;
|
||||||
if (l < 2) {
|
if (l < 2) {
|
||||||
@ -2132,7 +2409,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
|
|||||||
d8[l] = y_ds[kqsy / QI8_1].x;
|
d8[l] = y_ds[kqsy / QI8_1].x;
|
||||||
}
|
}
|
||||||
|
|
||||||
return vec_dot_q4_K_q8_1_impl(v, u, sc, m, x_dm[i * (WARP_SIZE/QI4_K) + kbx], d8);
|
return vec_dot_q4_K_q8_1_impl(v, u, sc, m, x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K + kbx], d8);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_q5_K_q8_1 2
|
#define VDR_q5_K_q8_1 2
|
||||||
@ -2260,10 +2537,10 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
|||||||
|
|
||||||
static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
|
||||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K)];
|
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_K) + GGML_CUDA_MMQ_Y/QI5_K];
|
||||||
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/4)];
|
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/4) + GGML_CUDA_MMQ_Y/4];
|
||||||
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (3*WARP_SIZE/32)];
|
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8) + GGML_CUDA_MMQ_Y/8];
|
||||||
|
|
||||||
*x_ql = tile_x_ql;
|
*x_ql = tile_x_ql;
|
||||||
*x_dm = tile_x_dm;
|
*x_dm = tile_x_dm;
|
||||||
@ -2273,26 +2550,68 @@ static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 **
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_tiles_q5_K(
|
static __device__ __forceinline__ void load_tiles_q5_K(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
const int kbx = k / QI5_K;
|
__builtin_assume(i_offset >= 0);
|
||||||
const int kqsx = k % QI5_K;
|
__builtin_assume(i_offset < 8);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const block_q5_K * bx = ((block_q5_K *) vx) + i*blocks_per_row + kbx;
|
const int kbx = k / QI5_K; // == 0 if QK_K == 256
|
||||||
|
const int kqsx = k % QI5_K; // == k if QK_K == 256
|
||||||
|
|
||||||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bx->qs, kqsx);
|
const block_q5_K * bx0 = (block_q5_K *) vx;
|
||||||
x_dm[i * (WARP_SIZE / QI6_K) + kbx] = bx->dm;
|
|
||||||
x_qh[i * (WARP_SIZE / 4) + k/4] = get_int_from_uint8_aligned(bx->qh, kqsx/4);
|
#pragma unroll
|
||||||
x_sc[i * (3*WARP_SIZE/32) + k % (3*WARP_SIZE/32)] = get_int_from_uint8_aligned(bx->scales, k % (3*WARP_SIZE/32));
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
|
||||||
|
const int i = i0 + i_offset;
|
||||||
|
|
||||||
|
const block_q5_K * bxi = bx0 + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256
|
||||||
|
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI5_K) {
|
||||||
|
const int i = (i0 + i_offset * QI5_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
|
||||||
|
|
||||||
|
const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd;
|
||||||
|
|
||||||
|
x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 4) {
|
||||||
|
const int i = i0 + i_offset * 4 + k / (WARP_SIZE/4);
|
||||||
|
|
||||||
|
const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI5_K/4);
|
||||||
|
|
||||||
|
x_qh[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8(bxi->qh, k % (QI5_K/4));
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
|
||||||
|
const int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
|
||||||
|
|
||||||
|
const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8);
|
||||||
|
|
||||||
|
x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_uint8_aligned(bxi->scales, k % (QI5_K/8));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
|
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
__builtin_assume(i < 2*WARP_SIZE);
|
__builtin_assume(i >= 0);
|
||||||
__builtin_assume(j < WARP_SIZE);
|
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||||
__builtin_assume(k < WARP_SIZE);
|
__builtin_assume(j >= 0);
|
||||||
|
__builtin_assume(j < WARP_SIZE);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI6_K; // == 0 if QK_K == 256
|
const int kbx = k / QI6_K; // == 0 if QK_K == 256
|
||||||
const int kqsx = k % QI6_K; // == k if QK_K == 256
|
const int kqsx = k % QI6_K; // == k if QK_K == 256
|
||||||
@ -2307,10 +2626,10 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
|
|||||||
vl[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 0];
|
vl[0] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 0];
|
||||||
vl[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4];
|
vl[1] = x_ql[i * (WARP_SIZE + 1) + 4 * bq8_offset + kqsx % 4 + 4];
|
||||||
|
|
||||||
vh[0] = x_qh[i * (WARP_SIZE/4) + kqsx % 4 + 0] >> bq8_offset;
|
vh[0] = x_qh[i * (WARP_SIZE/4) + i/4 + kqsx % 4 + 0] >> bq8_offset;
|
||||||
vh[1] = x_qh[i * (WARP_SIZE/4) + kqsx % 4 + 4] >> bq8_offset;
|
vh[1] = x_qh[i * (WARP_SIZE/4) + i/4 + kqsx % 4 + 4] >> bq8_offset;
|
||||||
|
|
||||||
const uint16_t * scales = (const uint16_t *) &x_sc[i * (3*WARP_SIZE/32) + kbx * (3*WARP_SIZE/32)];
|
const uint16_t * scales = (const uint16_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + kbx * 4];
|
||||||
uint16_t aux[2];
|
uint16_t aux[2];
|
||||||
const int l = bq8_offset/2;
|
const int l = bq8_offset/2;
|
||||||
if (l < 2) {
|
if (l < 2) {
|
||||||
@ -2330,7 +2649,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat(
|
|||||||
d8[l] = y_ds[kqsy / QI8_1].x;
|
d8[l] = y_ds[kqsy / QI8_1].x;
|
||||||
}
|
}
|
||||||
|
|
||||||
return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, x_dm[i * (WARP_SIZE/QI4_K) + kbx], d8);
|
return vec_dot_q5_K_q8_1_impl(vl, vh, u, sc, m, x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K + kbx], d8);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_q6_K_q8_1 1
|
#define VDR_q6_K_q8_1 1
|
||||||
@ -2387,10 +2706,10 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
|
|||||||
|
|
||||||
static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||||
|
|
||||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE) + GGML_CUDA_MMQ_Y];
|
||||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K)];
|
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K) + GGML_CUDA_MMQ_Y/QI6_K];
|
||||||
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2)];
|
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2) + GGML_CUDA_MMQ_Y/2];
|
||||||
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)];
|
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8) + GGML_CUDA_MMQ_Y/8];
|
||||||
|
|
||||||
*x_ql = tile_x_ql;
|
*x_ql = tile_x_ql;
|
||||||
*x_dm = tile_x_dm;
|
*x_dm = tile_x_dm;
|
||||||
@ -2400,26 +2719,68 @@ static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 **
|
|||||||
|
|
||||||
static __device__ __forceinline__ void load_tiles_q6_K(
|
static __device__ __forceinline__ void load_tiles_q6_K(
|
||||||
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh,
|
||||||
int * __restrict__ x_sc, const int & i, const int & k, const int & blocks_per_row) {
|
int * __restrict__ x_sc, const int & i_offset, const int & k, const int & blocks_per_row) {
|
||||||
|
|
||||||
const int kbx = k / QI6_K;
|
__builtin_assume(i_offset >= 0);
|
||||||
const int kqsx = k % QI6_K;
|
__builtin_assume(i_offset < 8);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const block_q6_K * bx = ((block_q6_K *) vx) + i*blocks_per_row + kbx;
|
const int kbx = k / QI6_K; // == 0 if QK_K == 256
|
||||||
|
const int kqsx = k % QI6_K; // == k if QK_K == 256
|
||||||
|
|
||||||
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bx->ql, kqsx);
|
const block_q6_K * bx0 = (block_q6_K *) vx;
|
||||||
x_dm[i * (WARP_SIZE / QI6_K) + kbx].x = bx->d;
|
|
||||||
x_qh[i * (WARP_SIZE / 2) + k/2] = get_int_from_uint8(bx->qh, kqsx/2);
|
#pragma unroll
|
||||||
x_sc[i * (WARP_SIZE / 8) + k/8] = get_int_from_int8(bx->scales, kqsx/8);
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8) {
|
||||||
|
const int i = i0 + i_offset;
|
||||||
|
|
||||||
|
const block_q6_K * bxi = bx0 + i*blocks_per_row + kbx;
|
||||||
|
|
||||||
|
x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->ql, kqsx);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256
|
||||||
|
const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * QI6_K) {
|
||||||
|
const int i = (i0 + i_offset * QI6_K + k / blocks_per_tile_x_row) % GGML_CUDA_MMQ_Y;
|
||||||
|
|
||||||
|
const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd;
|
||||||
|
|
||||||
|
x_dm[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd].x = bxi->d;
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 2) {
|
||||||
|
const int i = i0 + i_offset * 2 + k / (WARP_SIZE/2);
|
||||||
|
|
||||||
|
const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI6_K/2);
|
||||||
|
|
||||||
|
x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = get_int_from_uint8(bxi->qh, k % (QI6_K/2));
|
||||||
|
}
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < GGML_CUDA_MMQ_Y; i0 += 8 * 8) {
|
||||||
|
const int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % GGML_CUDA_MMQ_Y;
|
||||||
|
|
||||||
|
const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4;
|
||||||
|
|
||||||
|
x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
|
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
|
||||||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||||
|
|
||||||
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
__builtin_assume(i >= 0);
|
||||||
__builtin_assume(j < WARP_SIZE);
|
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||||
__builtin_assume(k < WARP_SIZE);
|
__builtin_assume(j >= 0);
|
||||||
|
__builtin_assume(j < WARP_SIZE);
|
||||||
|
__builtin_assume(k >= 0);
|
||||||
|
__builtin_assume(k < WARP_SIZE);
|
||||||
|
|
||||||
const int kbx = k / QI6_K; // == 0 if QK_K == 256
|
const int kbx = k / QI6_K; // == 0 if QK_K == 256
|
||||||
const int kqsx = k % QI6_K; // == k if QK_K == 256
|
const int kqsx = k % QI6_K; // == k if QK_K == 256
|
||||||
@ -2428,9 +2789,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
|
|||||||
const int scale_offset = (QI6_K/4) * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/8);
|
const int scale_offset = (QI6_K/4) * (kqsx / (QI6_K/2)) + (kqsx % (QI6_K/2)) / (QI6_K/8);
|
||||||
const int vh_shift = 2 * ((kqsx % (QI6_K/2)) / (QI6_K/4));
|
const int vh_shift = 2 * ((kqsx % (QI6_K/2)) / (QI6_K/4));
|
||||||
|
|
||||||
const int vh = x_qh[i * (WARP_SIZE/2) + kbx * (QI6_K/2) + (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)] >> vh_shift;
|
const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI6_K/2) + (QI6_K/4) * (kqsx / (QI6_K/2)) + kqsx % (QI6_K/4)] >> vh_shift;
|
||||||
|
|
||||||
const int x_sc_offset = i * (WARP_SIZE/8) + kbx * (QI6_K/8);
|
const int x_sc_offset = i * (WARP_SIZE/8) + i/8 + kbx * (QI6_K/8);
|
||||||
const int8_t * scales = ((int8_t *) (x_sc + x_sc_offset)) + scale_offset;
|
const int8_t * scales = ((int8_t *) (x_sc + x_sc_offset)) + scale_offset;
|
||||||
|
|
||||||
int u[QR6_K];
|
int u[QR6_K];
|
||||||
@ -2442,7 +2803,8 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
|
|||||||
d8[l] = y_ds[kqsy / QI8_1].x;
|
d8[l] = y_ds[kqsy / QI8_1].x;
|
||||||
}
|
}
|
||||||
|
|
||||||
return vec_dot_q6_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales, x_dm[i * (WARP_SIZE/QI6_K) + kbx].x, d8);
|
return vec_dot_q6_K_q8_1_impl(x_ql[i * (WARP_SIZE + 1) + k], vh, u, scales,
|
||||||
|
x_dm[i * (WARP_SIZE/QI6_K) + i/QI6_K + kbx].x, d8);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qr, int qi, typename block_q_t,
|
template <int qk, int qr, int qi, typename block_q_t,
|
||||||
@ -2486,19 +2848,17 @@ static __global__ void mul_mat_q(
|
|||||||
|
|
||||||
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
|
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
|
||||||
|
|
||||||
for (int i = 0; i < GGML_CUDA_MMQ_Y; i += 8) {
|
load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
|
||||||
load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
|
tid_y, tid_x, blocks_per_row_x);
|
||||||
i + tid_y, tid_x, blocks_per_row_x);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int ir = 0; ir < qr; ++ir) {
|
for (int ir = 0; ir < qr; ++ir) {
|
||||||
const int kqs = ir*WARP_SIZE + tid_x;
|
const int kqs = ir*WARP_SIZE + tid_x;
|
||||||
const int kby = kqs / QI8_1;
|
const int kbxd = kqs / QI8_1;
|
||||||
|
|
||||||
for (int i = 0; i < WARP_SIZE; i += 8) {
|
for (int i = 0; i < WARP_SIZE; i += 8) {
|
||||||
const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
|
const int col_y_eff = min(col_y_0 + tid_y + i, ncols_y-1); // to prevent out-of-bounds memory accesses
|
||||||
|
|
||||||
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby];
|
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd];
|
||||||
|
|
||||||
tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = get_int_from_int8_aligned(by0->qs, tid_x % QI8_1);
|
tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = get_int_from_int8_aligned(by0->qs, tid_x % QI8_1);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user