From 9bc6db28d011d47a5f318dc4aebbe7927fac4629 Mon Sep 17 00:00:00 2001 From: compilade Date: Thu, 5 Sep 2024 21:48:47 -0400 Subject: [PATCH] ggml-quants : ternary packing for TriLMs and BitNet b1.58 (#8151) * ggml-quants : 1.625 bpw ternary packing for BitNet 1.58b * ggml-quants : faster 1.625 bpw AVX2 vec_dot Not using a lookup table anymore makes it match q4_0 speed. * gguf-py : fix formatting * llama : remove spaces on empty line * ggml-quants : subtract 1 when back in epi8 This makes the 1.625 bpw type go faster than q4_0. Still not the fastest. * ggml-quants : Q2_2 now faster than Q4_K on with AVX2 * ggml-quants : cleanup Q1_3 code formatting * ggml-quants : ARM NEON vec_dot for q2_2 and q1_3 * ggml-quants : use ceiling division when quantizing q1_3 * convert-hf : simplify BitNet pre-quantization This still results in the exact same tensor weights and scales, but it reveals some weirdness in the current algorithm. * convert-hf : allow converting the weird BitNet 1.3B Its FFN size is 5460 which is not convenient. The offending tensors are kept in F16, which makes the final model 5.01 bpw. * bitnet : replace 1.58b with b1.58, as in the paper * ggml-quants : fix build failure on Windows * ggml-quants : attempt to fix Arm 32-bit support * ggml : add some informative comments in q1_3 vec_dot * ggml : add TQ1_0 and TQ2_0 ternary quantization types * ggml : even faster TQ2_0 * ggml : also faster TQ1_0 Same optimization as for TQ2_0 by offsetting the sum instead of the weights. This makes TQ1_0 almost as fast as Q8_0 on AVX2. * ggml : fix build issues in certain environments * ggml : add NEON vec_dot implementation for TQ1_0 and TQ2_0 * ggml : avoid directly using vmlal_high_s8, for 32-bit ARM compat The compiler seems smart enough to use the same instruction even when using vget_high_s8 instead. * ggml : remove q1_3 and q2_2 No more 1.625 bpw and 2.000 bpw, now instead using 1.6875 bpw and 2.0625 bpw with TQ1_0 and TQ2_0, respectively. * llama : remove the separate scale tensors of BitNet b1.58 They won't be needed, since the remaining ternary quant types have built-in scales. * ggml-quants : rename fields of TQ1_0 and TQ2_0 structs for consistency * ggml-quants : allow using vdotq_s32 in TQ2_0 vec_dot Not yet tested on hardware which supports it, might not work or might not even compile. But also it might. It should make the performance better on recent ARM CPUs. * ggml-quants : remove comment about possible format change of TQ2_0 Making it slightly more convenient for AVX512 but less convenient for everything else is not worth the trouble. * gguf-py : Numpy (de)quantization for TQ1_0 and TQ2_0 * ggml-quants : use roundf instead of nearest_int for TQ1_0 and TQ2_0 This does not change anything for ternary models, since their values should never end up being in halfway cases anyway. * convert : allow direct conversion to TQ1_0 and TQ2_0 The token embeddings and output tensors are kept in F16 to allow quantizing them to Q4_K and Q6_K with llama-quantize. * llama : handle fallback for TQ1_0 and TQ2_0 with Q4_0 Q4_0 is not completely symmetric (so not lossless for ternary models), but it should be good enough. * ggml-quants : allow using ARM dot product instructions for TQ1_0 * ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support * ggml : remove unused ggml_mul special case It would otherwise conflict with the more general optimization coming with Mamba-2. * ggml : handle TQ1_0 and TQ2_0 in dequantization-based operators * test-backend-ops : add TQ1_0 and TQ2_0 comments for later Not yet adding uncommented, because some backends like SYCL and Metal do not properly handle unknown types in supports_op for GGML_OP_MUL_MAT. (and Metal also doesn't handle it with GGML_OP_GET_ROWS) Support for TQ1_0 and TQ2_0 for other backends than CPU will be added in follow-up pull requests. --- convert_hf_to_gguf.py | 47 ++- examples/quantize/quantize.cpp | 2 + ggml/include/ggml.h | 2 + ggml/src/ggml-common.h | 20 + ggml/src/ggml-impl.h | 11 +- ggml/src/ggml-quants.c | 690 ++++++++++++++++++++++++++++++++- ggml/src/ggml-quants.h | 15 + ggml/src/ggml.c | 42 +- gguf-py/gguf/constants.py | 6 + gguf-py/gguf/quants.py | 81 ++++ gguf-py/tests/test_quants.py | 1 + include/llama.h | 2 + src/llama.cpp | 45 ++- tests/test-backend-ops.cpp | 2 + tests/test-quantize-fns.cpp | 6 + 15 files changed, 937 insertions(+), 35 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 27ac34b81..0a9bbc829 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -308,6 +308,20 @@ class Model: ): data_qtype = gguf.GGMLQuantizationType.F32 + if data_qtype is False and any( + self.match_model_tensor_name(new_name, key, bid) + for key in ( + gguf.MODEL_TENSOR.TOKEN_EMBD, + gguf.MODEL_TENSOR.OUTPUT, + ) + ): + if self.ftype in ( + gguf.LlamaFileType.MOSTLY_TQ1_0, + gguf.LlamaFileType.MOSTLY_TQ2_0, + ): + # TODO: use Q4_K and Q6_K + data_qtype = gguf.GGMLQuantizationType.F16 + # No override (data_qtype is False), or wants to be quantized (data_qtype is True) if isinstance(data_qtype, bool): if self.ftype == gguf.LlamaFileType.ALL_F32: @@ -318,6 +332,10 @@ class Model: data_qtype = gguf.GGMLQuantizationType.BF16 elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0: data_qtype = gguf.GGMLQuantizationType.Q8_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ1_0: + data_qtype = gguf.GGMLQuantizationType.TQ1_0 + elif self.ftype == gguf.LlamaFileType.MOSTLY_TQ2_0: + data_qtype = gguf.GGMLQuantizationType.TQ2_0 else: raise ValueError(f"Unknown file type: {self.ftype.name}") @@ -1623,15 +1641,16 @@ class BitnetModel(Model): self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) self.gguf_writer.add_rope_scaling_factor(1.0) - def weight_quant(self, weight): + def weight_quant(self, weight: Tensor) -> Tensor: dtype = weight.dtype weight = weight.float() - s = 1 / weight.abs().mean().clamp(min=1e-5) - weight = (weight * s).round().clamp(-1, 1) / s - scale = weight.abs().max().unsqueeze(0) - weight = torch.where(weight.abs().less(1e-6), 0, weight).type(dtype) - weight = torch.sign(weight).type(dtype) - return weight.type(dtype), scale.type(torch.float32) + scale = weight.abs().mean().clamp(min=1e-5) + iscale = 1 / scale + # TODO: multiply by the scale directly instead of inverting it twice + # (this is also unnecessarily doubly inverted upstream) + # ref: https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/af89e318d78a70802061246bf037199d2fb97020/utils_quant.py#L10 + result = (weight * iscale).round().clamp(-1, 1) / iscale + return result.type(dtype) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: new_name = self.map_tensor_name(name) @@ -1646,11 +1665,9 @@ class BitnetModel(Model): gguf.MODEL_TENSOR.FFN_GATE, ]): # transform weight into 1/0/-1 (in fp32) - weight_torch, scale_torch = self.weight_quant(data_torch) - yield (new_name, weight_torch) - yield (new_name.removesuffix(".weight") + ".scale", scale_torch) - else: - yield (new_name, data_torch) + data_torch = self.weight_quant(data_torch) + + yield (new_name, data_torch) @Model.register("GrokForCausalLM") @@ -4011,8 +4028,8 @@ def parse_args() -> argparse.Namespace: help="path to write to; default: based on input. {ftype} will be replaced by the outtype.", ) parser.add_argument( - "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "auto"], default="f16", - help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", + "--outtype", type=str, choices=["f32", "f16", "bf16", "q8_0", "tq1_0", "tq2_0", "auto"], default="f16", + help="output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, tq1_0 or tq2_0 for ternary, and auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type", ) parser.add_argument( "--bigendian", action="store_true", @@ -4099,6 +4116,8 @@ def main() -> None: "f16": gguf.LlamaFileType.MOSTLY_F16, "bf16": gguf.LlamaFileType.MOSTLY_BF16, "q8_0": gguf.LlamaFileType.MOSTLY_Q8_0, + "tq1_0": gguf.LlamaFileType.MOSTLY_TQ1_0, + "tq2_0": gguf.LlamaFileType.MOSTLY_TQ2_0, "auto": gguf.LlamaFileType.GUESSED, } diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 202346310..a23bfb86b 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -26,6 +26,8 @@ static const std::vector QUANT_OPTIONS = { { "IQ2_M", LLAMA_FTYPE_MOSTLY_IQ2_M, " 2.7 bpw quantization", }, { "IQ1_S", LLAMA_FTYPE_MOSTLY_IQ1_S, " 1.56 bpw quantization", }, { "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", }, + { "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0, " 1.69 bpw ternarization", }, + { "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0, " 2.06 bpw ternarization", }, { "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", }, { "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", }, { "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", }, diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3fb680360..09c72b095 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -395,6 +395,8 @@ extern "C" { GGML_TYPE_Q4_0_4_4 = 31, GGML_TYPE_Q4_0_4_8 = 32, GGML_TYPE_Q4_0_8_8 = 33, + GGML_TYPE_TQ1_0 = 34, + GGML_TYPE_TQ2_0 = 35, GGML_TYPE_COUNT, }; diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index e40057632..050161393 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -227,6 +227,25 @@ typedef struct { } block_q8_0x8; static_assert(sizeof(block_q8_0x8) == 8 * sizeof(ggml_half) + QK8_0 * 8, "wrong q8_0x8 block size/padding"); +// +// Ternary quantization +// + +// 1.6875 bpw +typedef struct { + uint8_t qs[(QK_K - 4 * QK_K / 64) / 5]; // 5 elements per byte (3^5 = 243 < 256) + uint8_t qh[QK_K/64]; // 4 elements per byte + ggml_half d; +} block_tq1_0; +static_assert(sizeof(block_tq1_0) == sizeof(ggml_half) + QK_K / 64 + (QK_K - 4 * QK_K / 64) / 5, "wrong tq1_0 block size/padding"); + +// 2.0625 bpw +typedef struct { + uint8_t qs[QK_K/4]; // 2 bits per element + ggml_half d; +} block_tq2_0; +static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding"); + // // Super-block quantization structures // @@ -361,6 +380,7 @@ typedef struct { } block_iq3_s; static_assert(sizeof(block_iq3_s) == sizeof(ggml_half) + 13*(QK_K/32) + IQ3S_N_SCALE, "wrong iq3_s block size/padding"); +// 1.5625 bpw typedef struct { ggml_half d; uint8_t qs[QK_K/8]; diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 190af0810..961f3c67b 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -175,7 +175,7 @@ typedef __fp16 ggml_fp16_internal_t; // 32-bit ARM compatibility -// vaddvq_s16 +// vaddlvq_s16 // vpaddq_s16 // vpaddq_s32 // vaddvq_s32 @@ -185,12 +185,9 @@ typedef __fp16 ggml_fp16_internal_t; // vzip1_u8 // vzip2_u8 -inline static int32_t vaddvq_s16(int16x8_t v) { - return - (int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) + - (int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) + - (int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) + - (int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7); +inline static int32_t vaddlvq_s16(int16x8_t v) { + int32x4_t v0 = vreinterpretq_s32_s64(vpaddlq_s32(vpaddlq_s16(v))); + return vgetq_lane_s32(v0, 0) + vgetq_lane_s32(v0, 2); } inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) { diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 48b90f01b..8c31e2cca 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -1630,7 +1630,7 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6 // ===================== Helper functions // static inline int nearest_int(float fval) { - assert(fval <= 4194303.f); + assert(fabsf(fval) <= 4194303.f); float val = fval + 12582912.f; int i; memcpy(&i, &val, sizeof(int)); return (i & 0x007fffff) - 0x00400000; @@ -3306,6 +3306,191 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } +// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs) + +void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK_K; j++) { + const float v = x[j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + // 5 elements per byte, along 32 bytes + for (size_t j = 0; j < sizeof(y->qs) - sizeof(y->qs) % 32; j += 32) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = 0; + for (size_t n = 0; n < 5; ++n) { + int xi = lroundf(x[m + n*32] * id) + 1; // -1, 0, 1 -> 0, 1, 2 + q *= 3; + q += xi; + } + // ceiling division (243 == pow(3, 5)) + q = ((uint16_t)q * 256 + (243 - 1)) / 243; + y[i].qs[j + m] = q; + } + x += 5*32; + } + // along 16 bytes + for (size_t j = sizeof(y->qs) - sizeof(y->qs) % 32; j < sizeof(y->qs); j += 16) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = 0; + for (size_t n = 0; n < 5; ++n) { + int xi = lroundf(x[m + n*16] * id) + 1; // -1, 0, 1 -> 0, 1, 2 + q *= 3; + q += xi; + } + // ceiling division (243 == pow(3, 5)) + q = ((uint16_t)q * 256 + (243 - 1)) / 243; + y[i].qs[j + m] = q; + } + x += 5*16; + } + // 4 elements per byte + for (size_t j = 0; j < sizeof(y->qh); ++j) { + uint8_t q = 0; + for (size_t m = 0; m < 4; ++m) { + // -1, 0, 1 -> 0, 1, 2 + int xi = lroundf(x[j + m*sizeof(y->qh)] * id) + 1; + q *= 3; + q += xi; + } + // shift the first value to the most significant trit + q *= 3; + // ceiling division (243 == pow(3, 5)) + q = ((uint16_t)q * 256 + (243 - 1)) / 243; + y[i].qh[j] = q; + } + x += 4*sizeof(y->qh); + } +} + +void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK_K; j++) { + const float v = x[j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = GGML_FP32_TO_FP16(d); + + for (size_t j = 0; j < sizeof(y->qs); j += 32) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = 0; + for (size_t n = 0; n < 4; ++n) { + // -1, 0, 1 -> 0, 1, 2 + int xi = lroundf(x[m + n*32] * id) + 1; + q += (xi & 3) << (2*n); + } + y[i].qs[j + m] = q; + } + x += 4*32; + } + } +} + +void quantize_row_tq1_0(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_tq1_0 * restrict y = vy; + quantize_row_tq1_0_ref(x, y, k); +} + +void quantize_row_tq2_0(const float * restrict x, void * restrict vy, int64_t k) { + assert(k % QK_K == 0); + block_tq2_0 * restrict y = vy; + quantize_row_tq2_0_ref(x, y, k); +} + +size_t quantize_tq1_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_TQ1_0, n_per_row); + quantize_row_tq1_0(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + +size_t quantize_tq2_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_TQ2_0, n_per_row); + quantize_row_tq2_0(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + + +void dequantize_row_tq1_0(const block_tq1_0 * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + for (int64_t i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { + for (size_t n = 0; n < 5; ++n) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[n]; + int16_t xi = ((uint16_t) q * 3) >> 8; + *y++ = (float) (xi - 1) * d; + } + } + } + for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { + for (size_t n = 0; n < 5; ++n) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[n]; + int16_t xi = ((uint16_t) q * 3) >> 8; + *y++ = (float) (xi - 1) * d; + } + } + } + + for (size_t n = 0; n < 4; ++n) { + for (size_t j = 0; j < sizeof(x->qh); ++j) { + uint8_t q = x[i].qh[j] * pow3[n]; + int16_t xi = ((uint16_t) q * 3) >> 8; + *y++ = (float) (xi - 1) * d; + } + } + } +} + +void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, int64_t k) { + assert(k % QK_K == 0); + const int64_t nb = k / QK_K; + + for (int64_t i = 0; i < nb; ++i) { + + const float d = GGML_FP16_TO_FP32(x[i].d); + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + for (size_t l = 0; l < 4; ++l) { + for (size_t m = 0; m < 32; ++m) { + int8_t q = (x[i].qs[j + m] >> (l*2)) & 3; + *y++ = (float) (q - 1) * d; + } + } + } + } +} + // ====================== "True" 2-bit (de)-quantization void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { @@ -5470,6 +5655,501 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r *s = sumf; } +void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq1_0 * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27}; + + const uint8x16_t shift = vld1q_u8(k_shift); + + for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif + + // first 32 bytes of 5 elements + { + uint8x16_t qx0 = vld1q_u8(x[i].qs + 0); + uint8x16_t qx1 = vld1q_u8(x[i].qs + 16); + uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3)); + uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3)); + uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9)); + uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9)); + uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27)); + uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27)); + uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81)); + uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81)); + + // multiply by 3 and keep the 2 bits above 8 bits + int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); + int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6)); + int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6)); + int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6)); + int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + 0); + const int8x16_t qy1 = vld1q_s8(y[i].qs + 16); + const int8x16_t qy2 = vld1q_s8(y[i].qs + 32); + const int8x16_t qy3 = vld1q_s8(y[i].qs + 48); + const int8x16_t qy4 = vld1q_s8(y[i].qs + 64); + const int8x16_t qy5 = vld1q_s8(y[i].qs + 80); + const int8x16_t qy6 = vld1q_s8(y[i].qs + 96); + const int8x16_t qy7 = vld1q_s8(y[i].qs + 112); + const int8x16_t qy8 = vld1q_s8(y[i].qs + 128); + const int8x16_t qy9 = vld1q_s8(y[i].qs + 144); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); + sumi0 = vdotq_s32(sumi0, sqx6, qy6); + sumi1 = vdotq_s32(sumi1, sqx7, qy7); + sumi0 = vdotq_s32(sumi0, sqx8, qy8); + sumi1 = vdotq_s32(sumi1, sqx9, qy9); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx8), vget_low_s8(qy8)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9)); +#endif + } + + // last 16 bytes of 5-element, along with the 4 bytes of 4 elements + { + uint8x16_t qx0 = vld1q_u8(x[i].qs + 32); + uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3)); + uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9)); + uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27)); + uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned + uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh)); + qx5 = vmulq_u8(qx5, shift); + + // multiply by 3 and keep the 2 bits above 8 bits + int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + 160); + const int8x16_t qy1 = vld1q_s8(y[i].qs + 176); + const int8x16_t qy2 = vld1q_s8(y[i].qs + 192); + const int8x16_t qy3 = vld1q_s8(y[i].qs + 208); + const int8x16_t qy4 = vld1q_s8(y[i].qs + 224); + const int8x16_t qy5 = vld1q_s8(y[i].qs + 240); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); +#endif + } + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else + sumi0 = vaddq_s16(sumi0, sumi1); + sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); + + sumf += d * (float) vaddlvq_s16(sumi0); +#endif + } + + *s = sumf; + +#elif defined(__AVX2__) + __m256 sumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + // 16-bit sums + __m256i sumi0 = _mm256_setzero_si256(); + __m256i sumi1 = _mm256_setzero_si256(); + __m256i sumi2 = _mm256_setzero_si256(); + + // first 32 bytes of 5 elements + { + __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs)); + // 8-bit multiplies with shifts, masks and adds + __m256i qx1 = _mm256_add_epi8(qx0, _mm256_add_epi8(qx0, qx0)); // 1 * 3 + __m256i qx2 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx0, 3), _mm256_set1_epi8(-8)), qx0); // 1 * 9 + __m256i qx3 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx1, 3), _mm256_set1_epi8(-8)), qx1); // 3 * 9 + __m256i qx4 = _mm256_add_epi8(_mm256_and_si256(_mm256_slli_epi16(qx2, 3), _mm256_set1_epi8(-8)), qx2); // 9 * 9 + + // TODO: can _mm256_mulhi_epu16 be faster even if 16-bits? + + // Cancel the +1 from avg so that it behaves like a halving add + qx0 = _mm256_subs_epu8(qx0, _mm256_set1_epi8(1)); + qx1 = _mm256_subs_epu8(qx1, _mm256_set1_epi8(1)); + qx2 = _mm256_subs_epu8(qx2, _mm256_set1_epi8(1)); + qx3 = _mm256_subs_epu8(qx3, _mm256_set1_epi8(1)); + qx4 = _mm256_subs_epu8(qx4, _mm256_set1_epi8(1)); + // Multiply by 3 and get the top 2 bits + qx0 = _mm256_avg_epu8(qx0, _mm256_avg_epu8(qx0, _mm256_setzero_si256())); + qx1 = _mm256_avg_epu8(qx1, _mm256_avg_epu8(qx1, _mm256_setzero_si256())); + qx2 = _mm256_avg_epu8(qx2, _mm256_avg_epu8(qx2, _mm256_setzero_si256())); + qx3 = _mm256_avg_epu8(qx3, _mm256_avg_epu8(qx3, _mm256_setzero_si256())); + qx4 = _mm256_avg_epu8(qx4, _mm256_avg_epu8(qx4, _mm256_setzero_si256())); + qx0 = _mm256_and_si256(_mm256_srli_epi16(qx0, 6), _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(_mm256_srli_epi16(qx1, 6), _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(_mm256_srli_epi16(qx2, 6), _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(_mm256_srli_epi16(qx3, 6), _mm256_set1_epi8(3)); + qx4 = _mm256_and_si256(_mm256_srli_epi16(qx4, 6), _mm256_set1_epi8(3)); + + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 0)); + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 32)); + const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 64)); + const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 96)); + const __m256i qy4 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 128)); + + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + qx4 = _mm256_maddubs_epi16(qx4, qy4); + + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); + sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); + sumi2 = _mm256_add_epi16(sumi2, qx4); + } + + // last 16 bytes of 5-element, along with the 4 bytes of 4 elements + { + __m128i qx0 = _mm_loadu_si128((const __m128i *) (x[i].qs + 32)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned + __m256i qx5_l = _mm256_cvtepu8_epi16(_mm_set1_epi32(qh)); + __m128i qx1 = _mm_add_epi8(qx0, _mm_add_epi8(qx0, qx0)); // 1 * 3 + __m128i qx2 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx0, 3), _mm_set1_epi8(-8)), qx0); // 1 * 9 + __m128i qx3 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx1, 3), _mm_set1_epi8(-8)), qx1); // 3 * 9 + __m128i qx4 = _mm_add_epi8(_mm_and_si128(_mm_slli_epi16(qx2, 3), _mm_set1_epi8(-8)), qx2); // 9 * 9 + __m256i qx01 = MM256_SET_M128I(qx1, qx0); + __m256i qx23 = MM256_SET_M128I(qx3, qx2); + + // avx2 does not have 8-bit multiplies, so 16-bit it is. + qx5_l = _mm256_mullo_epi16(qx5_l, _mm256_set_epi16(27, 27, 27, 27, 9, 9, 9, 9, 3, 3, 3, 3, 1, 1, 1, 1)); + qx5_l = _mm256_and_si256(qx5_l, _mm256_set1_epi16(0xFF)); + __m128i qx5 = _mm_packus_epi16(_mm256_castsi256_si128(qx5_l), _mm256_extracti128_si256(qx5_l, 1)); + + __m256i qx45 = MM256_SET_M128I(qx5, qx4); + + // Cancel the +1 from avg so that it behaves like a halving add + qx01 = _mm256_subs_epu8(qx01, _mm256_set1_epi8(1)); + qx23 = _mm256_subs_epu8(qx23, _mm256_set1_epi8(1)); + qx45 = _mm256_subs_epu8(qx45, _mm256_set1_epi8(1)); + // Multiply by 3 and get the top 2 bits + qx01 = _mm256_avg_epu8(qx01, _mm256_avg_epu8(qx01, _mm256_setzero_si256())); + qx23 = _mm256_avg_epu8(qx23, _mm256_avg_epu8(qx23, _mm256_setzero_si256())); + qx45 = _mm256_avg_epu8(qx45, _mm256_avg_epu8(qx45, _mm256_setzero_si256())); + qx01 = _mm256_and_si256(_mm256_srli_epi16(qx01, 6), _mm256_set1_epi8(3)); + qx23 = _mm256_and_si256(_mm256_srli_epi16(qx23, 6), _mm256_set1_epi8(3)); + qx45 = _mm256_and_si256(_mm256_srli_epi16(qx45, 6), _mm256_set1_epi8(3)); + + const __m256i qy01 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 160)); + const __m256i qy23 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 192)); + const __m256i qy45 = _mm256_loadu_si256((const __m256i *) (y[i].qs + 224)); + + qx01 = _mm256_maddubs_epi16(qx01, qy01); + qx23 = _mm256_maddubs_epi16(qx23, qy23); + qx45 = _mm256_maddubs_epi16(qx45, qy45); + + sumi0 = _mm256_add_epi16(sumi0, qx01); + sumi1 = _mm256_add_epi16(sumi1, qx23); + sumi2 = _mm256_add_epi16(sumi2, qx45); + } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); + + sumi0 = _mm256_sub_epi16(sumi0, ysum); + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(sumi1, sumi2)); + sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); + + sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); + } + + *s = hsum_float_8(sumf); + +#else + const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243}; + + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int sum = 0; + + for (size_t j = 0; j < sizeof(x->qs) - sizeof(x->qs) % 32; j += 32) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 32; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*32 + m]; + } + } + } + for (size_t j = sizeof(x->qs) - sizeof(x->qs) % 32; j < sizeof(x->qs); j += 16) { + for (size_t l = 0; l < 5; ++l) { + for (size_t m = 0; m < 16; ++m) { + uint8_t q = x[i].qs[j + m] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[j*5 + l*16 + m]; + } + } + } + + for (size_t l = 0; l < 4; ++l) { + for (size_t j = 0; j < sizeof(x->qh); ++j) { + uint8_t q = x[i].qh[j] * pow3[l]; + uint16_t xi = ((uint16_t) q * 3) >> 8; + sum += (xi - 1) * y[i].qs[sizeof(x->qs)*5 + l*sizeof(x->qh) + j]; + } + } + + sumf += (float) sum * (GGML_FP16_TO_FP32(x[i].d) * y[i].d); + } + + *s = sumf; +#endif +} + +void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + assert(nrc == 1); + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_tq2_0 * restrict x = vx; + const block_q8_K * restrict y = vy; + + const int nb = n / QK_K; + +#if defined(__ARM_NEON) + float sumf = 0.0f; + + const uint8x16_t m3 = vdupq_n_u8(3); + + for (int i = 0; i < nb; ++i) { +#if defined(__ARM_FEATURE_DOTPROD) + int32x4_t sumi0 = vdupq_n_s32(0); + int32x4_t sumi1 = vdupq_n_s32(0); +#else + int16x8_t sumi0 = vdupq_n_s16(0); + int16x8_t sumi1 = vdupq_n_s16(0); +#endif + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + uint8x16_t qx0 = vld1q_u8(x[i].qs + j); + uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16); + uint8x16_t qx2 = vshrq_n_u8(qx0, 2); + uint8x16_t qx3 = vshrq_n_u8(qx1, 2); + uint8x16_t qx4 = vshrq_n_u8(qx0, 4); + uint8x16_t qx5 = vshrq_n_u8(qx1, 4); + uint8x16_t qx6 = vshrq_n_u8(qx0, 6); + uint8x16_t qx7 = vshrq_n_u8(qx1, 6); + + int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3)); + int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3)); + int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3)); + int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3)); + int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3)); + int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3)); + int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3)); + int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3)); + + const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0); + const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16); + const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32); + const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48); + const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64); + const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80); + const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96); + const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112); + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vdotq_s32(sumi0, sqx0, qy0); + sumi1 = vdotq_s32(sumi1, sqx1, qy1); + sumi0 = vdotq_s32(sumi0, sqx2, qy2); + sumi1 = vdotq_s32(sumi1, sqx3, qy3); + sumi0 = vdotq_s32(sumi0, sqx4, qy4); + sumi1 = vdotq_s32(sumi1, sqx5, qy5); + sumi0 = vdotq_s32(sumi0, sqx6, qy6); + sumi1 = vdotq_s32(sumi1, sqx7, qy7); +#else + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx1), vget_high_s8(qy1)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx2), vget_low_s8(qy2)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx2), vget_high_s8(qy2)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx3), vget_low_s8(qy3)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx3), vget_high_s8(qy3)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx4), vget_low_s8(qy4)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx6), vget_low_s8(qy6)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6)); + sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7)); + sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7)); +#endif + } + + const int16x8_t ysum0 = vld1q_s16(y[i].bsums); + const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8); + + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + +#if defined(__ARM_FEATURE_DOTPROD) + sumi0 = vaddq_s32(sumi0, sumi1); + sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1))); + + sumf += d * (float) vaddvq_s32(sumi0); +#else + sumi0 = vaddq_s16(sumi0, sumi1); + sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1)); + + sumf += d * (float) vaddlvq_s16(sumi0); +#endif + } + + *s = sumf; + +#elif defined(__AVX2__) + __m256 sumf = _mm256_setzero_ps(); + + for (int i = 0; i < nb; ++i) { + // 16-bit sums, because 256*127 still fits + __m256i sumi0 = _mm256_setzero_si256(); + __m256i sumi1 = _mm256_setzero_si256(); + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + __m256i qx0 = _mm256_loadu_si256((const __m256i *) (x[i].qs + j)); + __m256i qx1 = _mm256_srli_epi16(qx0, 2); + __m256i qx2 = _mm256_srli_epi16(qx0, 4); + __m256i qx3 = _mm256_srli_epi16(qx0, 6); + + // 0, 1, 2 (should not be 3) + qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3)); + + const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0)); + const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32)); + const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64)); + const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96)); + + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); + + sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); + sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); + } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); + const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); + + sumi0 = _mm256_add_epi16(sumi0, sumi1); + sumi0 = _mm256_sub_epi16(sumi0, ysum); + sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); + + sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); + } + + *s = hsum_float_8(sumf); + +#else + float sumf = 0.0f; + + for (int i = 0; i < nb; ++i) { + int32_t sumi = 0; + + for (size_t j = 0; j < sizeof(x->qs); j += 32) { + for (size_t l = 0; l < 4; ++l) { + for (size_t k = 0; k < 32; ++k) { + sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].qs[j + k] >> (l*2)) & 3) - 1); + } + } + } + + const float d = y[i].d * GGML_FP16_TO_FP32(x[i].d); + + sumf += (float) sumi * d; + } + + *s = sumf; +#endif +} + void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -14800,6 +15480,14 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte } } } break; + case GGML_TYPE_TQ1_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_tq1_0, data, nb); + } break; + case GGML_TYPE_TQ2_0: + { + VALIDATE_ROW_DATA_D_F16_IMPL(block_tq2_0, data, nb); + } break; case GGML_TYPE_IQ1_S: { VALIDATE_ROW_DATA_D_F16_IMPL(block_iq1_s, data, nb); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 525d5ee30..e96ce2b5e 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -26,6 +26,9 @@ void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_REST void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k); +void quantize_row_tq1_0_ref(const float * GGML_RESTRICT x, block_tq1_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_tq2_0_ref(const float * GGML_RESTRICT x, block_tq2_0 * GGML_RESTRICT y, int64_t k); + void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k); @@ -46,6 +49,9 @@ void quantize_row_q5_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, in void quantize_row_q6_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_tq1_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_tq2_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); + void quantize_row_iq3_xxs(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_nl (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_iq4_xs (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -67,6 +73,9 @@ void dequantize_row_q5_K(const block_q5_K * GGML_RESTRICT x, float * GGML_RESTRI void dequantize_row_q6_K(const block_q6_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q8_K(const block_q8_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_tq1_0(const block_tq1_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_tq2_0(const block_tq2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); + void dequantize_row_iq2_xxs(const block_iq2_xxs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_iq2_xs (const block_iq2_xs * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_iq2_s (const block_iq2_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -90,6 +99,9 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_tq1_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_tq2_0_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); + void ggml_vec_dot_iq2_xxs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq2_xs_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_iq2_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -111,6 +123,9 @@ size_t quantize_iq4_nl (const float * GGML_RESTRICT src, void * GGML_RESTRICT ds size_t quantize_iq4_xs (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_iq3_s (const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_tq1_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_tq2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); + size_t quantize_q2_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 6e2ebf283..c98ca32bd 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1054,7 +1054,31 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .ncols = 8, .gemv = ggml_gemv_q4_0_8x8_q8_0, .gemm = ggml_gemm_q4_0_8x8_q8_0, - } + }, + [GGML_TYPE_TQ1_0] = { + .type_name = "tq1_0", + .blck_size = QK_K, + .type_size = sizeof(block_tq1_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_tq1_0, + .from_float = quantize_row_tq1_0, + .from_float_ref = (ggml_from_float_t) quantize_row_tq1_0_ref, + .vec_dot = ggml_vec_dot_tq1_0_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, + [GGML_TYPE_TQ2_0] = { + .type_name = "tq2_0", + .blck_size = QK_K, + .type_size = sizeof(block_tq2_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_tq2_0, + .from_float = quantize_row_tq2_0, + .from_float_ref = (ggml_from_float_t) quantize_row_tq2_0_ref, + .vec_dot = ggml_vec_dot_tq2_0_q8_K, + .vec_dot_type = GGML_TYPE_Q8_K, + .nrows = 1, + }, }; // For internal test use @@ -9897,6 +9921,8 @@ static void ggml_compute_forward_add( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -10275,6 +10301,8 @@ static void ggml_compute_forward_add1( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -10403,6 +10431,8 @@ static void ggml_compute_forward_acc( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -13386,6 +13416,8 @@ static void ggml_compute_forward_out_prod( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -13574,6 +13606,8 @@ static void ggml_compute_forward_set( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -13836,6 +13870,8 @@ static void ggml_compute_forward_get_rows( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -14425,6 +14461,8 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS: @@ -21868,6 +21906,8 @@ size_t ggml_quantize_chunk( case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_K: result = quantize_q5_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q6_K: result = quantize_q6_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ1_0: result = quantize_tq1_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_TQ2_0: result = quantize_tq2_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XXS: result = quantize_iq2_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ2_XS: result = quantize_iq2_xs (src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_IQ3_XXS: result = quantize_iq3_xxs(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a48c4fb67..c87d08782 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1291,6 +1291,8 @@ class GGMLQuantizationType(IntEnum): Q4_0_4_4 = 31 Q4_0_4_8 = 32 Q4_0_8_8 = 33 + TQ1_0 = 34 + TQ2_0 = 35 # TODO: add GGMLFileType from ggml_ftype in ggml.h @@ -1335,6 +1337,8 @@ class LlamaFileType(IntEnum): MOSTLY_Q4_0_4_4 = 33 # except 1d tensors MOSTLY_Q4_0_4_8 = 34 # except 1d tensors MOSTLY_Q4_0_8_8 = 35 # except 1d tensors + MOSTLY_TQ1_0 = 36 # except 1d tensors + MOSTLY_TQ2_0 = 37 # except 1d tensors GUESSED = 1024 # not specified in the model file @@ -1411,6 +1415,8 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16), GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16), GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16), + GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13), + GGMLQuantizationType.TQ2_0: (256, 2 + 64), } diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index ff589b852..3c8ba82e1 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -574,6 +574,87 @@ class Q6_K(__Quant, qtype=GGMLQuantizationType.Q6_K): return (d * q).reshape((n_blocks, QK_K)) +class TQ1_0(__Quant, qtype=GGMLQuantizationType.TQ1_0): + @classmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d = abs(blocks).max(axis=-1, keepdims=True) + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + qs = np_roundf(blocks * id) + qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8) + + qs0, qs1, qh = qs[..., :(32 * 5)], qs[..., (32 * 5):(48 * 5)], qs[..., (48 * 5):] + qs0 = qs0.reshape((n_blocks, -1, 5, 32)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs0 = np.sum(qs0, axis=-2).reshape((n_blocks, -1)) + qs1 = qs1.reshape((n_blocks, -1, 5, 16)) * np.array([81, 27, 9, 3, 1], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs1 = np.sum(qs1, axis=-2).reshape((n_blocks, -1)) + qh = qh.reshape((n_blocks, -1, 4, 4)) * np.array([81, 27, 9, 3], dtype=np.uint8).reshape((1, 1, 4, 1)) + qh = np.sum(qh, axis=-2).reshape((n_blocks, -1)) + qs = np.concatenate([qs0, qs1, qh], axis=-1) + qs = (qs.astype(np.uint16) * 256 + (243 - 1)) // 243 + + qs = qs.astype(np.uint8) + d = d.astype(np.float16).view(np.uint8) + + return np.concatenate([qs, d], axis=-1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + qs, rest = np.hsplit(blocks, [(QK_K - 4 * QK_K // 64) // 5]) + qh, d = np.hsplit(rest, [QK_K // 64]) + + d = d.view(np.float16).astype(np.float32) + + qs0, qs1 = qs[..., :32], qs[..., 32:] + qs0 = qs0.reshape((n_blocks, -1, 1, 32)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs0 = qs0.reshape((n_blocks, -1)) + qs1 = qs1.reshape((n_blocks, -1, 1, 16)) * np.array([1, 3, 9, 27, 81], dtype=np.uint8).reshape((1, 1, 5, 1)) + qs1 = qs1.reshape((n_blocks, -1)) + qh = qh.reshape((n_blocks, -1, 1, 4)) * np.array([1, 3, 9, 27], dtype=np.uint8).reshape((1, 1, 4, 1)) + qh = qh.reshape((n_blocks, -1)) + qs = np.concatenate([qs0, qs1, qh], axis=-1) + qs = ((qs.astype(np.uint16) * 3) >> 8).astype(np.int8) - np.int8(1) + + return (d * qs.astype(np.float32)) + + +class TQ2_0(__Quant, qtype=GGMLQuantizationType.TQ2_0): + @classmethod + def quantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + d = abs(blocks).max(axis=-1, keepdims=True) + with np.errstate(divide="ignore"): + id = np.where(d == 0, 0, 1 / d) + qs = np_roundf(blocks * id) + qs = (qs.astype(np.int8) + np.int8(1)).astype(np.uint8) + + qs = qs.reshape((n_blocks, -1, 4, 32)) << np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + qs = qs[..., 0, :] | qs[..., 1, :] | qs[..., 2, :] | qs[..., 3, :] + qs = qs.reshape((n_blocks, -1)) + + d = d.astype(np.float16).view(np.uint8) + + return np.concatenate([qs, d], axis=-1) + + @classmethod + def dequantize_blocks(cls, blocks: np.ndarray) -> np.ndarray: + n_blocks = blocks.shape[0] + + qs, d = np.hsplit(blocks, [QK_K // 4]) + + d = d.view(np.float16).astype(np.float32) + + qs = qs.reshape((n_blocks, -1, 1, 32)) >> np.array([0, 2, 4, 6], dtype=np.uint8).reshape((1, 1, 4, 1)) + qs = (qs & 0x03).reshape((n_blocks, -1)).astype(np.int8) - np.int8(1) + + return (d * qs.astype(np.float32)) + + class IQ2_XXS(__Quant, qtype=GGMLQuantizationType.IQ2_XXS): ksigns: bytes = ( b"\x00\x81\x82\x03\x84\x05\x06\x87\x88\x09\x0a\x8b\x0c\x8d\x8e\x0f" diff --git a/gguf-py/tests/test_quants.py b/gguf-py/tests/test_quants.py index 8b7a85c2c..762067814 100755 --- a/gguf-py/tests/test_quants.py +++ b/gguf-py/tests/test_quants.py @@ -66,6 +66,7 @@ class GGMLQuants: for t in ( "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "q2_K", "q3_K", "q4_K", "q5_K", "q6_K", + "tq1_0", "tq2_0", "iq2_xxs", "iq2_xs", "iq2_s", "iq3_xxs", "iq3_s", "iq1_s", "iq1_m", "iq4_nl", "iq4_xs", ): diff --git a/include/llama.h b/include/llama.h index bfc37e88b..a495e866d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -167,6 +167,8 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors + LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors + LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index c3669eb28..1a78112a3 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4444,6 +4444,8 @@ struct llama_model_loader { case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + case GGML_TYPE_TQ1_0: ftype = LLAMA_FTYPE_MOSTLY_TQ1_0; break; + case GGML_TYPE_TQ2_0: ftype = LLAMA_FTYPE_MOSTLY_TQ2_0; break; case GGML_TYPE_IQ2_XXS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XXS; break; case GGML_TYPE_IQ2_XS: ftype = LLAMA_FTYPE_MOSTLY_IQ2_XS; break; case GGML_TYPE_IQ2_S: ftype = LLAMA_FTYPE_MOSTLY_IQ2_S; break; @@ -5137,6 +5139,8 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small"; case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K"; + case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary"; + case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary"; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw"; case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw"; @@ -8118,23 +8122,23 @@ static bool llm_load_tensors( layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd}); layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); - layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}); + layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); - layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}); + layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); - layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}); + layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); - layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}); + layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff}); layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); - layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}); + layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); - layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}); + layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); - layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}); + layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED); } } break; case LLM_ARCH_T5: @@ -14177,7 +14181,9 @@ struct llm_build_context { { // compute Q and K and RoPE them struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); - Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); + if (model.layers[il].wq_scale) { + Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale); + } cb(Qcur, "Qcur", il); if (model.layers[il].bq) { Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); @@ -14186,7 +14192,9 @@ struct llm_build_context { // B1.K struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); - Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); + if (model.layers[il].wk_scale) { + Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale); + } cb(Kcur, "Kcur", il); if (model.layers[il].bk) { Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); @@ -14195,7 +14203,9 @@ struct llm_build_context { // B1.V struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); - Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); + if (model.layers[il].wv_scale) { + Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale); + } cb(Vcur, "Vcur", il); if (model.layers[il].bv) { Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); @@ -14226,7 +14236,9 @@ struct llm_build_context { cb(cur, "attn_sub_norm", il); cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur); - cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + if (model.layers[il].wo_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale); + } if (model.layers[il].bo) { cur = ggml_add(ctx0, cur, model.layers[il].bo); } @@ -14263,7 +14275,9 @@ struct llm_build_context { cb(cur, "ffn_sub_norm", il); cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur); - cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + if (model.layers[il].ffn_down_scale) { + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale); + } cb(cur, "ffn_down", il); cur = ggml_add(ctx0, cur, ffn_inp); @@ -16933,6 +16947,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n new_type == GGML_TYPE_Q4_0_8_8) { new_type = GGML_TYPE_Q4_0; } + else if (ftype == LLAMA_FTYPE_MOSTLY_TQ1_0 || ftype == LLAMA_FTYPE_MOSTLY_TQ2_0) { + new_type = GGML_TYPE_Q4_K; + } } } else if (ftype == LLAMA_FTYPE_MOSTLY_IQ2_XXS || ftype == LLAMA_FTYPE_MOSTLY_IQ2_XS || ftype == LLAMA_FTYPE_MOSTLY_IQ1_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_S || ftype == LLAMA_FTYPE_MOSTLY_IQ2_M || ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { @@ -17132,6 +17149,8 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n } if (convert_incompatible_tensor) { switch (new_type) { + case GGML_TYPE_TQ1_0: + case GGML_TYPE_TQ2_0: new_type = GGML_TYPE_Q4_0; break; // TODO: use a symmetric type instead case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_S: @@ -17237,6 +17256,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s case LLAMA_FTYPE_MOSTLY_Q5_K_S: case LLAMA_FTYPE_MOSTLY_Q5_K_M: default_type = GGML_TYPE_Q5_K; break; case LLAMA_FTYPE_MOSTLY_Q6_K: default_type = GGML_TYPE_Q6_K; break; + case LLAMA_FTYPE_MOSTLY_TQ1_0: default_type = GGML_TYPE_TQ1_0; break; + case LLAMA_FTYPE_MOSTLY_TQ2_0: default_type = GGML_TYPE_TQ2_0; break; case LLAMA_FTYPE_MOSTLY_IQ2_XXS: default_type = GGML_TYPE_IQ2_XXS; break; case LLAMA_FTYPE_MOSTLY_IQ2_XS: default_type = GGML_TYPE_IQ2_XS; break; case LLAMA_FTYPE_MOSTLY_IQ2_S: default_type = GGML_TYPE_IQ2_XS; break; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index c832bc956..bd65e8cb3 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2200,6 +2200,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, + // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends GGML_TYPE_IQ2_XXS, GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, @@ -2219,6 +2220,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, GGML_TYPE_Q5_K, GGML_TYPE_Q6_K, + // GGML_TYPE_TQ1_0, GGML_TYPE_TQ2_0, // TODO: implement for all backends GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S, GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M, GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS, diff --git a/tests/test-quantize-fns.cpp b/tests/test-quantize-fns.cpp index c97458d1d..ccf5721a3 100644 --- a/tests/test-quantize-fns.cpp +++ b/tests/test-quantize-fns.cpp @@ -15,11 +15,13 @@ constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f; +constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f; constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f; constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f; constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f; +constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f; static const char* RESULT_STR[] = {"ok", "FAILED"}; @@ -144,6 +146,8 @@ int main(int argc, char * argv[]) { if (qfns.from_float && qfns.to_float) { const float total_error = total_quantization_error(qfns, test_size, test_data.data()); const float max_quantization_error = + type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : + type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY : type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_IQ2_S ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS : type == GGML_TYPE_Q3_K ? MAX_QUANTIZATION_TOTAL_ERROR_3BITS : @@ -166,6 +170,8 @@ int main(int argc, char * argv[]) { const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS || type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S ? MAX_DOT_PRODUCT_ERROR_LOWBIT + : type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0 + ? MAX_DOT_PRODUCT_ERROR_TERNARY : MAX_DOT_PRODUCT_ERROR; failed = !(vec_dot_error < max_allowed_error); num_failed += failed;