ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b

This commit is contained in:
Francis Couture-Harpin 2024-06-19 12:21:08 -04:00
parent ac146628e4
commit bd807499f7
11 changed files with 594 additions and 4 deletions

View File

@ -296,12 +296,27 @@ class Model:
))
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
if self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3 and not any(
self.match_model_tensor_name(new_name, key, None)
for key in [
gguf.MODEL_TENSOR.TOKEN_EMBD,
gguf.MODEL_TENSOR.OUTPUT,
]
):
data = gguf.quantize_q1_3(data)
assert data.dtype == np.uint8
data_qtype = gguf.GGMLQuantizationType.Q1_3
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
data = gguf.quantize_bf16(data)
assert data.dtype == np.int16
data_qtype = gguf.GGMLQuantizationType.BF16
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
elif (
self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0
or self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3
and gguf.can_quantize_to_q8_0(data)
):
data = gguf.quantize_q8_0(data)
assert data.dtype == np.uint8
data_qtype = gguf.GGMLQuantizationType.Q8_0
@ -1412,6 +1427,12 @@ class LlamaModel(Model):
class BitnetModel(Model):
model_arch = gguf.MODEL_ARCH.BITNET
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, *args, **kwargs):
if ftype == gguf.LlamaFileType.GUESSED:
ftype = gguf.LlamaFileType.MOSTLY_Q1_3
super().__init__(dir_model, ftype, *args, **kwargs)
def set_vocab(self):
self._set_vocab_sentencepiece()

View File

@ -26,6 +26,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
{ "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", },
{ "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", },
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
{ "Q1_3", LLAMA_FTYPE_MOSTLY_Q1_3, " 1.63 bpw for BitNet 1.58b", },
{ "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.00 bpw for BitNet 1.58b", },
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", },
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", },
{ "IQ3_XXS",LLAMA_FTYPE_MOSTLY_IQ3_XXS," 3.06 bpw quantization", },

View File

@ -383,6 +383,8 @@ extern "C" {
GGML_TYPE_F64 = 28,
GGML_TYPE_IQ1_M = 29,
GGML_TYPE_BF16 = 30,
GGML_TYPE_Q2_2 = 31,
GGML_TYPE_Q1_3 = 32,
GGML_TYPE_COUNT,
};

View File

