iq1_s: CUDA is working

This commit is contained in:
Iwan Kawrakow 2024-02-11 13:08:26 +02:00
parent 80cd5bae99
commit a9d48e9718
6 changed files with 412 additions and 9 deletions

View File

@ -23,6 +23,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
{ "IQ2_XXS",LLAMA_FTYPE_MOSTLY_IQ2_XXS," 2.06 bpw quantization", },
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.63G, +0.6717 ppl @ LLaMA-v1-7B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.16G, +9.0634 ppl @ LLaMA-v1-7B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },
@ -287,9 +288,10 @@ int main(int argc, char ** argv) {
}
}
if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S) && imatrix_data.empty()) {
if ((params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || params.ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS ||
params.ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S || params.ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) && imatrix_data.empty()) {
fprintf(stderr, "\n===============================================================================================\n");
fprintf(stderr, "Please do not use IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
fprintf(stderr, "Please do not use IQ1_S, IQ2_XXS, IQ2_XS or Q2_K_S quantization without an importance matrix\n");
fprintf(stderr, "===============================================================================================\n\n\n");
return 1;
}

View File

@ -517,6 +517,15 @@ typedef struct {
} block_iq3_xxs;
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
#define QR1_S 8
#define QI1_S (QK_K / (4*QR1_S))
typedef struct {
half d;
uint8_t qs[QK_K/8];
uint8_t scales[QK_K/16];
} block_iq1_s;
static_assert(sizeof(block_iq1_s) == sizeof(ggml_fp16_t) + QK_K/8 + QK_K/16, "wrong iq1_s block size/padding");
#define WARP_SIZE 32
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@ -1681,6 +1690,137 @@ static const __device__ uint32_t iq3xxs_grid[256] = {
0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04,
};
static const __device__ uint64_t iq1s_grid[512] = {
0xffffffffff000101, 0xffffffffff01ff00, 0xffffffff00000000, 0xffffffff01ff00ff,
0xffffffff0101ffff, 0xffffffff0101ff01, 0xffffffff01010101, 0xffffff00ff000000,
0xffffff000000ff00, 0xffffff00000000ff, 0xffffff0000000100, 0xffffff0000010000,
0xffffff0001000000, 0xffffff01ff00ffff, 0xffffff01ff010100, 0xffffff0100ff01ff,
0xffffff0100000001, 0xffffff0101ffff00, 0xffffff0101010100, 0xffff00ffff00ff01,
0xffff00ffff0000ff, 0xffff00ff00ff0000, 0xffff00ff00010000, 0xffff00ff0100ff00,
0xffff00ff010001ff, 0xffff0000ffff01ff, 0xffff0000ff010001, 0xffff0000ff0101ff,
0xffff000000ffff00, 0xffff000000000000, 0xffff00000001ff00, 0xffff000001000101,
0xffff0000010100ff, 0xffff00010000ff00, 0xffff000101000000, 0xffff01ffffff0000,
0xffff01ff00000000, 0xffff01ff01ffffff, 0xffff01ff01ffff01, 0xffff01ff01010001,
0xffff0100ffffff01, 0xffff0100ff000101, 0xffff01000000ffff, 0xffff010000000001,
0xffff010000000100, 0xffff010001000000, 0xffff0101ff000000, 0xffff0101ff01ffff,
0xffff010100ff01ff, 0xffff010100ff0101, 0xffff0101000101ff, 0xffff010101ffffff,
0xffff01010101ff01, 0xffff010101010101, 0xff00ffffff000000, 0xff00ffff00ffff00,
0xff00ffff00000001, 0xff00ffff000001ff, 0xff00ffff01010000, 0xff00ff00ffff0000,
0xff00ff00ff00ff00, 0xff00ff00ff0000ff, 0xff00ff00ff000001, 0xff00ff00ff000100,
0xff00ff00ff010000, 0xff00ff0000ff0001, 0xff00ff000000ffff, 0xff00ff0000000000,
0xff00ff000001ff00, 0xff00ff0000010100, 0xff00ff0001ff0100, 0xff00ff000100ff00,
0xff00ff01ff000000, 0xff00ff010000ff00, 0xff00ff01010100ff, 0xff00ff0101010001,
0xff0000ffffffffff, 0xff0000ffffff0100, 0xff0000ff00000000, 0xff0000ff0001ff00,
0xff0000ff00010100, 0xff0000ff01ff0001, 0xff000000ff000000, 0xff000000ff01ff00,
0xff00000000ff00ff, 0xff0000000000ff00, 0xff00000000000000, 0xff000000000001ff,
0xff00000000000101, 0xff0000000001ffff, 0xff00000000010000, 0xff00000001000000,
0xff00000001010100, 0xff000001ff00ff01, 0xff000001ff0100ff, 0xff00000100ff0001,
0xff000001000000ff, 0xff00000100000100, 0xff0000010001ff00, 0xff00000101ff00ff,
0xff0000010100ff00, 0xff00000101010000, 0xff0001ffff000000, 0xff0001ffff01ffff,
0xff0001ff00ff00ff, 0xff0001ff00000101, 0xff0001ff00010000, 0xff0001ff01000000,
0xff0001000000ff01, 0xff00010000000000, 0xff00010000010100, 0xff00010001ffff00,
0xff00010001ff0100, 0xff000100010000ff, 0xff00010001010000, 0xff000101ffffff00,
0xff000101ff010001, 0xff00010101000001, 0xff000101010100ff, 0xff01ffffffff00ff,
0xff01ffffffff0001, 0xff01ffffff01ffff, 0xff01ffffff01ff01, 0xff01ffffff0101ff,
0xff01ffffff010101, 0xff01ffff00000000, 0xff01ffff0101ff01, 0xff01ffff010101ff,
0xff01ffff01010101, 0xff01ff00ff000000, 0xff01ff000000ff01, 0xff01ff0000000101,
0xff01ff0000010000, 0xff01ff00010000ff, 0xff01ff01ffffff01, 0xff01ff01ff01ff00,
0xff01ff01ff010101, 0xff01ff0100ff0000, 0xff01ff01000001ff, 0xff01ff0101ffff01,
0xff0100ffff010000, 0xff0100ff0100ffff, 0xff0100ff01000100, 0xff010000ffff0000,
0xff01000000ff0100, 0xff010000000000ff, 0xff01000000000001, 0xff0100000101ff00,
0xff010001ff00ffff, 0xff010001ff000100, 0xff01000100000000, 0xff01000100010001,
0xff01000101000101, 0xff0101ffffff0001, 0xff0101ffff0001ff, 0xff0101ffff010101,
0xff0101ff0000ff00, 0xff0101ff01ffff01, 0xff0101ff01ff01ff, 0xff0101ff01ff0101,
0xff0101ff01010001, 0xff010100ffffffff, 0xff010100ff000000, 0xff010100ff01ff01,
0xff01010000000100, 0xff01010001000000, 0xff010100010101ff, 0xff010101ffff0101,
0xff01010100ffff01, 0xff01010100ff01ff, 0xff0101010100ffff, 0x00ffffffffffffff,
0x00ffffffffff01ff, 0x00ffffff000000ff, 0x00ffffff00000100, 0x00ffffff00010000,
0x00ffffff01ff0101, 0x00ffff00ffff0001, 0x00ffff00ff0000ff, 0x00ffff00ff000100,
0x00ffff0000000000, 0x00ffff0001ffffff, 0x00ffff0001000100, 0x00ffff0001010001,
0x00ffff01ff01ff01, 0x00ffff010000ff00, 0x00ffff01000100ff, 0x00ff00ffff010101,
0x00ff00ff00ff0100, 0x00ff00ff0000ffff, 0x00ff00ff00000001, 0x00ff00ff000101ff,
0x00ff0000ff000000, 0x00ff0000ff01ffff, 0x00ff000000ff0001, 0x00ff00000000ff00,
0x00ff0000000000ff, 0x00ff000000000000, 0x00ff000000000101, 0x00ff000000010000,
0x00ff000001ffff01, 0x00ff000001000000, 0x00ff0001ff0001ff, 0x00ff0001ff000101,
0x00ff000100ffffff, 0x00ff000100ff0100, 0x00ff000100000001, 0x00ff0001010001ff,
0x00ff00010101ff00, 0x00ff01ffffffffff, 0x00ff01ffffff01ff, 0x00ff01ffff000000,
0x00ff01ff0001ff01, 0x00ff01ff01ff01ff, 0x00ff01ff01000101, 0x00ff01ff0101ffff,
0x00ff0100ff010000, 0x00ff010000ff00ff, 0x00ff010000000000, 0x00ff010000010101,
0x00ff01000100ff00, 0x00ff010001010000, 0x00ff01010000ff01, 0x00ff010100000100,
0x00ff010101ff0000, 0x0000ffffff00ff00, 0x0000ffffff0000ff, 0x0000ffffff010000,
0x0000ffff00ff0100, 0x0000ffff00000000, 0x0000ffff000100ff, 0x0000ffff00010101,
0x0000ffff01ffff01, 0x0000ffff01000100, 0x0000ff00ff000000, 0x0000ff00ff01ff00,
0x0000ff00ff0101ff, 0x0000ff0000ff0000, 0x0000ff000000ff00, 0x0000ff00000000ff,
0x0000ff0000000000, 0x0000ff0000000001, 0x0000ff0000000100, 0x0000ff0000010000,
0x0000ff0001000000, 0x0000ff000101ffff, 0x0000ff01ffffff00, 0x0000ff01ffff0101,
0x0000ff01ff010000, 0x0000ff0101000101, 0x0000ff0101010000, 0x000000ffffff0001,
0x000000ffff01ff01, 0x000000ff00ffff00, 0x000000ff000000ff, 0x000000ff00000000,
0x000000ff00010000, 0x000000ff01000000, 0x000000ff0101ff00, 0x00000000ff00ffff,
0x00000000ff00ff01, 0x00000000ff000000, 0x00000000ff000100, 0x00000000ff010000,
0x0000000000ffffff, 0x0000000000ffff01, 0x0000000000ff0000, 0x0000000000ff01ff,
0x0000000000ff0101, 0x000000000000ff00, 0x00000000000000ff, 0x0000000000000001,
0x0000000000000100, 0x000000000001ff00, 0x0000000000010000, 0x0000000001ff00ff,
0x000000000100ff00, 0x0000000001000000, 0x0000000001000100, 0x0000000001010000,
0x00000001ffff00ff, 0x00000001ff00ff00, 0x0000000100ffff00, 0x0000000100000000,
0x00000001000101ff, 0x0000000100010101, 0x0000000101ff0000, 0x0000000101000001,
0x000001ffff00ff00, 0x000001ffff0000ff, 0x000001ffff010100, 0x000001ff00ffff01,
0x000001ff0000ffff, 0x000001ff00000000, 0x000001ff0100ff00, 0x000001ff010000ff,
0x000001ff01010100, 0x00000100ffff0100, 0x00000100ff000000, 0x00000100ff01ff00,
0x0000010000ff0000, 0x000001000000ff00, 0x0000010000000000, 0x0000010000000100,
0x00000100000100ff, 0x0000010000010001, 0x0000010001000000, 0x000001000101ff01,
0x00000101ffff0001, 0x00000101000000ff, 0x0000010100000001, 0x0000010100010000,
0x0000010101ffff01, 0x0000010101ff01ff, 0x0000010101ff0101, 0x0001ffff00ffffff,
0x0001ffff00000100, 0x0001ffff0001ff00, 0x0001ffff01ff0101, 0x0001ffff01000000,
0x0001ff00ffff0000, 0x0001ff00ff00ff00, 0x0001ff00ff010001, 0x0001ff0000000000,
0x0001ff0001ffff00, 0x0001ff0001ff01ff, 0x0001ff0001010100, 0x0001ff01ff0000ff,
0x0001ff0100ffffff, 0x0001ff010001ffff, 0x0001ff010001ff01, 0x000100ffffff01ff,
0x000100ffff00ffff, 0x000100ffff000000, 0x000100ff00ff0000, 0x000100ff0000ff01,
0x000100ff00000101, 0x000100ff01010000, 0x00010000ffffff00, 0x00010000ff0000ff,
0x00010000ff010100, 0x0001000000ff00ff, 0x000100000000ffff, 0x0001000000000000,
0x00010000000001ff, 0x0001000000010000, 0x0001000001ff0001, 0x00010001ff000001,
0x00010001ff010000, 0x0001000100ff0000, 0x0001000100ff0101, 0x000100010000ff00,
0x0001000100000100, 0x000100010100ff01, 0x00010001010000ff, 0x000101ff00010000,
0x000101ff01ff0000, 0x00010100ffff0000, 0x0001010000000000, 0x000101000001ffff,
0x0001010000010101, 0x00010101ff00ff00, 0x00010101ff0001ff, 0x0001010100ffffff,
0x0001010101ff0000, 0x000101010101ff00, 0x01ffffffff000101, 0x01ffffffff01ffff,
0x01ffffffff01ff01, 0x01ffffff00000000, 0x01ffffff010100ff, 0x01ffff000000ff00,
0x01ffff0000000001, 0x01ffff00000001ff, 0x01ffff0000010000, 0x01ffff0001ff0000,
0x01ffff01ff000000, 0x01ffff01ff01ffff, 0x01ffff0101ff0101, 0x01ffff01010000ff,
0x01ff00ffffff0000, 0x01ff00ffff010000, 0x01ff00ff01ffffff, 0x01ff00ff0100ff01,
0x01ff00ff01010100, 0x01ff0000ffffffff, 0x01ff0000ffffff01, 0x01ff0000ffff01ff,
0x01ff0000ff00ff00, 0x01ff000000000000, 0x01ff00000001ff01, 0x01ff000001ff01ff,
0x01ff00000101ffff, 0x01ff0001ff010100, 0x01ff000101000001, 0x01ff000101010100,
0x01ff01ffff01ffff, 0x01ff01ff00ff0101, 0x01ff01ff01000000, 0x01ff0100ff000001,
0x01ff010000ffff00, 0x01ff010000000100, 0x01ff0100010101ff, 0x01ff0101ffff00ff,
0x01ff0101ffff0101, 0x01ff0101ff00ff00, 0x01ff01010001ffff, 0x01ff010100010001,
0x01ff0101010000ff, 0x0100ffff00ff00ff, 0x0100ffff00ff0001, 0x0100ffff00000100,
0x0100ffff0100ff00, 0x0100ff00ffff0000, 0x0100ff00ff00ffff, 0x0100ff00ff00ff01,
0x0100ff00ff000100, 0x0100ff00ff010000, 0x0100ff0000000000, 0x0100ff0000000101,
0x0100ff0001000100, 0x0100ff000101ff01, 0x0100ff0100ff00ff, 0x0100ff0100ff0001,
0x0100ff0100000100, 0x0100ff0100010001, 0x0100ff0101ffff00, 0x0100ff01010101ff,
0x010000ffff00ff00, 0x010000ffff0101ff, 0x010000ff0000ffff, 0x010000ff00000001,
0x010000ff01ff0101, 0x010000ff010001ff, 0x010000ff01010001, 0x01000000ffffff00,
0x01000000ffff0101, 0x01000000ff0000ff, 0x01000000ff000001, 0x01000000ff010101,
0x0100000000ff0000, 0x0100000000000000, 0x0100000000000100, 0x01000000000100ff,
0x0100000000010001, 0x01000000010000ff, 0x0100000001000001, 0x01000001ff000000,
0x010000010000ffff, 0x010000010000ff01, 0x0100000100010000, 0x0100000101000000,
0x010001ffff000101, 0x010001ff00ff00ff, 0x010001ff0000ff00, 0x010001ff000100ff,
0x01000100ffff0000, 0x01000100ff00ffff, 0x01000100ff0001ff, 0x0100010000000000,
0x010001000001ff00, 0x0100010001ff0000, 0x0100010001000101, 0x01000101ff0100ff,
0x0100010100ff0100, 0x0100010100010100, 0x0100010101ffffff, 0x0101ffffffffff00,
0x0101ffffff000101, 0x0101ffff00000000, 0x0101ffff000101ff, 0x0101ffff01010101,
0x0101ff00ff000000, 0x0101ff0000ff0100, 0x0101ff000000ff00, 0x0101ff0000010000,
0x0101ff00010000ff, 0x0101ff0001000001, 0x0101ff01ffff01ff, 0x0101ff01ff000101,
0x0101ff01ff01ff00, 0x0101ff01ff0101ff, 0x0101ff0100000000, 0x0101ff010101ff00,
0x010100ffff010000, 0x010100ff000000ff, 0x01010000ff000100, 0x01010000ff01ffff,
0x01010000ff01ff01, 0x0101000000ffff01, 0x0101000000000000, 0x0101000001ffffff,
0x010100000101ffff, 0x01010001ffff0000, 0x01010001000001ff, 0x0101000101ff0100,
0x0101000101010001, 0x010101ffffffff00, 0x010101ff00ff0001, 0x010101ff00000100,
0x010101ff0100ffff, 0x010101ff0100ff01, 0x010101ff01010101, 0x01010100ff000001,
0x0101010000ff01ff, 0x010101000000ff00, 0x01010100000101ff, 0x0101010001000000,
0x01010101ffffff01, 0x0101010100000101, 0x010101010001ff01, 0x01010101010100ff,
};
static const __device__ uint8_t ksigns_iq2xs[128] = {
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@ -1823,6 +1963,29 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
}
template<typename dst_t>
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq1_s * x = (const block_iq1_s *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const int i8 = 4*ib+il;
uint8_t h = x[i].scales[i8/2] >> 4*(i8%2);
const int8_t * grid = (const int8_t *)(iq1s_grid + (x[i].qs[i8] | ((h & 8) << 5)));
const float d = (float)x[i].d * (2*(h & 7) + 1);
for (int j = 0; j < 8; ++j) y[j] = d * grid[j];
#else
assert(false);
#endif
}
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@ -4522,6 +4685,49 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
#endif
}
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
#if QK_K == 256
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
const int ib32 = iqs;
int sumi1 = 0, sumi2 = 0, sumi3 = 0, sumi4 = 0;
const uint8_t h1 = bq1->scales[2*ib32+0];
const uint8_t h2 = bq1->scales[2*ib32+1];
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const int * q8 = (const int *)bq8_1[ib32].qs;
const int * grid1 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
const int * grid2 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
const int * grid3 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
const int * grid4 = (const int *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
for (int j = 0; j < 2; ++j) {
sumi1 = __dp4a(q8[j+0], grid1[j], sumi1);
sumi2 = __dp4a(q8[j+2], grid2[j], sumi2);
sumi3 = __dp4a(q8[j+4], grid3[j], sumi3);
sumi4 = __dp4a(q8[j+6], grid4[j], sumi4);
}
#else
const int8_t * q8 = bq8_1[ib32].qs;
const int8_t * grid1 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+0] | ((h1 & 0x08) << 5)));
const int8_t * grid2 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+1] | ((h1 & 0x80) << 1)));
const int8_t * grid3 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+2] | ((h2 & 0x08) << 5)));
const int8_t * grid4 = (const int8_t *)(iq1s_grid + (bq1->qs[4*ib32+3] | ((h2 & 0x80) << 1)));
for (int j = 0; j < 8; ++j) {
sumi1 += q8[j+ 0] * grid1[j];
sumi2 += q8[j+ 8] * grid2[j];
sumi3 += q8[j+16] * grid3[j];
sumi4 += q8[j+24] * grid4[j];
}
#endif
const float d = (float)bq1->d * __low2float(bq8_1[ib32].ds);
return d * (sumi1 * (2*(h1 & 7) + 1) + sumi2 * (2*((h1 >> 4) & 7) + 1) +
sumi3 * (2*(h2 & 7) + 1) + sumi4 * (2*((h2 >> 4) & 7) + 1));
#else
assert(false);
return 0.f;
#endif
}
template <int qk, int qr, int qi, bool need_sum, typename block_q_t, int mmq_x, int mmq_y, int nwarps,
allocate_tiles_cuda_t allocate_tiles, load_tiles_cuda_t load_tiles, int vdr, vec_dot_q_mul_mat_cuda_t vec_dot>
static __device__ __forceinline__ void mul_mat_q(
@ -6678,6 +6884,12 @@ static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@ -6717,6 +6929,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
default:
@ -6752,6 +6966,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
default:
@ -8530,6 +8746,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
default:
GGML_ASSERT(false);
@ -8553,6 +8770,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q6_K:
return 64;
@ -8650,6 +8868,10 @@ static void ggml_cuda_op_mul_mat_vec_q(
mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ1_S:
mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
default:
GGML_ASSERT(false);
break;
@ -11360,7 +11582,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
return false;
}
ggml_type a_type = a->type;
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS) {
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS || a_type == GGML_TYPE_IQ1_S) {
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
return false;
}

