diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e93c3829a..a137c157e 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -284,9 +284,6 @@ class Model: for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)): data: np.ndarray # type hint - if len(data.shape) == 0: - # otherwise single-value tensors get squeezed - data = data.reshape((1,)) n_dims = len(data.shape) data_dtype = data.dtype data_qtype: gguf.GGMLQuantizationType | None = None @@ -317,33 +314,12 @@ class Model: )) if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32: - # TODO: cleaner model-specific per-tensor types - # NOTE: Q1_3 is only relevant for BitNet b1.58 - if ( - self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3 - and gguf.can_quantize_to_q1_3(data) - and not any( - self.match_model_tensor_name(new_name, key, None) - for key in [ - gguf.MODEL_TENSOR.TOKEN_EMBD, - gguf.MODEL_TENSOR.OUTPUT, - ] - ) - ): - data = gguf.quantize_q1_3(data) - assert data.dtype == np.uint8 - data_qtype = gguf.GGMLQuantizationType.Q1_3 - - elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16: + if self.ftype == gguf.LlamaFileType.MOSTLY_BF16: data = gguf.quantize_bf16(data) assert data.dtype == np.int16 data_qtype = gguf.GGMLQuantizationType.BF16 - elif ( - self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 - or self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3 - and gguf.can_quantize_to_q8_0(data) - ): + elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data): data = gguf.quantize_q8_0(data) assert data.dtype == np.uint8 data_qtype = gguf.GGMLQuantizationType.Q8_0 @@ -1635,12 +1611,6 @@ class LlamaModel(Model): class BitnetModel(Model): model_arch = gguf.MODEL_ARCH.BITNET - def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, *args, **kwargs): - if ftype == gguf.LlamaFileType.GUESSED: - ftype = gguf.LlamaFileType.MOSTLY_Q1_3 - - super().__init__(dir_model, ftype, *args, **kwargs) - def set_vocab(self): self._set_vocab_sentencepiece() @@ -1649,16 +1619,16 @@ class BitnetModel(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(1.0) - def weight_quant(self, weight): + def weight_quant(self, weight: Tensor) -> Tensor: dtype = weight.dtype weight = weight.float() scale = weight.abs().mean().clamp(min=1e-5) iscale = 1 / scale - weight = (weight * iscale).round().clamp(-1, 1) - # TODO: use the scale directly instead of inverting it twice + # TODO: multiply by the scale directly instead of inverting it twice # (this is also unnecessarily doubly inverted upstream) # ref: https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/af89e318d78a70802061246bf037199d2fb97020/utils_quant.py#L10 - return weight.type(dtype), (1 / iscale).type(torch.float32) + result = (weight * iscale).round().clamp(-1, 1) / iscale + return result.type(dtype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: new_name = self.map_tensor_name(name) @@ -1673,11 +1643,9 @@ class BitnetModel(Model): gguf.MODEL_TENSOR.FFN_GATE, ]): # transform weight into 1/0/-1 (in fp32) - weight_torch, scale_torch = self.weight_quant(data_torch) - yield (new_name, weight_torch) - yield (new_name.removesuffix(".weight") + ".scale", scale_torch) - else: - yield (new_name, data_torch) + data_torch = self.weight_quant(data_torch) + + yield (new_name, data_torch) @Model.register("GrokForCausalLM") diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 1086cc9ed..7f8f724b7 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -28,8 +28,6 @@ static const std::vector QUANT_OPTIONS = { { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, { "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0, " 1.69 bpw ternarization", }, { "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0, " 2.06 bpw ternarization", }, - { "Q1_3", LLAMA_FTYPE_MOSTLY_Q1_3, " 1.63 bpw for BitNet b1.58", }, - { "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.00 bpw for BitNet b1.58", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 03884cba4..3950a4a07 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -392,8 +392,6 @@ extern "C" { GGML_TYPE_Q4_0_8_8 = 33, GGML_TYPE_TQ1_0 = 34, GGML_TYPE_TQ2_0 = 35, - GGML_TYPE_Q2_2 = 36, - GGML_TYPE_Q1_3 = 37, GGML_TYPE_COUNT, }; diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 3be4dd4ca..c65614696 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -141,20 +141,6 @@ typedef sycl::half2 ggml_half2; #endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP -// 1.625 bpw for BitNet b1.58 models -#define QK1_3 64 -typedef struct { - uint8_t q[(QK1_3 - 4*QK1_3/64)/5]; // 5 elements per byte (3^5 = 243 < 256) - uint8_t qs[QK1_3/64]; // 4 elements per byte -} block_q1_3; -static_assert(sizeof(block_q1_3) == (QK1_3 - 4*QK1_3/64)/5 + QK1_3/64, "wrong q1_3 block size/padding"); - -#define QK2_2 32 -typedef struct { - uint8_t qs[QK2_2 / 4]; // nibbles / quants -} block_q2_2; -static_assert(sizeof(block_q2_2) == QK2_2 / 4, "wrong q2_2 block size/padding"); - #define QK4_0 32 typedef struct { ggml_half d; // delta @@ -1084,41 +1070,6 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512) 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, GGML_TABLE_END() -GGML_TABLE_BEGIN(uint32_t, q1_3_grid, 256) - 0xffffffff, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff, - 0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff000000, 0xff000001, - 0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff, - 0xff010000, 0xff010001, 0xff0101ff, 0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01, - 0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00, - 0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101, - 0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000, 0x00010001, 0x000101ff, 0x00010100, - 0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001, - 0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000, - 0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x0101ff01, - 0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00, - 0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff, - 0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100, - 0xff000101, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff, - 0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0000, - 0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff, - 0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01, - 0x000100ff, 0x00010000, 0x00010000, 0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff, - 0x01ffff00, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101, - 0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x01000001, 0x010001ff, - 0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001, - 0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000, - 0xffff0001, 0xffff01ff, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01, - 0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ff00, - 0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff, 0xff0101ff, 0xff010100, 0xff010101, - 0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100, - 0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff, - 0x00000100, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000, - 0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ff00ff, - 0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x01ff0101, 0x0100ffff, 0x0100ff00, - 0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff, - 0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101, -GGML_TABLE_END() - #define NGRID_IQ1S 2048 #define IQ1S_DELTA 0.125f #define IQ1M_DELTA 0.125f diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index e66d9be25..a2fd0563c 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -657,39 +657,6 @@ static inline __m128i packNibbles( __m256i bytes ) { } #endif //__loongarch_asx -void quantize_row_q2_2_ref(const float * restrict x, block_q2_2 * restrict y, int64_t k) { - static const int qk = QK2_2; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - - for (int j = 0; j < qk/4; ++j) { - int8_t x0 = (int8_t)x[i*qk + 0 + j]; - int8_t x1 = (int8_t)x[i*qk + 1*qk/4 + j]; - int8_t x2 = (int8_t)x[i*qk + 2*qk/4 + j]; - int8_t x3 = (int8_t)x[i*qk + 3*qk/4 + j]; - - const uint8_t xi0 = x0 < 0 ? 1 : x0 == 0 ? 2 : 3; - const uint8_t xi1 = x1 < 0 ? 1 : x1 == 0 ? 2 : 3; - const uint8_t xi2 = x2 < 0 ? 1 : x2 == 0 ? 2 : 3; - const uint8_t xi3 = x3 < 0 ? 1 : x3 == 0 ? 2 : 3; - - y[i].qs[j] = 0; - y[i].qs[j] |= (xi0 << 0); - y[i].qs[j] |= (xi1 << 2); - y[i].qs[j] |= (xi2 << 4); - y[i].qs[j] |= (xi3 << 6); - } - } -} - -void quantize_row_q2_2(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q2_2_ref(x, y, k); -} - // reference implementation for deterministic creation of model files void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -1545,26 +1512,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k) #endif } -void dequantize_row_q2_2(const block_q2_2 * restrict x, float * restrict y, int64_t k) { - static const int qk = QK2_2; - - assert(k % qk == 0); - - const int nb = k / qk; - - for (int i = 0; i < nb; i++) { - - for (int j = 0; j < qk/4; ++j) { - const int8_t q = x[i].qs[j]; - - y[i*qk + j + 0 ] = (float) (((q >> 0) & 3) - 2); - y[i*qk + j + 1*qk/4] = (float) (((q >> 2) & 3) - 2); - y[i*qk + j + 2*qk/4] = (float) (((q >> 4) & 3) - 2); - y[i*qk + j + 3*qk/4] = (float) (((q >> 6) & 3) - 2); - } - } -} - void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_0; @@ -3359,13 +3306,6 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } -size_t quantize_q2_2(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - (void)quant_weights; // not used - const size_t row_size = ggml_row_size(GGML_TYPE_Q2_2, n_per_row); - quantize_row_q2_2_ref(src, dst, (int64_t)nrow*n_per_row); - return nrow * row_size; -} - // ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) { @@ -3552,89 +3492,6 @@ void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, in } } -void quantize_row_q1_3_ref(const float * restrict x, block_q1_3 * restrict y, int64_t k) { - assert(k % QK1_3 == 0); - const int64_t nb = k / QK1_3; - static_assert(sizeof(y->q) % 4 == 0, "bad block_q1_3.q size"); - - const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; - - for (int64_t i = 0; i < nb; ++i) { - uint8_t q[sizeof(y->q)] = {0}; - for (size_t j = 0; j < sizeof(y->q); ++j) { - for (size_t m = 0; m < 4; ++m) { - int xi = nearest_int(x[m]); - uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2; - q[j] += xt * pow3[m]; - } - x += 4; - } - for (size_t j = 0; j < sizeof(y->q); ++j) { - int xi = nearest_int(x[j]); - uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2; - q[j] += xt * pow3[4]; - // ceiling division - q[j] = ((uint16_t)q[j] * 256 + (pow3[5] - 1)) / pow3[5]; - y[i].q[j] = q[j]; - } - x += sizeof(y->q); - - for (size_t j = 0; j < sizeof(y->qs); ++j) { - uint8_t qb = 0; - for (size_t m = 0; m < 4; ++m) { - int xi = nearest_int(x[m]); - uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2; - qb += xt * pow3[m]; - } - x += 4; - // ceiling division - qb = ((uint16_t)qb * 256 + (pow3[5] - 1)) / pow3[5]; - y[i].qs[j] = qb; - } - } -} - -void quantize_row_q1_3(const float * restrict x, void * restrict vy, int64_t k) { - assert(k % QK1_3 == 0); - block_q1_3 * restrict y = vy; - quantize_row_q1_3_ref(x, y, k); -} - -size_t quantize_q1_3(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { - (void)quant_weights; // not used - const size_t row_size = ggml_row_size(GGML_TYPE_Q1_3, n_per_row); - quantize_row_q1_3(src, dst, (int64_t)nrow*n_per_row); - return nrow * row_size; -} - -void dequantize_row_q1_3(const block_q1_3 * restrict x, float * restrict y, int64_t k) { - assert(k % QK1_3 == 0); - const int64_t nb = k / QK1_3; - static_assert(sizeof(x->q) % 4 == 0, "bad block_q1_3.q size"); - - for (int64_t i = 0; i < nb; ++i) { - for (size_t j = 0; j < sizeof(x->q); ++j) { - const int8_t * q = (const int8_t *) (q1_3_grid + x[i].q[j]); - for (int m = 0; m < 4; ++m) { - *y++ = (float) q[m]; - } - } - - for (size_t j = 0; j < sizeof(x->q); ++j) { - uint16_t q = x[i].q[j]; - int16_t qi = (q * 3) >> 8; - *y++ = (float) (qi - 1); - } - - for (size_t j = 0; j < sizeof(x->qs); ++j) { - const int8_t * q = (const int8_t *) (q1_3_grid + x[i].qs[j]); - for (int m = 0; m < 4; ++m) { - *y++ = (float) q[m]; - } - } - } -} - // ====================== "True" 2-bit (de)-quantization void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { @@ -4055,122 +3912,6 @@ static inline __m128i get_scale_shuffle(int i) { } #endif -void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - const int qk = QK8_0; - const int nb = n / qk; - - assert(n % qk == 0); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - - const block_q2_2 * restrict x = vx; - const block_q8_0 * restrict y = vy; - -#if defined(__AVX2__) - __m256 acc = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - - const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) ); - - // assuming this is always aligned - __m256i xq8 = _mm256_set1_epi64x(*(const int64_t *) x[i].qs); - xq8 = _mm256_srlv_epi64(xq8, _mm256_set_epi64x(6, 4, 2, 0)); - xq8 = _mm256_and_si256(xq8, _mm256_set1_epi8(0x03)); - // stangely enough, this is much slower with 1 instead of 2 - xq8 = _mm256_sub_epi8(xq8, _mm256_set1_epi8(2)); - - const __m256i yq8 = _mm256_loadu_si256((const __m256i *) (y[i].qs)); - const __m256 q = mul_sum_i8_pairs_float(xq8, yq8); - - acc = _mm256_fmadd_ps( d, q, acc ); - } - - *s = hsum_float_8(acc); -#elif defined(__ARM_NEON) - float sumf0 = 0.0f; - float sumf1 = 0.0f; - - const uint8x8_t mask = vdup_n_u8(3); - const int8x8_t offset = vdup_n_s8(2); - - const int leftovers = nb % 2; - - for (int i = 0; i < nb - leftovers; i += 2) { - const uint8x8_t xq8_0 = vld1_u8(x[0].qs); - const uint8x8_t xq8_1 = vld1_u8(x[1].qs); - - const int8x8_t xq8_0_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8_0, mask)), offset); - const int8x8_t xq8_0_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 2), mask)), offset); - const int8x8_t xq8_0_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 4), mask)), offset); - const int8x8_t xq8_0_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 6), mask)), offset); - const int8x8_t xq8_1_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8_1, mask)), offset); - const int8x8_t xq8_1_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 2), mask)), offset); - const int8x8_t xq8_1_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 4), mask)), offset); - const int8x8_t xq8_1_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 6), mask)), offset); - - const int8x16_t xq8_0_l = vcombine_s8(xq8_0_0, xq8_0_1); - const int8x16_t xq8_0_h = vcombine_s8(xq8_0_2, xq8_0_3); - const int8x16_t xq8_1_l = vcombine_s8(xq8_1_0, xq8_1_1); - const int8x16_t xq8_1_h = vcombine_s8(xq8_1_2, xq8_1_3); - - const int8x16_t yq8_0_l = vld1q_s8(y[0].qs + 0); - const int8x16_t yq8_0_h = vld1q_s8(y[0].qs + 16); - const int8x16_t yq8_1_l = vld1q_s8(y[1].qs + 0); - const int8x16_t yq8_1_h = vld1q_s8(y[1].qs + 16); - - const int16x8_t dot0 = vaddq_s16(vpaddlq_s8(vmulq_s8(xq8_0_l, yq8_0_l)), vpaddlq_s8(vmulq_s8(xq8_0_h, yq8_0_h))); - const int16x8_t dot1 = vaddq_s16(vpaddlq_s8(vmulq_s8(xq8_1_l, yq8_1_l)), vpaddlq_s8(vmulq_s8(xq8_1_h, yq8_1_h))); - - sumf0 += GGML_FP16_TO_FP32(y[0].d) * (float) vaddlvq_s16(dot0); - sumf1 += GGML_FP16_TO_FP32(y[1].d) * (float) vaddlvq_s16(dot1); - x += 2; - y += 2; - } - - // one block at a time - for (int i = nb - leftovers; i < nb; ++i) { - const uint8x8_t xq8 = vld1_u8(x->qs); - const int8x8_t xq8_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8, mask)), offset); - const int8x8_t xq8_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 2), mask)), offset); - const int8x8_t xq8_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 4), mask)), offset); - const int8x8_t xq8_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 6), mask)), offset); - - const int8x16_t xq8_l = vcombine_s8(xq8_0, xq8_1); - const int8x16_t xq8_h = vcombine_s8(xq8_2, xq8_3); - - const int8x16_t yq8_l = vld1q_s8(y->qs + 0); - const int8x16_t yq8_h = vld1q_s8(y->qs + 16); - - const int16x8_t dot0 = vpaddlq_s8(vmulq_s8(xq8_l, yq8_l)); - const int16x8_t dot1 = vpaddlq_s8(vmulq_s8(xq8_h, yq8_h)); - - sumf0 += GGML_FP16_TO_FP32(y->d) * (float) vaddlvq_s16(vaddq_s16(dot0, dot1)); - x += 1; - y += 1; - } - - *s = sumf0 + sumf1; -#else - - float sumf = 0.0f; - for (int i = 0; i < nb; i++) { - int sumi = 0; - for (int j = 0; j < qk / 4; j++) { - const uint8_t weight = x[i].qs[j]; - sumi += (int)y[i].qs[j + 0*qk/4] * (((weight >> 0) & 3) - 2); - sumi += (int)y[i].qs[j + 1*qk/4] * (((weight >> 2) & 3) - 2); - sumi += (int)y[i].qs[j + 2*qk/4] * (((weight >> 4) & 3) - 2); - sumi += (int)y[i].qs[j + 3*qk/4] * (((weight >> 6) & 3) - 2); - } - sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d)); - } - *s = sumf; -#endif -} - void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_0; const int nb = n / qk; @@ -11855,225 +11596,6 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { } #endif -void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { - assert(nrc == 1); - UNUSED(nrc); - UNUSED(bx); - UNUSED(by); - UNUSED(bs); - // assumed by the code below - assert(n % QK1_3 == 0); - static_assert(QK1_3 == 2 * QK8_0, "QK1_3 must be 2 times bigger than QK8_0"); - - const block_q1_3 * restrict x = vx; - const block_q8_0 * restrict y = vy; - - const int nb = n / QK1_3; - -#if defined(__AVX2__) - __m256 accumf = _mm256_setzero_ps(); - - for (int i = 0; i < nb; ++i) { - // const __m128i x12a = _mm_maskload_epi32((const int32_t *) x, _mm_set_epi32(0, -1, -1, -1)); - // const __m128i x13b = _mm_insert_epi8(x12a, x->qs[0], 12); - // WARNING: reading 3 bytes further than necessary. - // It's measurably faster than a masked load on an Intel Core m3-8100Y - const __m128i x13b = _mm_loadu_si128((const __m128i *) x); - const __m256i x13 = MM256_SET_M128I(x13b, x13b); - - { - // pre-shift the values by 8 bits, and prepare the layout for later packing - __m256i x0l = _mm256_shuffle_epi8(x13, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1, - 4, -1, 4, -1, 4, -1, 4, -1, - 1, -1, 1, -1, 1, -1, 1, -1, - 0, -1, 0, -1, 0, -1, 0, -1)); - __m256i x0h = _mm256_shuffle_epi8(x13, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1, - 6, -1, 6, -1, 6, -1, 6, -1, - 3, -1, 3, -1, 3, -1, 3, -1, - 2, -1, 2, -1, 2, -1, 2, -1)); - __m256i x1l = _mm256_shuffle_epi8(x13, _mm256_set_epi8(7, -1, 6, -1, 5, -1, 4, -1, - 3, -1, 2, -1, 1, -1, 0, -1, - 9, -1, 9, -1, 9, -1, 9, -1, - 8, -1, 8, -1, 8, -1, 8, -1)); - __m256i x1h = _mm256_shuffle_epi8(x13, _mm256_set_epi8(12, -1, 12, -1, 12, -1, 12, -1, - 11, -1, 10, -1, 9, -1, 8, -1, - 11, -1, 11, -1, 11, -1, 11, -1, - 10, -1, 10, -1, 10, -1, 10, -1)); - const __m256i shift0 = _mm256_set_epi16(3, 9, 27, 81, - 3, 9, 27, 81, - 3, 9, 27, 81, - 3, 9, 27, 81); - const __m256i shift1l = _mm256_set_epi16(1, 1, 1, 1, - 1, 1, 1, 1, - 3, 9, 27, 81, - 3, 9, 27, 81); - const __m256i shift1h = _mm256_set_epi16(3, 9, 27, 81, - 1, 1, 1, 1, - 3, 9, 27, 81, - 3, 9, 27, 81); - // extract ternary values - // first by shifting the numbers to make each one the next significant digit - x0l = _mm256_mullo_epi16(x0l, shift0); - x0h = _mm256_mullo_epi16(x0h, shift0); - x1l = _mm256_mullo_epi16(x1l, shift1l); - x1h = _mm256_mullo_epi16(x1h, shift1h); - // then by extracting each of these most significant digits - x0l = _mm256_mulhi_epu16(x0l, _mm256_set1_epi16(3)); - x0h = _mm256_mulhi_epu16(x0h, _mm256_set1_epi16(3)); - x1l = _mm256_mulhi_epu16(x1l, _mm256_set1_epi16(3)); - x1h = _mm256_mulhi_epu16(x1h, _mm256_set1_epi16(3)); - - __m256i x0 = _mm256_packs_epi16(x0l, x0h); - __m256i x1 = _mm256_packs_epi16(x1l, x1h); - - // 0, 1, 2 => -1, 0, 1 - x0 = _mm256_sub_epi8(x0, _mm256_set1_epi8(1)); - x1 = _mm256_sub_epi8(x1, _mm256_set1_epi8(1)); - - const __m256i y0 = _mm256_loadu_si256((const __m256i *) (y[0].qs)); - const __m256i y1 = _mm256_loadu_si256((const __m256i *) (y[1].qs)); - - const __m256 d0 = _mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d)); - const __m256 d1 = _mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d)); - - const __m256 q0 = mul_sum_i8_pairs_float(y0, x0); - const __m256 q1 = mul_sum_i8_pairs_float(y1, x1); - - accumf = _mm256_fmadd_ps(d0, q0, accumf); - accumf = _mm256_fmadd_ps(d1, q1, accumf); - } - - x += 1; - y += 2; - } - - *s = hsum_float_8(accumf); -#elif defined(__ARM_NEON) - - static const uint8_t k_mask0[16] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; - static const uint8_t k_mask1[16] = {4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7}; - static const uint8_t k_mask2[16] = {8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 11, 11}; - static const uint8_t k_mask3[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 12, 12}; - - static const uint8_t k_shift0[16] = {81, 27, 9, 3, 81, 27, 9, 3, 81, 27, 9, 3, 81, 27, 9, 3}; - static const uint8_t k_shift3[16] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 81, 27, 9, 3}; - - // float32x4_t sumv0 = vdupq_n_f32(0.0f); - // float32x4_t sumv1 = vdupq_n_f32(0.0f); - - float sumf0 = 0.0f; - float sumf1 = 0.0f; - - const uint8x16_t mask0 = vld1q_u8(k_mask0); - const uint8x16_t mask1 = vld1q_u8(k_mask1); - const uint8x16_t mask2 = vld1q_u8(k_mask2); - const uint8x16_t mask3 = vld1q_u8(k_mask3); - - const uint8x16_t shift0 = vld1q_u8(k_shift0); - const uint8x16_t shift3 = vld1q_u8(k_shift3); - - const int8x16_t one = vdupq_n_s8(1); - - for (int i = 0; i < nb; ++i) { - // WARNING: reading 3 bytes further than necessary - const uint8x16_t x13b = vld1q_u8((const uint8_t *) x); - - uint8x16_t x0 = ggml_vqtbl1q_u8(x13b, mask0); - uint8x16_t x1 = ggml_vqtbl1q_u8(x13b, mask1); - uint8x16_t x2 = ggml_vqtbl1q_u8(x13b, mask2); - uint8x16_t x3 = ggml_vqtbl1q_u8(x13b, mask3); - - x0 = vmulq_u8(x0, shift0); - x1 = vmulq_u8(x1, shift0); - x2 = vmulq_u8(x2, shift0); - x3 = vmulq_u8(x3, shift3); - - // multiply by 3 and keep the 2 bits above 8 bits - x0 = vshrq_n_u8(vhaddq_u8(x0, vshrq_n_u8(x0, 1)), 6); - x1 = vshrq_n_u8(vhaddq_u8(x1, vshrq_n_u8(x1, 1)), 6); - x2 = vshrq_n_u8(vhaddq_u8(x2, vshrq_n_u8(x2, 1)), 6); - x3 = vshrq_n_u8(vhaddq_u8(x3, vshrq_n_u8(x3, 1)), 6); - - // 0, 1, 2 => -1, 0, 1 - int8x16_t x0i = vsubq_s8(vreinterpretq_s8_u8(x0), one); - int8x16_t x1i = vsubq_s8(vreinterpretq_s8_u8(x1), one); - int8x16_t x2i = vsubq_s8(vreinterpretq_s8_u8(x2), one); - int8x16_t x3i = vsubq_s8(vreinterpretq_s8_u8(x3), one); - - const int8x16_t y0 = vld1q_s8(y[0].qs + 0); - const int8x16_t y1 = vld1q_s8(y[0].qs + 16); - const int8x16_t y2 = vld1q_s8(y[1].qs + 0); - const int8x16_t y3 = vld1q_s8(y[1].qs + 16); - - // const int32x4_t p0 = vpaddlq_s16(vaddq_s16(vpaddlq_s8(x0i), vpaddlq_s8(x1i))); - // const int32x4_t p1 = vpaddlq_s16(vaddq_s16(vpaddlq_s8(x2i), vpaddlq_s8(x3i))); - - // there's no direct equivalent to _mm_sign_epi8, unfortunately - x0i = vmulq_s8(x0i, y0); - x1i = vmulq_s8(x1i, y1); - x2i = vmulq_s8(x2i, y2); - x3i = vmulq_s8(x3i, y3); - - // overall 18.5% faster than with vector sums on a cortex-A72 - sumf0 += GGML_FP16_TO_FP32(y[0].d) * (float) vaddlvq_s16(vaddq_s16(vpaddlq_s8(x0i), vpaddlq_s8(x1i))); - sumf1 += GGML_FP16_TO_FP32(y[1].d) * (float) vaddlvq_s16(vaddq_s16(vpaddlq_s8(x2i), vpaddlq_s8(x3i))); - - // const int32x4_t p0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), x0i, y0), x1i, y1); - // const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), x2i, y2), x3i, y3); - - // sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p0), GGML_FP16_TO_FP32(y[0].d)); - // sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p1), GGML_FP16_TO_FP32(y[1].d)); - - y += 2; - x += 1; - } - - // *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); - *s = sumf0 + sumf1; -#else - float sumf = 0.0f; - - for (int i = 0; i < nb; ++i) { - int sum = 0; - for (int j = 0; j < 8; ++j) { - const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[j]); - for (int k = 0; k < 4; ++k) { - sum += xj[k] * (int16_t) y->qs[4*j + k]; - } - } - - sumf += GGML_FP16_TO_FP32(y->d) * sum; - y += 1; - sum = 0; - - for (int j = 0; j < 4; ++j) { - const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[8 + j]); - for (int k = 0; k < 4; ++k) { - sum += xj[k] * (int16_t) y->qs[4*j + k]; - } - } - - for (size_t j = 0; j < 12; ++j) { - uint16_t xj = x[i].q[j]; - xj = (xj * 3) >> 8; - sum += ((int16_t) xj - 1) * (int16_t) y->qs[16 + j]; - } - - { - const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].qs[0]); - for (int k = 0; k < 4; ++k) { - sum += (int16_t) xj[k] * (int16_t) y->qs[28 + k]; - } - } - - sumf += GGML_FP16_TO_FP32(y->d) * sum; - y += 1; - } - - *s = sumf; -#endif -} - void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(n % QK_K == 0); assert(nrc == 1); @@ -15964,8 +15486,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8); } break; - case GGML_TYPE_Q1_3: - case GGML_TYPE_Q2_2: + case GGML_TYPE_I8: case GGML_TYPE_I16: case GGML_TYPE_I32: diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 238cfd3fb..df9c4b24a 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -12,8 +12,6 @@ extern "C" { #endif // Quantization -void quantize_row_q1_3_ref(const float * GGML_RESTRICT x, block_q1_3 * GGML_RESTRICT y, int64_t k); -void quantize_row_q2_2_ref(const float * GGML_RESTRICT x, block_q2_2 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); @@ -37,8 +35,6 @@ void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGM void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k); void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); -void quantize_row_q1_3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); -void quantize_row_q2_2(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -63,8 +59,6 @@ void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); // Dequantization -void dequantize_row_q1_3(const block_q1_3 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); -void dequantize_row_q2_2(const block_q2_2 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -93,8 +87,6 @@ void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_ void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); // Dot product -void ggml_vec_dot_q1_3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); -void ggml_vec_dot_q2_2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -139,8 +131,6 @@ size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q1_3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); -size_t quantize_q2_2(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5bbe0e4a8..b56ebcb7c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -856,30 +856,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, - [GGML_TYPE_Q2_2] = { - .type_name = "q2_2", - .blck_size = QK2_2, - .type_size = sizeof(block_q2_2), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q2_2, - .from_float = quantize_row_q2_2, - .from_float_ref = (ggml_from_float_t) quantize_row_q2_2_ref, - .vec_dot = ggml_vec_dot_q2_2_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, - }, - [GGML_TYPE_Q1_3] = { - .type_name = "q1_3", - .blck_size = QK1_3, - .type_size = sizeof(block_q1_3), - .is_quantized = true, - .to_float = (ggml_to_float_t) dequantize_row_q1_3, - .from_float = quantize_row_q1_3, - .from_float_ref = (ggml_from_float_t) quantize_row_q1_3_ref, - .vec_dot = ggml_vec_dot_q1_3_q8_0, - .vec_dot_type = GGML_TYPE_Q8_0, - .nrows = 1, - }, [GGML_TYPE_IQ1_S] = { .type_name = "iq1_s", .blck_size = QK_K, @@ -13936,8 +13912,6 @@ static void ggml_compute_forward_clamp( } break; case GGML_TYPE_F16: case GGML_TYPE_BF16: - case GGML_TYPE_Q1_3: - case GGML_TYPE_Q2_2: case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: @@ -20638,8 +20612,6 @@ size_t ggml_quantize_chunk( size_t result = 0; switch (type) { - case GGML_TYPE_Q1_3: result = quantize_q1_3(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; - case GGML_TYPE_Q2_2: result = quantize_q2_2(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 32e73e56a..57aa53b41 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1150,8 +1150,6 @@ class GGMLQuantizationType(IntEnum): Q4_0_8_8 = 33 TQ1_0 = 34 TQ2_0 = 35 - Q1_3 = 36 - Q2_2 = 37 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -1193,8 +1191,11 @@ class LlamaFileType(IntEnum): MOSTLY_IQ4_XS = 30 # except 1d tensors MOSTLY_IQ1_M = 31 # except 1d tensors MOSTLY_BF16 = 32 # except 1d tensors - MOSTLY_Q2_2 = 33 # except 1d tensors - MOSTLY_Q1_3 = 34 # except 1d tensors + MOSTLY_Q4_0_4_4 = 33 # except 1d tensors + MOSTLY_Q4_0_4_8 = 34 # except 1d tensors + MOSTLY_Q4_0_8_8 = 35 # except 1d tensors + MOSTLY_TQ1_0 = 36 # except 1d tensors + MOSTLY_TQ2_0 = 37 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1268,8 +1269,11 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.F64: (1, 8), GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), GGMLQuantizationType.BF16: (1, 2), - GGMLQuantizationType.Q2_2: (32, 8), - GGMLQuantizationType.Q1_3: (64, 12 + 1), + GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16), + GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16), + GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16), + GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), + GGMLQuantizationType.TQ2_0: (256, 2 + 64), } diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index 8f7fd0232..16e0a9aaa 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -121,55 +121,3 @@ def quantize_q8_0(data: np.ndarray): return __quantize_q8_0_lazy(data) else: return __quantize_q8_0_array(data) - - -__q1_3_block_size, __q1_3_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q1_3] - - -def can_quantize_to_q1_3(n: np.ndarray) -> bool: - return n.shape[-1] % __q1_3_block_size == 0 - - -def __quantize_q1_3_shape_change(s: tuple[int, ...]) -> tuple[int, ...]: - return (*s[:-1], s[-1] // __q1_3_block_size * __q1_3_type_size) - - -def __quantize_q1_3_rows(n: np.ndarray) -> np.ndarray: - shape = n.shape - assert shape[-1] % __q1_3_block_size == 0 - - n_blocks = n.size // __q1_3_block_size - - blocks = n.reshape((n_blocks, __q1_3_block_size)).astype(np.float32, copy=False) - - # assuming the weights are pre-scaled - blocks = (np.sign(blocks).astype(np.int8) + 1).view(np.uint8) - q48, rest = np.hsplit(blocks, (48,)) - q12, q4 = np.hsplit(rest, (12,)) - - pow3 = np.array([1, 3, 9, 27]) - q48 = q48.reshape((n_blocks, 12, 4)) - q48 = np.sum(q48 * pow3.reshape((1, 1, 4)), axis=2, keepdims=True).reshape((n_blocks, 12)) - q4 = np.sum(q4 * pow3.reshape((1, 4)), axis=1, keepdims=True) - q48 = q48 + (q12 * 81) - q = np.concatenate([q48, q4], axis=1) - q = (((q.astype(np.uint16) * 256) + (243 - 1)) // 243).astype(np.uint8) - - return q.reshape(__quantize_q1_3_shape_change(shape)) - - -def __quantize_q1_3_array(n: np.ndarray) -> np.ndarray: - return __apply_over_grouped_rows(__quantize_q1_3_rows, arr=n, otype=np.uint8, oshape=__quantize_q1_3_shape_change(n.shape)) - - -__quantize_q1_3_lazy = LazyNumpyTensor._wrap_fn( - __quantize_q1_3_array, - meta_noop=(np.uint8, __quantize_q1_3_shape_change), -) - - -def quantize_q1_3(data: np.ndarray): - if type(data) is LazyNumpyTensor: - return __quantize_q1_3_lazy(data) - else: - return __quantize_q1_3_array(data) diff --git a/include/llama.h b/include/llama.h index 7dcc260e8..be039e45f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -168,8 +168,6 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q1_3 = 38, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q2_2 = 39, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index 7ab2b47cd..bd52975db 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4451,8 +4451,6 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; - case LLAMA_FTYPE_MOSTLY_Q1_3: return "Q1_3 - 1.625 bpw for BitNet b1.58"; - case LLAMA_FTYPE_MOSTLY_Q2_2: return "Q2_2 - 2.000 bpw for BitNet b1.58"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; @@ -7347,23 +7345,23 @@ static bool llm_load_tensors( layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); + layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); + layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); + layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); + layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); + layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); + layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); + layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; case LLM_ARCH_T5: @@ -13028,7 +13026,9 @@ struct llm_build_context { { // compute Q and K and RoPE them struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); + if (model.layers[il].wq_scale) { + Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); + } cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -13037,7 +13037,9 @@ struct llm_build_context { // B1.K struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); + if (model.layers[il].wk_scale) { + Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); + } cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -13046,7 +13048,9 @@ struct llm_build_context { // B1.V struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); + if (model.layers[il].wv_scale) { + Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); + } cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -13077,7 +13081,9 @@ struct llm_build_context { cb(cur, "attn_sub_norm", il); cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); - cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + if (model.layers[il].wo_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + } if (model.layers[il].bo) { cur = ggml_add(ctx0, cur, model.layers[il].bo); } @@ -13114,7 +13120,9 @@ struct llm_build_context { cb(cur, "ffn_sub_norm", il); cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur); - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + if (model.layers[il].ffn_down_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + } cb(cur, "ffn_down", il); cur = ggml_add(ctx0, cur, ffn_inp); @@ -15631,8 +15639,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s llama_ftype ftype = params->ftype; switch (params->ftype) { - case LLAMA_FTYPE_MOSTLY_Q1_3: default_type = GGML_TYPE_Q1_3; break; - case LLAMA_FTYPE_MOSTLY_Q2_2: default_type = GGML_TYPE_Q2_2; break; case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index a6e508d01..ccf5721a3 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -15,13 +15,13 @@ constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f; -constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.015625f; // TODO: change to 0.01f +constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f; constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f; constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f; -constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.5f; // TODO: change to 0.15f +constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f; static const char* RESULT_STR[] = {"ok", "FAILED"}; @@ -146,8 +146,6 @@ int main(int argc, char * argv[]) { if (qfns.from_float && qfns.to_float) { const float total_error = total_quantization_error(qfns, test_size, test_data.data()); const float max_quantization_error = - type == GGML_TYPE_Q1_3 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : - type == GGML_TYPE_Q2_2 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : @@ -172,7 +170,7 @@ int main(int argc, char * argv[]) { const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S ? MAX_DOT_PRODUCT_ERROR_LOWBIT - : type == GGML_TYPE_Q2_2 || type == GGML_TYPE_Q1_3 || type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0 + : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0 ? MAX_DOT_PRODUCT_ERROR_TERNARY : MAX_DOT_PRODUCT_ERROR; failed = !(vec_dot_error < max_allowed_error);