@ -137,6 +137,20 @@ typedef sycl::half2 ggml_half2;
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
// 1.625 bpw for BitNet 1.58b models
#define QK1_3 64
typedef struct {
uint8_t q[(QK1_3 - 4*QK1_3/64)/5]; // 5 elements per byte (3^5 = 243 < 256)
uint8_t qs[QK1_3/64]; // 4 elements per byte
} block_q1_3;
static_assert(sizeof(block_q1_3) == (QK1_3 - 4*QK1_3/64)/5 + QK1_3/64, "wrong q1_3 block size/padding");
#define QK2_2 32
typedef struct {
uint8_t qs[QK2_2 / 4]; // nibbles / quants
} block_q2_2;
static_assert(sizeof(block_q2_2) == QK2_2 / 4, "wrong q2_2 block size/padding");
#define QK4_0 32
typedef struct {
ggml_half d; // delta
@ -333,6 +347,7 @@ typedef struct {
} block_iq3_s;
static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding");
// 1.5625 bpw
typedef struct {
ggml_half d;
uint8_t qs[QK_K/8];
@ -1022,6 +1037,108 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
GGML_TABLE_END()
GGML_TABLE_BEGIN(uint32_t, q22_grid, 256)
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00010000, 0x01010000, 0x00010000, 0xff010000,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
0x00000100, 0x01000100, 0x00000100, 0xff000100,
0x00010100, 0x01010100, 0x00010100, 0xff010100,
0x00000100, 0x01000100, 0x00000100, 0xff000100,
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00010000, 0x01010000, 0x00010000, 0xff010000,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
0x00000001, 0x01000001, 0x00000001, 0xff000001,
0x00010001, 0x01010001, 0x00010001, 0xff010001,
0x00000001, 0x01000001, 0x00000001, 0xff000001,
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
0x00000101, 0x01000101, 0x00000101, 0xff000101,
0x00010101, 0x01010101, 0x00010101, 0xff010101,
0x00000101, 0x01000101, 0x00000101, 0xff000101,
0x00ff0101, 0x01ff0101, 0x00ff0101, 0xffff0101,
0x00000001, 0x01000001, 0x00000001, 0xff000001,
0x00010001, 0x01010001, 0x00010001, 0xff010001,
0x00000001, 0x01000001, 0x00000001, 0xff000001,
0x00ff0001, 0x01ff0001, 0x00ff0001, 0xffff0001,
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
0x0001ff01, 0x0101ff01, 0x0001ff01, 0xff01ff01,
0x0000ff01, 0x0100ff01, 0x0000ff01, 0xff00ff01,
0x00ffff01, 0x01ffff01, 0x00ffff01, 0xffffff01,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00010000, 0x01010000, 0x00010000, 0xff010000,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
0x00000100, 0x01000100, 0x00000100, 0xff000100,
0x00010100, 0x01010100, 0x00010100, 0xff010100,
0x00000100, 0x01000100, 0x00000100, 0xff000100,
0x00ff0100, 0x01ff0100, 0x00ff0100, 0xffff0100,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00010000, 0x01010000, 0x00010000, 0xff010000,
0x00000000, 0x01000000, 0x00000000, 0xff000000,
0x00ff0000, 0x01ff0000, 0x00ff0000, 0xffff0000,
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
0x0001ff00, 0x0101ff00, 0x0001ff00, 0xff01ff00,
0x0000ff00, 0x0100ff00, 0x0000ff00, 0xff00ff00,
0x00ffff00, 0x01ffff00, 0x00ffff00, 0xffffff00,
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
0x000101ff, 0x010101ff, 0x000101ff, 0xff0101ff,
0x000001ff, 0x010001ff, 0x000001ff, 0xff0001ff,
0x00ff01ff, 0x01ff01ff, 0x00ff01ff, 0xffff01ff,
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
0x000100ff, 0x010100ff, 0x000100ff, 0xff0100ff,
0x000000ff, 0x010000ff, 0x000000ff, 0xff0000ff,
0x00ff00ff, 0x01ff00ff, 0x00ff00ff, 0xffff00ff,
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
0x0001ffff, 0x0101ffff, 0x0001ffff, 0xff01ffff,
0x0000ffff, 0x0100ffff, 0x0000ffff, 0xff00ffff,
0x00ffffff, 0x01ffffff, 0x00ffffff, 0xffffffff,
GGML_TABLE_END()
GGML_TABLE_BEGIN(uint32_t, q1_3_grid, 256)
0xffffffff, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff,
0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff000000, 0xff000001,
0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff,
0xff010000, 0xff010001, 0xff0101ff, 0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01,
0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00,
0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101,
0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000, 0x00010001, 0x000101ff, 0x00010100,
0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001,
0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000,
0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x0101ff01,
0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00,
0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff,
0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100,
0xff000101, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff,
0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0000,
0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff,
0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01,
0x000100ff, 0x00010000, 0x00010000, 0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff,
0x01ffff00, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101,
0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x01000001, 0x010001ff,
0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001,
0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000,
0xffff0001, 0xffff01ff, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01,
0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ff00,
0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff, 0xff0101ff, 0xff010100, 0xff010101,
0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100,
0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff,
0x00000100, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000,
0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ff00ff,
0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x01ff0101, 0x0100ffff, 0x0100ff00,
0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff,
0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101,
GGML_TABLE_END()
#define NGRID_IQ1S 2048
#define IQ1S_DELTA 0.125f
#define IQ1M_DELTA 0.125f

View File

@ -1630,7 +1630,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6
// ===================== Helper functions
//
static inline int nearest_int(float fval) {
assert(fval <= 4194303.f);
assert(fabsf(fval) <= 4194303.f);
float val = fval + 12582912.f;
int i; memcpy(&i, &val, sizeof(int));
return (i & 0x007fffff) - 0x00400000;
@ -3306,6 +3306,140 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
return nrow * row_size;
}
size_t quantize_q2_2(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used
const size_t row_size = ggml_row_size(GGML_TYPE_Q2_2, n_per_row);
quantize_row_q2_2_reference(src, dst, (int64_t)nrow*n_per_row);
return nrow * row_size;
}
// ====================== 1.625 bpw (de)-quantization (BitNet 1.58b)
void quantize_row_q1_3_reference(const float * restrict x, block_q1_3 * restrict y, int64_t k) {
assert(k % QK1_3 == 0);
const int64_t nb = k / QK1_3;
static_assert(sizeof(y->q) % 4 == 0, "bad block_q1_3.q size");
const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
for (int64_t i = 0; i < nb; ++i) {
uint8_t q[sizeof(y->q)] = {0};
for (size_t j = 0; j < sizeof(y->q); ++j) {
for (size_t m = 0; m < 4; ++m) {
int xi = nearest_int(x[m]);
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
q[j] += xt * pow3[m];
}
x += 4;
}
for (size_t j = 0; j < sizeof(y->q); ++j) {
int xi = nearest_int(x[j]);
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
q[j] += xt * pow3[4];
q[j] = ((uint16_t)q[j] * 256) / pow3[5];
q[j] += (uint8_t)(q[j] != 0);
y[i].q[j] = q[j];
}
x += sizeof(y->q);
for (size_t j = 0; j < sizeof(y->qs); ++j) {
uint8_t qb = 0;
for (size_t m = 0; m < 4; ++m) {
int xi = nearest_int(x[m]);
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
qb += xt * pow3[m];
}
x += 4;
qb = ((uint16_t)qb * 256) / pow3[5];
qb += (uint8_t)(qb != 0);
y[i].qs[j] = qb;
}
}
}
void quantize_row_q1_3(const float * restrict x, void * restrict vy, int64_t k) {
assert(k % QK1_3 == 0);
block_q1_3 * restrict y = vy;
quantize_row_q1_3_reference(x, y, k);
}
size_t quantize_q1_3(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
(void)quant_weights; // not used
const size_t row_size = ggml_row_size(GGML_TYPE_Q1_3, n_per_row);
quantize_row_q1_3(src, dst, (int64_t)nrow*n_per_row);
return nrow * row_size;
}
void dequantize_row_q1_3(const block_q1_3 * restrict x, float * restrict y, int64_t k) {
assert(k % QK1_3 == 0);
const int64_t nb = k / QK1_3;
static_assert(sizeof(x->q) % 4 == 0, "bad block_q1_3.q size");
// #if defined(__SSE2__)
// __m128 vscale = _mm_set1_ps(scale);
// for (int64_t i = 0; i < nb; ++i) {
// for (size_t j = 0; j < sizeof(x->q); j += 4) {
// __m128 q1 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 0]]));
// __m128 q2 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 1]]));
// __m128 q3 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 2]]));
// __m128 q4 = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].q[j + 3]]));
// q1 = _mm_mul_ps(q1, vscale);
// q2 = _mm_mul_ps(q2, vscale);
// q3 = _mm_mul_ps(q3, vscale);
// q4 = _mm_mul_ps(q4, vscale);
// _mm_store_ps(y + 0, q1);
// _mm_store_ps(y + 4, q2);
// _mm_store_ps(y + 8, q3);
// _mm_store_ps(y + 12, q4);
// y += 16;
// }
// for (size_t j = 0; j < sizeof(x->q); j += 4) {
// __m128i q5i = _mm_loadu_si32(x[i].q + j);
// q5i = _mm_cvtepi8_epi16(q5i);
// q5i = _mm_add_epi16(q5i, _mm_add_epi16(q5i, q5i));
// q5i = _mm_srli_epi16(q5i, 8);
// q5i = _mm_sub_epi16(q5i, _mm_set1_epi16(1));
// __m128 q5 = _mm_cvtepi32_ps(_mm_cvtepi16_epi32(q5i));
// q5 = _mm_mul_ps(q5, vscale);
// _mm_store_ps(y, q5);
// y += 4;
// }
// for (size_t j = 0; j < sizeof(x->qs); ++j) {
// __m128 q = _mm_cvtpi8_ps(_m_from_int(q1_3_grid[x[i].qs[j]]));
// q = _mm_mul_ps(q, vscale);
// _mm_store_ps(y, q);
// y += 4;
// }
// }
// #else
for (int64_t i = 0; i < nb; ++i) {
for (size_t j = 0; j < sizeof(x->q); ++j) {
const int8_t * q = (const int8_t *) (q1_3_grid + x[i].q[j]);
for (int m = 0; m < 4; ++m) {
*y++ = (float) q[m];
}
}
for (size_t j = 0; j < sizeof(x->q); ++j) {
uint16_t q = x[i].q[j];
*y++ = (float) ((int16_t)((q * 3) >> 8) - 1);
}
for (size_t j = 0; j < sizeof(x->qs); ++j) {
const int8_t * q = (const int8_t *) (q1_3_grid + x[i].qs[j]);
for (int m = 0; m < 4; ++m) {
*y++ = (float) q[m];
}
}
}
// #endif
}
// ====================== "True" 2-bit (de)-quantization
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
@ -3726,6 +3860,122 @@ static inline __m128i get_scale_shuffle(int i) {
}
#endif
void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
const block_q2_2 * restrict x = vx;
const block_q8_0 * restrict y = vy;
#if defined(__AVX2__)
__m256 acc = _mm256_setzero_ps();
int leftovers = nb % 2;
for (int i = 0; i < nb - leftovers; i += 2) {
const __m256 d0 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 0].d) );
const __m256 d1 = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i + 1].d) );
// assuming two consecutive blocks are contiguous AND aligned
__m128i xq16b = _mm_load_si128((const __m128i *) (x[i].qs));
__m256i xq16 = MM256_SET_M128I(xq16b, xq16b);
__m256i xq8l0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
4, -1, 4, -1, 4, -1, 4, -1,
1, -1, 1, -1, 1, -1, 1, -1,
0, -1, 0, -1, 0, -1, 0, -1));
__m256i xq8h0 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
6, -1, 6, -1, 6, -1, 6, -1,
3, -1, 3, -1, 3, -1, 3, -1,
2, -1, 2, -1, 2, -1, 2, -1));
__m256i xq8l1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(13, -1, 13, -1, 13, -1, 13, -1,
12, -1, 12, -1, 12, -1, 12, -1,
9, -1, 9, -1, 9, -1, 9, -1,
8, -1, 8, -1, 8, -1, 8, -1));
__m256i xq8h1 = _mm256_shuffle_epi8(xq16, _mm256_set_epi8(15, -1, 15, -1, 15, -1, 15, -1,
14, -1, 14, -1, 14, -1, 14, -1,
11, -1, 11, -1, 11, -1, 11, -1,
10, -1, 10, -1, 10, -1, 10, -1));
__m256i shift = _mm256_set_epi16(64, 16, 4, 1,
64, 16, 4, 1,
64, 16, 4, 1,
64, 16, 4, 1);
xq8l0 = _mm256_mullo_epi16(xq8l0, shift);
xq8h0 = _mm256_mullo_epi16(xq8h0, shift);
xq8l1 = _mm256_mullo_epi16(xq8l1, shift);
xq8h1 = _mm256_mullo_epi16(xq8h1, shift);
xq8l0 = _mm256_srai_epi16(xq8l0, 14);
xq8h0 = _mm256_srai_epi16(xq8h0, 14);
xq8l1 = _mm256_srai_epi16(xq8l1, 14);
xq8h1 = _mm256_srai_epi16(xq8h1, 14);
__m256i xq8_0 = _mm256_packs_epi16(xq8l0, xq8h0);
__m256i xq8_1 = _mm256_packs_epi16(xq8l1, xq8h1);
__m256i yq8_0 = _mm256_lddqu_si256((const __m256i *) (y[i + 0].qs));
__m256i yq8_1 = _mm256_lddqu_si256((const __m256i *) (y[i + 1].qs));
const __m256 q0 = mul_sum_i8_pairs_float(xq8_0, yq8_0);
const __m256 q1 = mul_sum_i8_pairs_float(xq8_1, yq8_1);
acc = _mm256_fmadd_ps( d0, q0, acc );
acc = _mm256_fmadd_ps( d1, q1, acc );
}
for (int i = nb - leftovers; i < nb; ++i) {
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) );
__m128i xq8b = _mm_loadu_si64(x[i].qs);
__m256i xq8 = MM256_SET_M128I(xq8b, xq8b);
__m256i xq8l = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
4, -1, 4, -1, 4, -1, 4, -1,
1, -1, 1, -1, 1, -1, 1, -1,
0, -1, 0, -1, 0, -1, 0, -1));
__m256i xq8h = _mm256_shuffle_epi8(xq8, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
6, -1, 6, -1, 6, -1, 6, -1,
3, -1, 3, -1, 3, -1, 3, -1,
2, -1, 2, -1, 2, -1, 2, -1));
__m256i shift = _mm256_set_epi16(64, 16, 4, 1,
64, 16, 4, 1,
64, 16, 4, 1,
64, 16, 4, 1);
xq8l = _mm256_mullo_epi16(xq8l, shift);
xq8h = _mm256_mullo_epi16(xq8h, shift);
xq8l = _mm256_srai_epi16(xq8l, 14);
xq8h = _mm256_srai_epi16(xq8h, 14);
xq8 = _mm256_packs_epi16(xq8l, xq8h);
__m256i yq8 = _mm256_lddqu_si256((const __m256i *) (y[i].qs));
const __m256 q = mul_sum_i8_pairs_float(xq8, yq8);
acc = _mm256_fmadd_ps( d, q, acc );
}
*s = hsum_float_8(acc);
#else
float sumf = 0.0;
for (int i = 0; i < nb; i++) {
int sumi = 0;
for (int j = 0; j < qk / 4; j++) {
const int8_t* weight = (const int8_t *)(q22_grid + x[i].qs[j]);
sumi += (int)y[i].qs[4*j+0] * weight[0];
sumi += (int)y[i].qs[4*j+1] * weight[1];
sumi += (int)y[i].qs[4*j+2] * weight[2];
sumi += (int)y[i].qs[4*j+3] * weight[3];
}
sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d));
}
*s = sumf;
#endif
}
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0;
const int nb = n / qk;
@ -11102,6 +11352,105 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
}
#endif
void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
UNUSED(bx);
UNUSED(by);
UNUSED(bs);
// assumed by the code below
assert(n % QK1_3 == 0);
static_assert(QK1_3 == 2 * QK8_0, "QK1_3 must be 2 times bigger than QK8_0");
const block_q1_3 * restrict x = vx;
const block_q8_0 * restrict y = vy;
const int nb = n / QK1_3;
#if defined(__AVX2__)
__m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
{
__m256i x0 = _mm256_set_epi32(q1_3_grid[x[i].q[7]], q1_3_grid[x[i].q[6]],
q1_3_grid[x[i].q[5]], q1_3_grid[x[i].q[4]],
q1_3_grid[x[i].q[3]], q1_3_grid[x[i].q[2]],
q1_3_grid[x[i].q[1]], q1_3_grid[x[i].q[0]]);
__m256i y0 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i].qs));
__m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i].d));
__m256 q = mul_sum_i8_pairs_float(x0, y0);
accumf = _mm256_fmadd_ps(d, q, accumf);
}
{
__m256i x1 = _mm256_castsi128_si256(_mm_set_epi32(q1_3_grid[x[i].q[11]], q1_3_grid[x[i].q[10]],
q1_3_grid[x[i].q[9]], q1_3_grid[x[i].q[8]]));
__m256i x2 = _mm256_cvtepu8_epi16(_mm_maskload_epi32((const int32_t *) x[i].q, _mm_set_epi32(0, -1, -1, -1)));
__m256i y1 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i + 1].qs));
x2 = _mm256_mulhi_epu16(x2, _mm256_set1_epi16(3 << 8));
x2 = _mm256_sub_epi16(x2, _mm256_set1_epi16(1));
// TODO: reduce shuffling
x2 = _mm256_packs_epi16(x2, _mm256_setzero_si256());
x2 = _mm256_permute4x64_epi64(x2, _MM_SHUFFLE(3, 1, 2, 0));
__m128i x2_l = _mm_insert_epi32(_mm256_castsi256_si128(x2), q1_3_grid[x[i].qs[0]], 3);
x1 = _mm256_inserti128_si256(x1, x2_l, 1);
__m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i + 1].d));
__m256 q = mul_sum_i8_pairs_float(x1, y1);
accumf = _mm256_fmadd_ps(d, q, accumf);
}
}
*s = hsum_float_8(accumf);
#else
float sumf = 0.0f;
for (int i = 0; i < nb; ++i) {
int sum = 0;
for (int j = 0; j < 8; ++j) {
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[j]);
for (int k = 0; k < 4; ++k) {
sum += xj[k] * (int16_t) y[2*i].qs[4*j + k];
}
}
sumf += GGML_FP16_TO_FP32(y[2*i].d) * sum;
sum = 0;
for (int j = 0; j < 4; ++j) {
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[8 + j]);
for (int k = 0; k < 4; ++k) {
sum += xj[k] * (int16_t) y[2*i + 1].qs[4*j + k];
}
}
for (size_t j = 0; j < 12; ++j) {
uint16_t xj = x[i].q[j];
xj = (xj * 3) >> 8;
sum += ((int16_t) xj - 1) * (int16_t) y[2*i + 1].qs[16 + j];
}
{
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].qs[0]);
for (int k = 0; k < 4; ++k) {
sum += (int16_t) xj[k] * (int16_t) y[2*i + 1].qs[28 + k];
}
}
sumf += GGML_FP16_TO_FP32(y[2*i + 1].d) * sum;
}
*s = sumf;
#endif
}
void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
assert(n % QK_K == 0);
assert(nrc == 1);
@ -14977,6 +15326,8 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break;
case GGML_TYPE_Q1_3:
case GGML_TYPE_Q2_2:
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:

View File

@ -12,6 +12,8 @@ extern "C" {
#endif
// Quantization
void quantize_row_q1_3_reference(const float * GGML_RESTRICT x, block_q1_3 * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_2_reference(const float * GGML_RESTRICT x, block_q2_2 * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
@ -32,6 +34,8 @@ void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
void quantize_row_q1_3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q2_2(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
@ -53,6 +57,8 @@ void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
// Dequantization
void dequantize_row_q1_3(const block_q1_3 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q2_2(const block_q2_2 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
@ -78,6 +84,8 @@ void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
// Dot product
void ggml_vec_dot_q1_3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q2_2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
@ -116,6 +124,8 @@ size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q1_3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q2_2(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);

View File

@ -823,6 +823,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_Q1_3] = {
.type_name = "q1_3",
.blck_size = QK1_3,
.type_size = sizeof(block_q1_3),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q1_3,
.from_float = quantize_row_q1_3,
.from_float_reference = (ggml_from_float_t) quantize_row_q1_3_reference,
.vec_dot = ggml_vec_dot_q1_3_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_IQ1_S] = {
.type_name = "iq1_s",
.blck_size = QK_K,
@ -10041,7 +10053,16 @@ static void ggml_compute_forward_mul_f32(
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));
if (nb10 == sizeof(float)) {
if (ggml_nelements(src1) == 1) {
float scale = ((float *) src1->data)[0];
for (int64_t ir = ith; ir < nr; ir += nth) {
if (dst->data != src0->data) {
// src0 is same shape as dst => same indices
memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float));
}
ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);
}
} else if (nb10 == sizeof(float)) {
for (int64_t ir = ith; ir < nr; ir += nth) {
// src0 and dst are same shape => same indices
const int64_t i03 = ir/(ne02*ne01);
@ -13713,6 +13734,8 @@ static void ggml_compute_forward_clamp(
} break;
case GGML_TYPE_F16:
case GGML_TYPE_BF16:
case GGML_TYPE_Q1_3:
case GGML_TYPE_Q2_2:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
@ -20438,6 +20461,8 @@ size_t ggml_quantize_chunk(
size_t result = 0;
switch (type) {
case GGML_TYPE_Q1_3: result = quantize_q1_3(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q2_2: result = quantize_q2_2(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

View File

@ -1023,6 +1023,8 @@ class GGMLQuantizationType(IntEnum):
F64 = 28
IQ1_M = 29
BF16 = 30
Q2_2 = 31
Q1_3 = 32
# TODO: add GGMLFileType from ggml_ftype in ggml.h
@ -1064,6 +1066,8 @@ class LlamaFileType(IntEnum):
MOSTLY_IQ4_XS = 30 # except 1d tensors
MOSTLY_IQ1_M = 31 # except 1d tensors
MOSTLY_BF16 = 32 # except 1d tensors
MOSTLY_Q2_2 = 33 # except 1d tensors
MOSTLY_Q1_3 = 34 # except 1d tensors
GUESSED = 1024 # not specified in the model file
@ -1137,6 +1141,8 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F64: (1, 8),
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
GGMLQuantizationType.BF16: (1, 2),
GGMLQuantizationType.Q2_2: (32, 8),
GGMLQuantizationType.Q1_3: (64, 12 + 1),
}

View File

@ -121,3 +121,53 @@ def quantize_q8_0(data: np.ndarray):
return __quantize_q8_0_lazy(data)
else:
return __quantize_q8_0_array(data)
__q1_3_block_size, __q1_3_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q1_3]
def __quantize_q1_3_shape_change(s: tuple[int, ...]) -> tuple[int, ...]:
return (*s[:-1], s[-1] // __q1_3_block_size * __q1_3_type_size)
def __quantize_q1_3_rows(n: np.ndarray) -> np.ndarray:
shape = n.shape
assert shape[-1] % __q1_3_block_size == 0
n_blocks = n.size // __q1_3_block_size
blocks = n.reshape((n_blocks, __q1_3_block_size)).astype(np.float32, copy=False)
# assuming the weights are pre-scaled
blocks = (np.sign(blocks).astype(np.int8) + 1).view(np.uint8)
q48, rest = np.hsplit(blocks, (48,))
q12, q4 = np.hsplit(rest, (12,))
pow3 = np.array([1, 3, 9, 27])
q48 = q48.reshape((n_blocks, 12, 4))
q48 = np.sum(q48 * pow3.reshape((1, 1, 4)), axis=2, keepdims=True).reshape((n_blocks, 12))
q4 = np.sum(q4 * pow3.reshape((1, 4)), axis=1, keepdims=True)
q48 = q48 + (q12 * 81)
q = np.concatenate([q48, q4], axis=1);
q = ((q.astype(np.uint16) * 256) // 243).astype(np.uint8)
q = np.where(q != 0, q + 1, 0);
return q.reshape(__quantize_q1_3_shape_change(shape))
def __quantize_q1_3_array(n: np.ndarray) -> np.ndarray:
return __apply_over_grouped_rows(__quantize_q1_3_rows, arr=n, otype=np.uint8, oshape=__quantize_q1_3_shape_change(n.shape))
__quantize_q1_3_lazy = LazyNumpyTensor._wrap_fn(
__quantize_q1_3_array,
meta_noop=(np.uint8, __quantize_q1_3_shape_change),
)
def quantize_q1_3(data: np.ndarray):
if type(data) is LazyNumpyTensor:
return __quantize_q1_3_lazy(data)
else:
return __quantize_q1_3_array(data)

View File

@ -158,6 +158,8 @@ extern "C" {
LLAMA_FTYPE_MOSTLY_IQ4_XS = 30, // except 1d tensors
LLAMA_FTYPE_MOSTLY_IQ1_M = 31, // except 1d tensors
LLAMA_FTYPE_MOSTLY_BF16 = 32, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q2_2 = 33, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q1_3 = 34, // except 1d tensors
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
};

View File

@ -4186,6 +4186,8 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
case LLAMA_FTYPE_MOSTLY_Q1_3: return "Q1_3 - 1.625 bpw for BitNet 1.58b";
case LLAMA_FTYPE_MOSTLY_Q2_2: return "Q2_2 - 2.000 bpw for BitNet 1.58b";
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
@ -16287,6 +16289,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
llama_ftype ftype = params->ftype;
switch (params->ftype) {
case LLAMA_FTYPE_MOSTLY_Q1_3: default_type = GGML_TYPE_Q1_3; break;
case LLAMA_FTYPE_MOSTLY_Q2_2: default_type = GGML_TYPE_Q2_2; break;
case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;