View File

@ -3710,6 +3710,49 @@ void dequantize_row_iq3_xxs(const block_iq3_xxs * restrict x, float * restrict y
}
}
// ====================== 1.5625 bpw (de)-quantization
void dequantize_row_iq1_s(const block_iq1_s * restrict x, float * restrict y, int k) {
assert(k % QK_K == 0);
const int nb = k / QK_K;
float db[4];
uint16_t idx[4];
//const int8_t * grid[4];
for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * sc = x[i].scales;
const uint8_t * qs = x[i].qs;
for (int i8 = 0; i8 < QK_K/8; i8 += 4) {
idx[0] = qs[0] | ((sc[0] & 0x08) << 5);
idx[1] = qs[1] | ((sc[0] & 0x80) << 1);
idx[2] = qs[2] | ((sc[1] & 0x08) << 5);
idx[3] = qs[3] | ((sc[1] & 0x80) << 1);
//grid[0] = (const int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5)));
//grid[1] = (const int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1)));
//grid[2] = (const int8_t *)(iq1s_grid + (qs[2] | ((sc[1] & 0x08) << 5)));
//grid[3] = (const int8_t *)(iq1s_grid + (qs[3] | ((sc[1] & 0x80) << 1)));
db[0] = d * (2*(sc[0] & 7) + 1);
db[1] = d * (2*((sc[0] >> 4) & 7) + 1);
db[2] = d * (2*(sc[1] & 7) + 1);
db[3] = d * (2*((sc[1] >> 4) & 7) + 1);
for (int l = 0; l < 4; ++l) {
const int8_t * grid = (const int8_t *)(iq1s_grid + idx[l]);
for (int j = 0; j < 8; ++j) {
//y[j] = db[l] * grid[l][j];
y[j] = db[l] * grid[j];
}
y += 8;
}
qs += 4;
sc += 2;
}
}
}
//===================================== Q8_K ==============================================
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int k) {
@ -10378,3 +10421,131 @@ static int iq1_sort_helper(const void * left, const void * right) {
return *l < *r ? -1 : *l > *r ? 1 : 0;
}
static void quantize_row_iq1_s_impl(const float * restrict x, void * restrict vy, int n, const float * restrict quant_weights) {
const int gindex = iq2_data_index(GGML_TYPE_IQ1_S);
const uint64_t * kgrid_q2xs = iq2_data[gindex].grid;
const int * kmap_q2xs = iq2_data[gindex].map;
const uint16_t * kneighbors_q2xs = iq2_data[gindex].neighbours;
GGML_ASSERT(quant_weights && "missing quantization weights");
GGML_ASSERT(kgrid_q2xs && "forgot to call ggml_quantize_init()?");
GGML_ASSERT(kmap_q2xs && "forgot to call ggml_quantize_init()?");
GGML_ASSERT(kneighbors_q2xs && "forgot to call ggml_quantize_init()?");
GGML_ASSERT(n%QK_K == 0);
const int nbl = n/256;
block_iq1_s * y = vy;
float scales[QK_K/8];
float weight[8];
int8_t L[8];
float sumx[9];
float sumw[9];
float pairs[16];
int * idx = (int *)(pairs + 1);
uint8_t hbit[QK_K/8];
for (int ibl = 0; ibl < nbl; ++ibl) {
y[ibl].d = GGML_FP32_TO_FP16(0.f);
memset(y[ibl].qs, 0, QK_K/8);
memset(y[ibl].scales, 0, QK_K/16);
float max_scale = 0;
const float * xbl = x + QK_K*ibl;
float sumx2 = 0;
for (int i = 0; i < QK_K; ++i) sumx2 += xbl[i]*xbl[i];
float sigma2 = sumx2/QK_K;
for (int ib = 0; ib < QK_K/8; ++ib) {
const float * xb = xbl + 8*ib;
const float * qw = quant_weights + QK_K*ibl + 8*ib;
for (int i = 0; i < 8; ++i) weight[i] = qw[i] * sqrtf(sigma2 + xb[i]*xb[i]);
float max = fabsf(xb[0]);
for (int i = 1; i < 8; ++i) max = MAX(max, fabsf(xb[i]));
if (!max) {
scales[ib] = 0;
memset(L, 1, 8);
continue;
}
for (int j = 0; j < 8; ++j) {
pairs[2*j] = xb[j];
idx[2*j] = j;
}
qsort(pairs, 8, 2*sizeof(float), iq1_sort_helper);
{
sumx[0] = sumw[0] = 0;
for (int j = 0; j < 8; ++j) {
int i = idx[2*j];
sumx[j+1] = sumx[j] + weight[i]*xb[i];
sumw[j+1] = sumw[j] + weight[i];
}
}
float best_score = 0, scale = max;
int besti1 = 0, besti2 = 0;
for (int i1 = 0; i1 <= 8; ++i1) {
for (int i2 = i1; i2 <= 8; ++i2) {
float sumqx = -(sumx[i1] - sumx[0]) + (sumx[8] - sumx[i2]);
float sumq2 = (sumw[i1] - sumw[0]) + (sumw[8] - sumw[i2]);
if (sumq2 > 0 && sumqx*sumqx > best_score*sumq2) {
scale = sumqx/sumq2; best_score = scale*sumqx;
besti1 = i1; besti2 = i2;
}
}
}
for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0;
for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1;
for (int j = besti2; j < 8; ++j) L[idx[2*j]] = 2;
if (scale < 0) {
for (int j = 0; j < 8; ++j) L[j] = 2 - L[j];
scale = -scale;
}
uint16_t u = 0;
for (int j = 0; j < 8; ++j) u |= (L[j] << 2*j);
int grid_index = kmap_q2xs[u];
if (grid_index < 0) {
const uint16_t * neighbours = kneighbors_q2xs - kmap_q2xs[u] - 1;
grid_index = iq1_find_best_neighbour(neighbours, kgrid_q2xs, xb, weight, &scale, L, NGRID_IQ2XXS);
GGML_ASSERT(grid_index >= 0);
}
y[ibl].qs[ib] = grid_index & 255;
hbit[ib] = grid_index >> 8;
GGML_ASSERT(scale >= 0);
scales[ib] = scale;
max_scale = MAX(max_scale, scale);
}
if (!max_scale) {
memset(y[ibl].qs, 0, QK_K/8);
continue;
}
float d = max_scale/15;
//y[ibl].d = GGML_FP32_TO_FP16(d*1.075f); // 1.075f is another fudge factor. Don't ask me why it is needed.
y[ibl].d = GGML_FP32_TO_FP16(d*1.085f); // 1.08f is another fudge factor. Don't ask me why it is needed.
float id = 1/d;
for (int ib = 0; ib < QK_K/8; ++ib) {
int l = nearest_int(0.5f*(id*scales[ib]-1));
l = MAX(0, MIN(7, l));
if (hbit[ib]) l |= 8;
y[ibl].scales[ib/2] |= (l << 4*(ib%2));
}
}
}
size_t quantize_iq1_s(const float * src, void * dst, int nrow, int n_per_row, int64_t * hist, const float * quant_weights) {
(void)hist;
GGML_ASSERT(n_per_row%QK_K == 0);
int nblock = n_per_row/QK_K;
char * qrow = (char *)dst;
for (int row = 0; row < nrow; ++row) {
quantize_row_iq1_s_impl(src, qrow, n_per_row, quant_weights);
src += n_per_row;
qrow += nblock*sizeof(block_iq1_s);
}
return nrow * nblock * sizeof(block_iq1_s);
}

6
ggml.c
View File

@ -19071,8 +19071,10 @@ void ggml_quantize_init(enum ggml_type type) {
void ggml_quantize_free(void) {
ggml_critical_section_start();
iq2xs_free_impl(256);
iq2xs_free_impl(512);
iq2xs_free_impl(GGML_TYPE_IQ2_XXS);
iq2xs_free_impl(GGML_TYPE_IQ2_XS);
iq2xs_free_impl(GGML_TYPE_IQ1_S);
iq3xs_free_impl(256);
ggml_critical_section_end();
}

View File

@ -2495,6 +2495,7 @@ struct llama_model_loader {
case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break;
case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break;
case GGML_TYPE_IQ3_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ3_XXS; break;
case GGML_TYPE_IQ1_S: ftype = LLAMA_FTYPE_MOSTLY_IQ1_S; break;
default:
{
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
@ -2844,6 +2845,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
case LLAMA_FTYPE_MOSTLY_Q3_K_XS:return "Q3_K - Extra small";
case LLAMA_FTYPE_MOSTLY_IQ3_XXS:return "IQ3_XXS - 3.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_S :return "IQ1_S - 1.5625 bpw";
default: return "unknown, may not work";
}
@ -10102,20 +10104,20 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
if (arch == LLM_ARCH_FALCON || nx % QK_K != 0) {
new_type = GGML_TYPE_Q8_0;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) {
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
new_type = GGML_TYPE_Q5_K;
}
else if (new_type != GGML_TYPE_Q8_0) {
new_type = GGML_TYPE_Q6_K;
}
} else if (name == "token_embd.weight") {
if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) {
if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
new_type = GGML_TYPE_Q2_K;
}
else if (ftype == LLAMA_FTYPE_MOSTLY_IQ3_XXS) {
new_type = GGML_TYPE_Q4_K;
}
} else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS) {
} else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
if (name.find("attn_v.weight") != std::string::npos) {
if (qs.model.hparams.n_gqa() >= 4 || qs.model.hparams.n_expert >= 4) new_type = GGML_TYPE_Q4_K;
else new_type = GGML_TYPE_Q2_K;
@ -10258,7 +10260,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K ||
new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K ||
new_type == GGML_TYPE_IQ2_XS || new_type == GGML_TYPE_IQ2_XXS ||
new_type == GGML_TYPE_IQ3_XXS) {
new_type == GGML_TYPE_IQ3_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S) {
int nx = tensor->ne[0];
int ny = tensor->ne[1];
if (nx % QK_K != 0) {
@ -10273,6 +10275,7 @@ static ggml_type get_k_quant_type(quantize_state_internal & qs, ggml_type new_ty
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_Q2_K: new_type = GGML_TYPE_Q4_0; break;
case GGML_TYPE_Q3_K: new_type = GGML_TYPE_Q4_1; break;
case GGML_TYPE_Q4_K: new_type = GGML_TYPE_Q5_0; break;
@ -10315,6 +10318,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: quantized_type = GGML_TYPE_IQ2_XXS; break;
case LLAMA_FTYPE_MOSTLY_IQ2_XS: quantized_type = GGML_TYPE_IQ2_XS; break;
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: quantized_type = GGML_TYPE_IQ3_XXS; break;
case LLAMA_FTYPE_MOSTLY_IQ1_S: quantized_type = GGML_TYPE_IQ1_S ; break;
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
}
@ -10488,6 +10492,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
}
if ((new_type == GGML_TYPE_IQ2_XXS ||
new_type == GGML_TYPE_IQ2_XS ||
new_type == GGML_TYPE_IQ1_S ||
(new_type == GGML_TYPE_Q2_K && params->ftype == LLAMA_FTYPE_MOSTLY_Q2_K_S && strcmp(tensor->name, "token_embd.weight") != 0)) && !imatrix) {
LLAMA_LOG_ERROR("\n\n============================================================\n");
LLAMA_LOG_ERROR("Missing importance matrix for tensor %s in a very low-bit quantization\n", tensor->name);

View File

@ -100,6 +100,7 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_Q2_K_S = 21, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_XS = 22, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ3_XXS = 23, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_S = 24, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};