mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 19:34:35 +00:00
ggml : remove q1_3 and q2_2
* llama : remove the separate scale tensors of BitNet b1.58 They won't be needed, since the remaining ternary quant types have built-in scales.
This commit is contained in:
parent
45719a2472
commit
04eec58112
@ -284,9 +284,6 @@ class Model:
|
||||
|
||||
for new_name, data in ((n, d.squeeze().numpy()) for n, d in self.modify_tensors(data_torch, name, bid)):
|
||||
data: np.ndarray # type hint
|
||||
if len(data.shape) == 0:
|
||||
# otherwise single-value tensors get squeezed
|
||||
data = data.reshape((1,))
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
data_qtype: gguf.GGMLQuantizationType | None = None
|
||||
@ -317,33 +314,12 @@ class Model:
|
||||
))
|
||||
|
||||
if self.ftype != gguf.LlamaFileType.ALL_F32 and extra_f16 and not extra_f32:
|
||||
# TODO: cleaner model-specific per-tensor types
|
||||
# NOTE: Q1_3 is only relevant for BitNet b1.58
|
||||
if (
|
||||
self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3
|
||||
and gguf.can_quantize_to_q1_3(data)
|
||||
and not any(
|
||||
self.match_model_tensor_name(new_name, key, None)
|
||||
for key in [
|
||||
gguf.MODEL_TENSOR.TOKEN_EMBD,
|
||||
gguf.MODEL_TENSOR.OUTPUT,
|
||||
]
|
||||
)
|
||||
):
|
||||
data = gguf.quantize_q1_3(data)
|
||||
assert data.dtype == np.uint8
|
||||
data_qtype = gguf.GGMLQuantizationType.Q1_3
|
||||
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
if self.ftype == gguf.LlamaFileType.MOSTLY_BF16:
|
||||
data = gguf.quantize_bf16(data)
|
||||
assert data.dtype == np.int16
|
||||
data_qtype = gguf.GGMLQuantizationType.BF16
|
||||
|
||||
elif (
|
||||
self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0
|
||||
or self.ftype == gguf.LlamaFileType.MOSTLY_Q1_3
|
||||
and gguf.can_quantize_to_q8_0(data)
|
||||
):
|
||||
elif self.ftype == gguf.LlamaFileType.MOSTLY_Q8_0 and gguf.can_quantize_to_q8_0(data):
|
||||
data = gguf.quantize_q8_0(data)
|
||||
assert data.dtype == np.uint8
|
||||
data_qtype = gguf.GGMLQuantizationType.Q8_0
|
||||
@ -1635,12 +1611,6 @@ class LlamaModel(Model):
|
||||
class BitnetModel(Model):
|
||||
model_arch = gguf.MODEL_ARCH.BITNET
|
||||
|
||||
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, *args, **kwargs):
|
||||
if ftype == gguf.LlamaFileType.GUESSED:
|
||||
ftype = gguf.LlamaFileType.MOSTLY_Q1_3
|
||||
|
||||
super().__init__(dir_model, ftype, *args, **kwargs)
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
@ -1649,16 +1619,16 @@ class BitnetModel(Model):
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
|
||||
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||
|
||||
def weight_quant(self, weight):
|
||||
def weight_quant(self, weight: Tensor) -> Tensor:
|
||||
dtype = weight.dtype
|
||||
weight = weight.float()
|
||||
scale = weight.abs().mean().clamp(min=1e-5)
|
||||
iscale = 1 / scale
|
||||
weight = (weight * iscale).round().clamp(-1, 1)
|
||||
# TODO: use the scale directly instead of inverting it twice
|
||||
# TODO: multiply by the scale directly instead of inverting it twice
|
||||
# (this is also unnecessarily doubly inverted upstream)
|
||||
# ref: https://huggingface.co/1bitLLM/bitnet_b1_58-3B/blob/af89e318d78a70802061246bf037199d2fb97020/utils_quant.py#L10
|
||||
return weight.type(dtype), (1 / iscale).type(torch.float32)
|
||||
result = (weight * iscale).round().clamp(-1, 1) / iscale
|
||||
return result.type(dtype)
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
new_name = self.map_tensor_name(name)
|
||||
@ -1673,11 +1643,9 @@ class BitnetModel(Model):
|
||||
gguf.MODEL_TENSOR.FFN_GATE,
|
||||
]):
|
||||
# transform weight into 1/0/-1 (in fp32)
|
||||
weight_torch, scale_torch = self.weight_quant(data_torch)
|
||||
yield (new_name, weight_torch)
|
||||
yield (new_name.removesuffix(".weight") + ".scale", scale_torch)
|
||||
else:
|
||||
yield (new_name, data_torch)
|
||||
data_torch = self.weight_quant(data_torch)
|
||||
|
||||
yield (new_name, data_torch)
|
||||
|
||||
|
||||
@Model.register("GrokForCausalLM")
|
||||
|
@ -28,8 +28,6 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
|
||||
{ "IQ1_M", LLAMA_FTYPE_MOSTLY_IQ1_M, " 1.75 bpw quantization", },
|
||||
{ "TQ1_0", LLAMA_FTYPE_MOSTLY_TQ1_0, " 1.69 bpw ternarization", },
|
||||
{ "TQ2_0", LLAMA_FTYPE_MOSTLY_TQ2_0, " 2.06 bpw ternarization", },
|
||||
{ "Q1_3", LLAMA_FTYPE_MOSTLY_Q1_3, " 1.63 bpw for BitNet b1.58", },
|
||||
{ "Q2_2", LLAMA_FTYPE_MOSTLY_Q2_2, " 2.00 bpw for BitNet b1.58", },
|
||||
{ "Q2_K", LLAMA_FTYPE_MOSTLY_Q2_K, " 2.96G, +3.5199 ppl @ Llama-3-8B", },
|
||||
{ "Q2_K_S", LLAMA_FTYPE_MOSTLY_Q2_K_S, " 2.96G, +3.1836 ppl @ Llama-3-8B", },
|
||||
{ "IQ3_XXS", LLAMA_FTYPE_MOSTLY_IQ3_XXS, " 3.06 bpw quantization", },
|
||||
|
@ -392,8 +392,6 @@ extern "C" {
|
||||
GGML_TYPE_Q4_0_8_8 = 33,
|
||||
GGML_TYPE_TQ1_0 = 34,
|
||||
GGML_TYPE_TQ2_0 = 35,
|
||||
GGML_TYPE_Q2_2 = 36,
|
||||
GGML_TYPE_Q1_3 = 37,
|
||||
GGML_TYPE_COUNT,
|
||||
};
|
||||
|
||||
|
@ -141,20 +141,6 @@ typedef sycl::half2 ggml_half2;
|
||||
|
||||
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
|
||||
|
||||
// 1.625 bpw for BitNet b1.58 models
|
||||
#define QK1_3 64
|
||||
typedef struct {
|
||||
uint8_t q[(QK1_3 - 4*QK1_3/64)/5]; // 5 elements per byte (3^5 = 243 < 256)
|
||||
uint8_t qs[QK1_3/64]; // 4 elements per byte
|
||||
} block_q1_3;
|
||||
static_assert(sizeof(block_q1_3) == (QK1_3 - 4*QK1_3/64)/5 + QK1_3/64, "wrong q1_3 block size/padding");
|
||||
|
||||
#define QK2_2 32
|
||||
typedef struct {
|
||||
uint8_t qs[QK2_2 / 4]; // nibbles / quants
|
||||
} block_q2_2;
|
||||
static_assert(sizeof(block_q2_2) == QK2_2 / 4, "wrong q2_2 block size/padding");
|
||||
|
||||
#define QK4_0 32
|
||||
typedef struct {
|
||||
ggml_half d; // delta
|
||||
@ -1084,41 +1070,6 @@ GGML_TABLE_BEGIN(uint32_t, iq3s_grid, 512)
|
||||
0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101,
|
||||
GGML_TABLE_END()
|
||||
|
||||
GGML_TABLE_BEGIN(uint32_t, q1_3_grid, 256)
|
||||
0xffffffff, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff,
|
||||
0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff000000, 0xff000001,
|
||||
0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff,
|
||||
0xff010000, 0xff010001, 0xff0101ff, 0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01,
|
||||
0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00,
|
||||
0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101,
|
||||
0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000, 0x00010001, 0x000101ff, 0x00010100,
|
||||
0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001,
|
||||
0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000,
|
||||
0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x0101ff01,
|
||||
0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00,
|
||||
0xffffff01, 0xffff00ff, 0xffff0000, 0xffff0001, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff,
|
||||
0xff00ff00, 0xff00ff01, 0xff0000ff, 0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100,
|
||||
0xff000101, 0xff01ffff, 0xff01ff00, 0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff,
|
||||
0xff010100, 0xff010101, 0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0000,
|
||||
0x00ff0001, 0x00ff01ff, 0x00ff0100, 0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff,
|
||||
0x00000000, 0x00000001, 0x000001ff, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01,
|
||||
0x000100ff, 0x00010000, 0x00010000, 0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff,
|
||||
0x01ffff00, 0x01ffff01, 0x01ff00ff, 0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101,
|
||||
0x0100ffff, 0x0100ff00, 0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x01000001, 0x010001ff,
|
||||
0x01000100, 0x01000101, 0x0101ffff, 0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001,
|
||||
0x010101ff, 0x01010100, 0x01010101, 0xffffffff, 0xffffff00, 0xffffff01, 0xffff00ff, 0xffff0000,
|
||||
0xffff0001, 0xffff01ff, 0xffff01ff, 0xffff0100, 0xffff0101, 0xff00ffff, 0xff00ff00, 0xff00ff01,
|
||||
0xff0000ff, 0xff000000, 0xff000001, 0xff0001ff, 0xff000100, 0xff000101, 0xff01ffff, 0xff01ff00,
|
||||
0xff01ff01, 0xff0100ff, 0xff010000, 0xff010001, 0xff0101ff, 0xff0101ff, 0xff010100, 0xff010101,
|
||||
0x00ffffff, 0x00ffff00, 0x00ffff01, 0x00ff00ff, 0x00ff0000, 0x00ff0001, 0x00ff01ff, 0x00ff0100,
|
||||
0x00ff0101, 0x0000ffff, 0x0000ff00, 0x0000ff01, 0x000000ff, 0x00000000, 0x00000001, 0x000001ff,
|
||||
0x00000100, 0x00000100, 0x00000101, 0x0001ffff, 0x0001ff00, 0x0001ff01, 0x000100ff, 0x00010000,
|
||||
0x00010001, 0x000101ff, 0x00010100, 0x00010101, 0x01ffffff, 0x01ffff00, 0x01ffff01, 0x01ff00ff,
|
||||
0x01ff0000, 0x01ff0001, 0x01ff01ff, 0x01ff0100, 0x01ff0101, 0x01ff0101, 0x0100ffff, 0x0100ff00,
|
||||
0x0100ff01, 0x010000ff, 0x01000000, 0x01000001, 0x010001ff, 0x01000100, 0x01000101, 0x0101ffff,
|
||||
0x0101ff00, 0x0101ff01, 0x010100ff, 0x01010000, 0x01010001, 0x010101ff, 0x01010100, 0x01010101,
|
||||
GGML_TABLE_END()
|
||||
|
||||
#define NGRID_IQ1S 2048
|
||||
#define IQ1S_DELTA 0.125f
|
||||
#define IQ1M_DELTA 0.125f
|
||||
|
@ -657,39 +657,6 @@ static inline __m128i packNibbles( __m256i bytes ) {
|
||||
}
|
||||
#endif //__loongarch_asx
|
||||
|
||||
void quantize_row_q2_2_ref(const float * restrict x, block_q2_2 * restrict y, int64_t k) {
|
||||
static const int qk = QK2_2;
|
||||
|
||||
assert(k % qk == 0);
|
||||
|
||||
const int nb = k / qk;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
||||
for (int j = 0; j < qk/4; ++j) {
|
||||
int8_t x0 = (int8_t)x[i*qk + 0 + j];
|
||||
int8_t x1 = (int8_t)x[i*qk + 1*qk/4 + j];
|
||||
int8_t x2 = (int8_t)x[i*qk + 2*qk/4 + j];
|
||||
int8_t x3 = (int8_t)x[i*qk + 3*qk/4 + j];
|
||||
|
||||
const uint8_t xi0 = x0 < 0 ? 1 : x0 == 0 ? 2 : 3;
|
||||
const uint8_t xi1 = x1 < 0 ? 1 : x1 == 0 ? 2 : 3;
|
||||
const uint8_t xi2 = x2 < 0 ? 1 : x2 == 0 ? 2 : 3;
|
||||
const uint8_t xi3 = x3 < 0 ? 1 : x3 == 0 ? 2 : 3;
|
||||
|
||||
y[i].qs[j] = 0;
|
||||
y[i].qs[j] |= (xi0 << 0);
|
||||
y[i].qs[j] |= (xi1 << 2);
|
||||
y[i].qs[j] |= (xi2 << 4);
|
||||
y[i].qs[j] |= (xi3 << 6);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_q2_2(const float * restrict x, void * restrict y, int64_t k) {
|
||||
quantize_row_q2_2_ref(x, y, k);
|
||||
}
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
|
||||
static const int qk = QK4_0;
|
||||
@ -1545,26 +1512,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
|
||||
#endif
|
||||
}
|
||||
|
||||
void dequantize_row_q2_2(const block_q2_2 * restrict x, float * restrict y, int64_t k) {
|
||||
static const int qk = QK2_2;
|
||||
|
||||
assert(k % qk == 0);
|
||||
|
||||
const int nb = k / qk;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
|
||||
for (int j = 0; j < qk/4; ++j) {
|
||||
const int8_t q = x[i].qs[j];
|
||||
|
||||
y[i*qk + j + 0 ] = (float) (((q >> 0) & 3) - 2);
|
||||
y[i*qk + j + 1*qk/4] = (float) (((q >> 2) & 3) - 2);
|
||||
y[i*qk + j + 2*qk/4] = (float) (((q >> 4) & 3) - 2);
|
||||
y[i*qk + j + 3*qk/4] = (float) (((q >> 6) & 3) - 2);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) {
|
||||
static const int qk = QK4_0;
|
||||
|
||||
@ -3359,13 +3306,6 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr
|
||||
return nrow * row_size;
|
||||
}
|
||||
|
||||
size_t quantize_q2_2(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
(void)quant_weights; // not used
|
||||
const size_t row_size = ggml_row_size(GGML_TYPE_Q2_2, n_per_row);
|
||||
quantize_row_q2_2_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * row_size;
|
||||
}
|
||||
|
||||
// ====================== Ternary (de)-quantization (BitNet b1.58 and TriLMs)
|
||||
|
||||
void quantize_row_tq1_0_ref(const float * restrict x, block_tq1_0 * restrict y, int64_t k) {
|
||||
@ -3552,89 +3492,6 @@ void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, in
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_q1_3_ref(const float * restrict x, block_q1_3 * restrict y, int64_t k) {
|
||||
assert(k % QK1_3 == 0);
|
||||
const int64_t nb = k / QK1_3;
|
||||
static_assert(sizeof(y->q) % 4 == 0, "bad block_q1_3.q size");
|
||||
|
||||
const uint8_t pow3[6] = {1, 3, 9, 27, 81, 243};
|
||||
|
||||
for (int64_t i = 0; i < nb; ++i) {
|
||||
uint8_t q[sizeof(y->q)] = {0};
|
||||
for (size_t j = 0; j < sizeof(y->q); ++j) {
|
||||
for (size_t m = 0; m < 4; ++m) {
|
||||
int xi = nearest_int(x[m]);
|
||||
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
|
||||
q[j] += xt * pow3[m];
|
||||
}
|
||||
x += 4;
|
||||
}
|
||||
for (size_t j = 0; j < sizeof(y->q); ++j) {
|
||||
int xi = nearest_int(x[j]);
|
||||
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
|
||||
q[j] += xt * pow3[4];
|
||||
// ceiling division
|
||||
q[j] = ((uint16_t)q[j] * 256 + (pow3[5] - 1)) / pow3[5];
|
||||
y[i].q[j] = q[j];
|
||||
}
|
||||
x += sizeof(y->q);
|
||||
|
||||
for (size_t j = 0; j < sizeof(y->qs); ++j) {
|
||||
uint8_t qb = 0;
|
||||
for (size_t m = 0; m < 4; ++m) {
|
||||
int xi = nearest_int(x[m]);
|
||||
uint8_t xt = xi < 0 ? 0 : xi == 0 ? 1 : 2;
|
||||
qb += xt * pow3[m];
|
||||
}
|
||||
x += 4;
|
||||
// ceiling division
|
||||
qb = ((uint16_t)qb * 256 + (pow3[5] - 1)) / pow3[5];
|
||||
y[i].qs[j] = qb;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void quantize_row_q1_3(const float * restrict x, void * restrict vy, int64_t k) {
|
||||
assert(k % QK1_3 == 0);
|
||||
block_q1_3 * restrict y = vy;
|
||||
quantize_row_q1_3_ref(x, y, k);
|
||||
}
|
||||
|
||||
size_t quantize_q1_3(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||
(void)quant_weights; // not used
|
||||
const size_t row_size = ggml_row_size(GGML_TYPE_Q1_3, n_per_row);
|
||||
quantize_row_q1_3(src, dst, (int64_t)nrow*n_per_row);
|
||||
return nrow * row_size;
|
||||
}
|
||||
|
||||
void dequantize_row_q1_3(const block_q1_3 * restrict x, float * restrict y, int64_t k) {
|
||||
assert(k % QK1_3 == 0);
|
||||
const int64_t nb = k / QK1_3;
|
||||
static_assert(sizeof(x->q) % 4 == 0, "bad block_q1_3.q size");
|
||||
|
||||
for (int64_t i = 0; i < nb; ++i) {
|
||||
for (size_t j = 0; j < sizeof(x->q); ++j) {
|
||||
const int8_t * q = (const int8_t *) (q1_3_grid + x[i].q[j]);
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
*y++ = (float) q[m];
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < sizeof(x->q); ++j) {
|
||||
uint16_t q = x[i].q[j];
|
||||
int16_t qi = (q * 3) >> 8;
|
||||
*y++ = (float) (qi - 1);
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < sizeof(x->qs); ++j) {
|
||||
const int8_t * q = (const int8_t *) (q1_3_grid + x[i].qs[j]);
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
*y++ = (float) q[m];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ====================== "True" 2-bit (de)-quantization
|
||||
|
||||
void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) {
|
||||
@ -4055,122 +3912,6 @@ static inline __m128i get_scale_shuffle(int i) {
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
||||
assert(n % qk == 0);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
const block_q2_2 * restrict x = vx;
|
||||
const block_q8_0 * restrict y = vy;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(y[i].d) );
|
||||
|
||||
// assuming this is always aligned
|
||||
__m256i xq8 = _mm256_set1_epi64x(*(const int64_t *) x[i].qs);
|
||||
xq8 = _mm256_srlv_epi64(xq8, _mm256_set_epi64x(6, 4, 2, 0));
|
||||
xq8 = _mm256_and_si256(xq8, _mm256_set1_epi8(0x03));
|
||||
// stangely enough, this is much slower with 1 instead of 2
|
||||
xq8 = _mm256_sub_epi8(xq8, _mm256_set1_epi8(2));
|
||||
|
||||
const __m256i yq8 = _mm256_loadu_si256((const __m256i *) (y[i].qs));
|
||||
const __m256 q = mul_sum_i8_pairs_float(xq8, yq8);
|
||||
|
||||
acc = _mm256_fmadd_ps( d, q, acc );
|
||||
}
|
||||
|
||||
*s = hsum_float_8(acc);
|
||||
#elif defined(__ARM_NEON)
|
||||
float sumf0 = 0.0f;
|
||||
float sumf1 = 0.0f;
|
||||
|
||||
const uint8x8_t mask = vdup_n_u8(3);
|
||||
const int8x8_t offset = vdup_n_s8(2);
|
||||
|
||||
const int leftovers = nb % 2;
|
||||
|
||||
for (int i = 0; i < nb - leftovers; i += 2) {
|
||||
const uint8x8_t xq8_0 = vld1_u8(x[0].qs);
|
||||
const uint8x8_t xq8_1 = vld1_u8(x[1].qs);
|
||||
|
||||
const int8x8_t xq8_0_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8_0, mask)), offset);
|
||||
const int8x8_t xq8_0_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 2), mask)), offset);
|
||||
const int8x8_t xq8_0_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 4), mask)), offset);
|
||||
const int8x8_t xq8_0_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 6), mask)), offset);
|
||||
const int8x8_t xq8_1_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8_1, mask)), offset);
|
||||
const int8x8_t xq8_1_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 2), mask)), offset);
|
||||
const int8x8_t xq8_1_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 4), mask)), offset);
|
||||
const int8x8_t xq8_1_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 6), mask)), offset);
|
||||
|
||||
const int8x16_t xq8_0_l = vcombine_s8(xq8_0_0, xq8_0_1);
|
||||
const int8x16_t xq8_0_h = vcombine_s8(xq8_0_2, xq8_0_3);
|
||||
const int8x16_t xq8_1_l = vcombine_s8(xq8_1_0, xq8_1_1);
|
||||
const int8x16_t xq8_1_h = vcombine_s8(xq8_1_2, xq8_1_3);
|
||||
|
||||
const int8x16_t yq8_0_l = vld1q_s8(y[0].qs + 0);
|
||||
const int8x16_t yq8_0_h = vld1q_s8(y[0].qs + 16);
|
||||
const int8x16_t yq8_1_l = vld1q_s8(y[1].qs + 0);
|
||||
const int8x16_t yq8_1_h = vld1q_s8(y[1].qs + 16);
|
||||
|
||||
const int16x8_t dot0 = vaddq_s16(vpaddlq_s8(vmulq_s8(xq8_0_l, yq8_0_l)), vpaddlq_s8(vmulq_s8(xq8_0_h, yq8_0_h)));
|
||||
const int16x8_t dot1 = vaddq_s16(vpaddlq_s8(vmulq_s8(xq8_1_l, yq8_1_l)), vpaddlq_s8(vmulq_s8(xq8_1_h, yq8_1_h)));
|
||||
|
||||
sumf0 += GGML_FP16_TO_FP32(y[0].d) * (float) vaddlvq_s16(dot0);
|
||||
sumf1 += GGML_FP16_TO_FP32(y[1].d) * (float) vaddlvq_s16(dot1);
|
||||
x += 2;
|
||||
y += 2;
|
||||
}
|
||||
|
||||
// one block at a time
|
||||
for (int i = nb - leftovers; i < nb; ++i) {
|
||||
const uint8x8_t xq8 = vld1_u8(x->qs);
|
||||
const int8x8_t xq8_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8, mask)), offset);
|
||||
const int8x8_t xq8_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 2), mask)), offset);
|
||||
const int8x8_t xq8_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 4), mask)), offset);
|
||||
const int8x8_t xq8_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 6), mask)), offset);
|
||||
|
||||
const int8x16_t xq8_l = vcombine_s8(xq8_0, xq8_1);
|
||||
const int8x16_t xq8_h = vcombine_s8(xq8_2, xq8_3);
|
||||
|
||||
const int8x16_t yq8_l = vld1q_s8(y->qs + 0);
|
||||
const int8x16_t yq8_h = vld1q_s8(y->qs + 16);
|
||||
|
||||
const int16x8_t dot0 = vpaddlq_s8(vmulq_s8(xq8_l, yq8_l));
|
||||
const int16x8_t dot1 = vpaddlq_s8(vmulq_s8(xq8_h, yq8_h));
|
||||
|
||||
sumf0 += GGML_FP16_TO_FP32(y->d) * (float) vaddlvq_s16(vaddq_s16(dot0, dot1));
|
||||
x += 1;
|
||||
y += 1;
|
||||
}
|
||||
|
||||
*s = sumf0 + sumf1;
|
||||
#else
|
||||
|
||||
float sumf = 0.0f;
|
||||
for (int i = 0; i < nb; i++) {
|
||||
int sumi = 0;
|
||||
for (int j = 0; j < qk / 4; j++) {
|
||||
const uint8_t weight = x[i].qs[j];
|
||||
sumi += (int)y[i].qs[j + 0*qk/4] * (((weight >> 0) & 3) - 2);
|
||||
sumi += (int)y[i].qs[j + 1*qk/4] * (((weight >> 2) & 3) - 2);
|
||||
sumi += (int)y[i].qs[j + 2*qk/4] * (((weight >> 4) & 3) - 2);
|
||||
sumi += (int)y[i].qs[j + 3*qk/4] * (((weight >> 6) & 3) - 2);
|
||||
}
|
||||
sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d));
|
||||
}
|
||||
*s = sumf;
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
@ -11855,225 +11596,6 @@ static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
// assumed by the code below
|
||||
assert(n % QK1_3 == 0);
|
||||
static_assert(QK1_3 == 2 * QK8_0, "QK1_3 must be 2 times bigger than QK8_0");
|
||||
|
||||
const block_q1_3 * restrict x = vx;
|
||||
const block_q8_0 * restrict y = vy;
|
||||
|
||||
const int nb = n / QK1_3;
|
||||
|
||||
#if defined(__AVX2__)
|
||||
__m256 accumf = _mm256_setzero_ps();
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
// const __m128i x12a = _mm_maskload_epi32((const int32_t *) x, _mm_set_epi32(0, -1, -1, -1));
|
||||
// const __m128i x13b = _mm_insert_epi8(x12a, x->qs[0], 12);
|
||||
// WARNING: reading 3 bytes further than necessary.
|
||||
// It's measurably faster than a masked load on an Intel Core m3-8100Y
|
||||
const __m128i x13b = _mm_loadu_si128((const __m128i *) x);
|
||||
const __m256i x13 = MM256_SET_M128I(x13b, x13b);
|
||||
|
||||
{
|
||||
// pre-shift the values by 8 bits, and prepare the layout for later packing
|
||||
__m256i x0l = _mm256_shuffle_epi8(x13, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
|
||||
4, -1, 4, -1, 4, -1, 4, -1,
|
||||
1, -1, 1, -1, 1, -1, 1, -1,
|
||||
0, -1, 0, -1, 0, -1, 0, -1));
|
||||
__m256i x0h = _mm256_shuffle_epi8(x13, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
|
||||
6, -1, 6, -1, 6, -1, 6, -1,
|
||||
3, -1, 3, -1, 3, -1, 3, -1,
|
||||
2, -1, 2, -1, 2, -1, 2, -1));
|
||||
__m256i x1l = _mm256_shuffle_epi8(x13, _mm256_set_epi8(7, -1, 6, -1, 5, -1, 4, -1,
|
||||
3, -1, 2, -1, 1, -1, 0, -1,
|
||||
9, -1, 9, -1, 9, -1, 9, -1,
|
||||
8, -1, 8, -1, 8, -1, 8, -1));
|
||||
__m256i x1h = _mm256_shuffle_epi8(x13, _mm256_set_epi8(12, -1, 12, -1, 12, -1, 12, -1,
|
||||
11, -1, 10, -1, 9, -1, 8, -1,
|
||||
11, -1, 11, -1, 11, -1, 11, -1,
|
||||
10, -1, 10, -1, 10, -1, 10, -1));
|
||||
const __m256i shift0 = _mm256_set_epi16(3, 9, 27, 81,
|
||||
3, 9, 27, 81,
|
||||
3, 9, 27, 81,
|
||||
3, 9, 27, 81);
|
||||
const __m256i shift1l = _mm256_set_epi16(1, 1, 1, 1,
|
||||
1, 1, 1, 1,
|
||||
3, 9, 27, 81,
|
||||
3, 9, 27, 81);
|
||||
const __m256i shift1h = _mm256_set_epi16(3, 9, 27, 81,
|
||||
1, 1, 1, 1,
|
||||
3, 9, 27, 81,
|
||||
3, 9, 27, 81);
|
||||
// extract ternary values
|
||||
// first by shifting the numbers to make each one the next significant digit
|
||||
x0l = _mm256_mullo_epi16(x0l, shift0);
|
||||
x0h = _mm256_mullo_epi16(x0h, shift0);
|
||||
x1l = _mm256_mullo_epi16(x1l, shift1l);
|
||||
x1h = _mm256_mullo_epi16(x1h, shift1h);
|
||||
// then by extracting each of these most significant digits
|
||||
x0l = _mm256_mulhi_epu16(x0l, _mm256_set1_epi16(3));
|
||||
x0h = _mm256_mulhi_epu16(x0h, _mm256_set1_epi16(3));
|
||||
x1l = _mm256_mulhi_epu16(x1l, _mm256_set1_epi16(3));
|
||||
x1h = _mm256_mulhi_epu16(x1h, _mm256_set1_epi16(3));
|
||||
|
||||
__m256i x0 = _mm256_packs_epi16(x0l, x0h);
|
||||
__m256i x1 = _mm256_packs_epi16(x1l, x1h);
|
||||
|
||||
// 0, 1, 2 => -1, 0, 1
|
||||
x0 = _mm256_sub_epi8(x0, _mm256_set1_epi8(1));
|
||||
x1 = _mm256_sub_epi8(x1, _mm256_set1_epi8(1));
|
||||
|
||||
const __m256i y0 = _mm256_loadu_si256((const __m256i *) (y[0].qs));
|
||||
const __m256i y1 = _mm256_loadu_si256((const __m256i *) (y[1].qs));
|
||||
|
||||
const __m256 d0 = _mm256_set1_ps(GGML_FP16_TO_FP32(y[0].d));
|
||||
const __m256 d1 = _mm256_set1_ps(GGML_FP16_TO_FP32(y[1].d));
|
||||
|
||||
const __m256 q0 = mul_sum_i8_pairs_float(y0, x0);
|
||||
const __m256 q1 = mul_sum_i8_pairs_float(y1, x1);
|
||||
|
||||
accumf = _mm256_fmadd_ps(d0, q0, accumf);
|
||||
accumf = _mm256_fmadd_ps(d1, q1, accumf);
|
||||
}
|
||||
|
||||
x += 1;
|
||||
y += 2;
|
||||
}
|
||||
|
||||
*s = hsum_float_8(accumf);
|
||||
#elif defined(__ARM_NEON)
|
||||
|
||||
static const uint8_t k_mask0[16] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
|
||||
static const uint8_t k_mask1[16] = {4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7};
|
||||
static const uint8_t k_mask2[16] = {8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 11, 11};
|
||||
static const uint8_t k_mask3[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 12, 12};
|
||||
|
||||
static const uint8_t k_shift0[16] = {81, 27, 9, 3, 81, 27, 9, 3, 81, 27, 9, 3, 81, 27, 9, 3};
|
||||
static const uint8_t k_shift3[16] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 81, 27, 9, 3};
|
||||
|
||||
// float32x4_t sumv0 = vdupq_n_f32(0.0f);
|
||||
// float32x4_t sumv1 = vdupq_n_f32(0.0f);
|
||||
|
||||
float sumf0 = 0.0f;
|
||||
float sumf1 = 0.0f;
|
||||
|
||||
const uint8x16_t mask0 = vld1q_u8(k_mask0);
|
||||
const uint8x16_t mask1 = vld1q_u8(k_mask1);
|
||||
const uint8x16_t mask2 = vld1q_u8(k_mask2);
|
||||
const uint8x16_t mask3 = vld1q_u8(k_mask3);
|
||||
|
||||
const uint8x16_t shift0 = vld1q_u8(k_shift0);
|
||||
const uint8x16_t shift3 = vld1q_u8(k_shift3);
|
||||
|
||||
const int8x16_t one = vdupq_n_s8(1);
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
// WARNING: reading 3 bytes further than necessary
|
||||
const uint8x16_t x13b = vld1q_u8((const uint8_t *) x);
|
||||
|
||||
uint8x16_t x0 = ggml_vqtbl1q_u8(x13b, mask0);
|
||||
uint8x16_t x1 = ggml_vqtbl1q_u8(x13b, mask1);
|
||||
uint8x16_t x2 = ggml_vqtbl1q_u8(x13b, mask2);
|
||||
uint8x16_t x3 = ggml_vqtbl1q_u8(x13b, mask3);
|
||||
|
||||
x0 = vmulq_u8(x0, shift0);
|
||||
x1 = vmulq_u8(x1, shift0);
|
||||
x2 = vmulq_u8(x2, shift0);
|
||||
x3 = vmulq_u8(x3, shift3);
|
||||
|
||||
// multiply by 3 and keep the 2 bits above 8 bits
|
||||
x0 = vshrq_n_u8(vhaddq_u8(x0, vshrq_n_u8(x0, 1)), 6);
|
||||
x1 = vshrq_n_u8(vhaddq_u8(x1, vshrq_n_u8(x1, 1)), 6);
|
||||
x2 = vshrq_n_u8(vhaddq_u8(x2, vshrq_n_u8(x2, 1)), 6);
|
||||
x3 = vshrq_n_u8(vhaddq_u8(x3, vshrq_n_u8(x3, 1)), 6);
|
||||
|
||||
// 0, 1, 2 => -1, 0, 1
|
||||
int8x16_t x0i = vsubq_s8(vreinterpretq_s8_u8(x0), one);
|
||||
int8x16_t x1i = vsubq_s8(vreinterpretq_s8_u8(x1), one);
|
||||
int8x16_t x2i = vsubq_s8(vreinterpretq_s8_u8(x2), one);
|
||||
int8x16_t x3i = vsubq_s8(vreinterpretq_s8_u8(x3), one);
|
||||
|
||||
const int8x16_t y0 = vld1q_s8(y[0].qs + 0);
|
||||
const int8x16_t y1 = vld1q_s8(y[0].qs + 16);
|
||||
const int8x16_t y2 = vld1q_s8(y[1].qs + 0);
|
||||
const int8x16_t y3 = vld1q_s8(y[1].qs + 16);
|
||||
|
||||
// const int32x4_t p0 = vpaddlq_s16(vaddq_s16(vpaddlq_s8(x0i), vpaddlq_s8(x1i)));
|
||||
// const int32x4_t p1 = vpaddlq_s16(vaddq_s16(vpaddlq_s8(x2i), vpaddlq_s8(x3i)));
|
||||
|
||||
// there's no direct equivalent to _mm_sign_epi8, unfortunately
|
||||
x0i = vmulq_s8(x0i, y0);
|
||||
x1i = vmulq_s8(x1i, y1);
|
||||
x2i = vmulq_s8(x2i, y2);
|
||||
x3i = vmulq_s8(x3i, y3);
|
||||
|
||||
// overall 18.5% faster than with vector sums on a cortex-A72
|
||||
sumf0 += GGML_FP16_TO_FP32(y[0].d) * (float) vaddlvq_s16(vaddq_s16(vpaddlq_s8(x0i), vpaddlq_s8(x1i)));
|
||||
sumf1 += GGML_FP16_TO_FP32(y[1].d) * (float) vaddlvq_s16(vaddq_s16(vpaddlq_s8(x2i), vpaddlq_s8(x3i)));
|
||||
|
||||
// const int32x4_t p0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), x0i, y0), x1i, y1);
|
||||
// const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), x2i, y2), x3i, y3);
|
||||
|
||||
// sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p0), GGML_FP16_TO_FP32(y[0].d));
|
||||
// sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p1), GGML_FP16_TO_FP32(y[1].d));
|
||||
|
||||
y += 2;
|
||||
x += 1;
|
||||
}
|
||||
|
||||
// *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
|
||||
*s = sumf0 + sumf1;
|
||||
#else
|
||||
float sumf = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
int sum = 0;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[j]);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
sum += xj[k] * (int16_t) y->qs[4*j + k];
|
||||
}
|
||||
}
|
||||
|
||||
sumf += GGML_FP16_TO_FP32(y->d) * sum;
|
||||
y += 1;
|
||||
sum = 0;
|
||||
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[8 + j]);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
sum += xj[k] * (int16_t) y->qs[4*j + k];
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < 12; ++j) {
|
||||
uint16_t xj = x[i].q[j];
|
||||
xj = (xj * 3) >> 8;
|
||||
sum += ((int16_t) xj - 1) * (int16_t) y->qs[16 + j];
|
||||
}
|
||||
|
||||
{
|
||||
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].qs[0]);
|
||||
for (int k = 0; k < 4; ++k) {
|
||||
sum += (int16_t) xj[k] * (int16_t) y->qs[28 + k];
|
||||
}
|
||||
}
|
||||
|
||||
sumf += GGML_FP16_TO_FP32(y->d) * sum;
|
||||
y += 1;
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
#endif
|
||||
}
|
||||
|
||||
void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
|
||||
assert(n % QK_K == 0);
|
||||
assert(nrc == 1);
|
||||
@ -15964,8 +15486,7 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
|
||||
{
|
||||
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
|
||||
} break;
|
||||
case GGML_TYPE_Q1_3:
|
||||
case GGML_TYPE_Q2_2:
|
||||
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
@ -12,8 +12,6 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
// Quantization
|
||||
void quantize_row_q1_3_ref(const float * GGML_RESTRICT x, block_q1_3 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q2_2_ref(const float * GGML_RESTRICT x, block_q2_2 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
|
||||
@ -37,8 +35,6 @@ void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGM
|
||||
void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
void quantize_row_q1_3(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q2_2(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
@ -63,8 +59,6 @@ void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
|
||||
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// Dequantization
|
||||
void dequantize_row_q1_3(const block_q1_3 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_q2_2(const block_q2_2 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
@ -93,8 +87,6 @@ void dequantize_row_iq4_xs (const block_iq4_xs * GGML_RESTRICT x, float * GGML_
|
||||
void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||
|
||||
// Dot product
|
||||
void ggml_vec_dot_q1_3_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q2_2_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc);
|
||||
@ -139,8 +131,6 @@ size_t quantize_q3_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst,
|
||||
size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q1_3(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q2_2(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||
|
@ -856,30 +856,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q2_2] = {
|
||||
.type_name = "q2_2",
|
||||
.blck_size = QK2_2,
|
||||
.type_size = sizeof(block_q2_2),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q2_2,
|
||||
.from_float = quantize_row_q2_2,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q2_2_ref,
|
||||
.vec_dot = ggml_vec_dot_q2_2_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q1_3] = {
|
||||
.type_name = "q1_3",
|
||||
.blck_size = QK1_3,
|
||||
.type_size = sizeof(block_q1_3),
|
||||
.is_quantized = true,
|
||||
.to_float = (ggml_to_float_t) dequantize_row_q1_3,
|
||||
.from_float = quantize_row_q1_3,
|
||||
.from_float_ref = (ggml_from_float_t) quantize_row_q1_3_ref,
|
||||
.vec_dot = ggml_vec_dot_q1_3_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_IQ1_S] = {
|
||||
.type_name = "iq1_s",
|
||||
.blck_size = QK_K,
|
||||
@ -13936,8 +13912,6 @@ static void ggml_compute_forward_clamp(
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
case GGML_TYPE_Q1_3:
|
||||
case GGML_TYPE_Q2_2:
|
||||
case GGML_TYPE_Q4_0:
|
||||
case GGML_TYPE_Q4_1:
|
||||
case GGML_TYPE_Q5_0:
|
||||
@ -20638,8 +20612,6 @@ size_t ggml_quantize_chunk(
|
||||
size_t result = 0;
|
||||
|
||||
switch (type) {
|
||||
case GGML_TYPE_Q1_3: result = quantize_q1_3(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q2_2: result = quantize_q2_2(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
|
||||
|
@ -1150,8 +1150,6 @@ class GGMLQuantizationType(IntEnum):
|
||||
Q4_0_8_8 = 33
|
||||
TQ1_0 = 34
|
||||
TQ2_0 = 35
|
||||
Q1_3 = 36
|
||||
Q2_2 = 37
|
||||
|
||||
|
||||
# TODO: add GGMLFileType from ggml_ftype in ggml.h
|
||||
@ -1193,8 +1191,11 @@ class LlamaFileType(IntEnum):
|
||||
MOSTLY_IQ4_XS = 30 # except 1d tensors
|
||||
MOSTLY_IQ1_M = 31 # except 1d tensors
|
||||
MOSTLY_BF16 = 32 # except 1d tensors
|
||||
MOSTLY_Q2_2 = 33 # except 1d tensors
|
||||
MOSTLY_Q1_3 = 34 # except 1d tensors
|
||||
MOSTLY_Q4_0_4_4 = 33 # except 1d tensors
|
||||
MOSTLY_Q4_0_4_8 = 34 # except 1d tensors
|
||||
MOSTLY_Q4_0_8_8 = 35 # except 1d tensors
|
||||
MOSTLY_TQ1_0 = 36 # except 1d tensors
|
||||
MOSTLY_TQ2_0 = 37 # except 1d tensors
|
||||
|
||||
GUESSED = 1024 # not specified in the model file
|
||||
|
||||
@ -1268,8 +1269,11 @@ GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
|
||||
GGMLQuantizationType.F64: (1, 8),
|
||||
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
|
||||
GGMLQuantizationType.BF16: (1, 2),
|
||||
GGMLQuantizationType.Q2_2: (32, 8),
|
||||
GGMLQuantizationType.Q1_3: (64, 12 + 1),
|
||||
GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16),
|
||||
GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16),
|
||||
GGMLQuantizationType.TQ1_0: (256, 2 + 4 * 13),
|
||||
GGMLQuantizationType.TQ2_0: (256, 2 + 64),
|
||||
}
|
||||
|
||||
|
||||
|
@ -121,55 +121,3 @@ def quantize_q8_0(data: np.ndarray):
|
||||
return __quantize_q8_0_lazy(data)
|
||||
else:
|
||||
return __quantize_q8_0_array(data)
|
||||
|
||||
|
||||
__q1_3_block_size, __q1_3_type_size = GGML_QUANT_SIZES[GGMLQuantizationType.Q1_3]
|
||||
|
||||
|
||||
def can_quantize_to_q1_3(n: np.ndarray) -> bool:
|
||||
return n.shape[-1] % __q1_3_block_size == 0
|
||||
|
||||
|
||||
def __quantize_q1_3_shape_change(s: tuple[int, ...]) -> tuple[int, ...]:
|
||||
return (*s[:-1], s[-1] // __q1_3_block_size * __q1_3_type_size)
|
||||
|
||||
|
||||
def __quantize_q1_3_rows(n: np.ndarray) -> np.ndarray:
|
||||
shape = n.shape
|
||||
assert shape[-1] % __q1_3_block_size == 0
|
||||
|
||||
n_blocks = n.size // __q1_3_block_size
|
||||
|
||||
blocks = n.reshape((n_blocks, __q1_3_block_size)).astype(np.float32, copy=False)
|
||||
|
||||
# assuming the weights are pre-scaled
|
||||
blocks = (np.sign(blocks).astype(np.int8) + 1).view(np.uint8)
|
||||
q48, rest = np.hsplit(blocks, (48,))
|
||||
q12, q4 = np.hsplit(rest, (12,))
|
||||
|
||||
pow3 = np.array([1, 3, 9, 27])
|
||||
q48 = q48.reshape((n_blocks, 12, 4))
|
||||
q48 = np.sum(q48 * pow3.reshape((1, 1, 4)), axis=2, keepdims=True).reshape((n_blocks, 12))
|
||||
q4 = np.sum(q4 * pow3.reshape((1, 4)), axis=1, keepdims=True)
|
||||
q48 = q48 + (q12 * 81)
|
||||
q = np.concatenate([q48, q4], axis=1)
|
||||
q = (((q.astype(np.uint16) * 256) + (243 - 1)) // 243).astype(np.uint8)
|
||||
|
||||
return q.reshape(__quantize_q1_3_shape_change(shape))
|
||||
|
||||
|
||||
def __quantize_q1_3_array(n: np.ndarray) -> np.ndarray:
|
||||
return __apply_over_grouped_rows(__quantize_q1_3_rows, arr=n, otype=np.uint8, oshape=__quantize_q1_3_shape_change(n.shape))
|
||||
|
||||
|
||||
__quantize_q1_3_lazy = LazyNumpyTensor._wrap_fn(
|
||||
__quantize_q1_3_array,
|
||||
meta_noop=(np.uint8, __quantize_q1_3_shape_change),
|
||||
)
|
||||
|
||||
|
||||
def quantize_q1_3(data: np.ndarray):
|
||||
if type(data) is LazyNumpyTensor:
|
||||
return __quantize_q1_3_lazy(data)
|
||||
else:
|
||||
return __quantize_q1_3_array(data)
|
||||
|
@ -168,8 +168,6 @@ extern "C" {
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_TQ1_0 = 36, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_TQ2_0 = 37, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q1_3 = 38, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q2_2 = 39, // except 1d tensors
|
||||
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
@ -4451,8 +4451,6 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
|
||||
case LLAMA_FTYPE_ALL_F32: return "all F32";
|
||||
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
|
||||
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
|
||||
case LLAMA_FTYPE_MOSTLY_Q1_3: return "Q1_3 - 1.625 bpw for BitNet b1.58";
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_2: return "Q2_2 - 2.000 bpw for BitNet b1.58";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
|
||||
@ -7347,23 +7345,23 @@ static bool llm_load_tensors(
|
||||
layer.attn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_SUB_NORM, "weight", i), {n_embd});
|
||||
|
||||
layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd});
|
||||
layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1});
|
||||
layer.wq_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "scale", i), {1});
|
||||
layer.wk_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
|
||||
layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "scale", i), {1});
|
||||
layer.wv_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
|
||||
layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1});
|
||||
layer.wo_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
|
||||
layer.ffn_sub_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_SUB_NORM, "weight", i), {n_ff});
|
||||
|
||||
layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1});
|
||||
layer.ffn_gate_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
|
||||
layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1});
|
||||
layer.ffn_down_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff});
|
||||
layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1});
|
||||
layer.ffn_up_scale = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP, "scale", i), {1}, llama_model_loader::TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_T5:
|
||||
@ -13028,7 +13026,9 @@ struct llm_build_context {
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur);
|
||||
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
|
||||
if (model.layers[il].wq_scale) {
|
||||
Qcur = ggml_mul(ctx0, Qcur, model.layers[il].wq_scale);
|
||||
}
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
@ -13037,7 +13037,9 @@ struct llm_build_context {
|
||||
|
||||
// B1.K
|
||||
struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur);
|
||||
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
|
||||
if (model.layers[il].wk_scale) {
|
||||
Kcur = ggml_mul(ctx0, Kcur, model.layers[il].wk_scale);
|
||||
}
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
@ -13046,7 +13048,9 @@ struct llm_build_context {
|
||||
|
||||
// B1.V
|
||||
struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur);
|
||||
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
|
||||
if (model.layers[il].wv_scale) {
|
||||
Vcur = ggml_mul(ctx0, Vcur, model.layers[il].wv_scale);
|
||||
}
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
@ -13077,7 +13081,9 @@ struct llm_build_context {
|
||||
cb(cur, "attn_sub_norm", il);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wo, cur);
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
|
||||
if (model.layers[il].wo_scale) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].wo_scale);
|
||||
}
|
||||
if (model.layers[il].bo) {
|
||||
cur = ggml_add(ctx0, cur, model.layers[il].bo);
|
||||
}
|
||||
@ -13114,7 +13120,9 @@ struct llm_build_context {
|
||||
cb(cur, "ffn_sub_norm", il);
|
||||
|
||||
cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].ffn_down, cur);
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
|
||||
if (model.layers[il].ffn_down_scale) {
|
||||
cur = ggml_mul(ctx0, cur, model.layers[il].ffn_down_scale);
|
||||
}
|
||||
cb(cur, "ffn_down", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
@ -15631,8 +15639,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||
llama_ftype ftype = params->ftype;
|
||||
|
||||
switch (params->ftype) {
|
||||
case LLAMA_FTYPE_MOSTLY_Q1_3: default_type = GGML_TYPE_Q1_3; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q2_2: default_type = GGML_TYPE_Q2_2; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break;
|
||||
|
@ -15,13 +15,13 @@
|
||||
|
||||
constexpr float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.015625f; // TODO: change to 0.01f
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_TERNARY = 0.01f;
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
|
||||
constexpr float MAX_QUANTIZATION_TOTAL_ERROR_3BITS_XXS = 0.0050f;
|
||||
constexpr float MAX_DOT_PRODUCT_ERROR = 0.02f;
|
||||
constexpr float MAX_DOT_PRODUCT_ERROR_LOWBIT = 0.04f;
|
||||
constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.5f; // TODO: change to 0.15f
|
||||
constexpr float MAX_DOT_PRODUCT_ERROR_TERNARY = 0.15f;
|
||||
|
||||
static const char* RESULT_STR[] = {"ok", "FAILED"};
|
||||
|
||||
@ -146,8 +146,6 @@ int main(int argc, char * argv[]) {
|
||||
if (qfns.from_float && qfns.to_float) {
|
||||
const float total_error = total_quantization_error(qfns, test_size, test_data.data());
|
||||
const float max_quantization_error =
|
||||
type == GGML_TYPE_Q1_3 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
|
||||
type == GGML_TYPE_Q2_2 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
|
||||
type == GGML_TYPE_TQ1_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
|
||||
type == GGML_TYPE_TQ2_0 ? MAX_QUANTIZATION_TOTAL_ERROR_TERNARY :
|
||||
type == GGML_TYPE_Q2_K ? MAX_QUANTIZATION_TOTAL_ERROR_2BITS :
|
||||
@ -172,7 +170,7 @@ int main(int argc, char * argv[]) {
|
||||
const float max_allowed_error = type == GGML_TYPE_Q2_K || type == GGML_TYPE_IQ2_XS || type == GGML_TYPE_IQ2_XXS ||
|
||||
type == GGML_TYPE_IQ3_XXS || type == GGML_TYPE_IQ3_S || type == GGML_TYPE_IQ2_S
|
||||
? MAX_DOT_PRODUCT_ERROR_LOWBIT
|
||||
: type == GGML_TYPE_Q2_2 || type == GGML_TYPE_Q1_3 || type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
|
||||
: type == GGML_TYPE_TQ1_0 || type == GGML_TYPE_TQ2_0
|
||||
? MAX_DOT_PRODUCT_ERROR_TERNARY
|
||||
: MAX_DOT_PRODUCT_ERROR;
|
||||
failed = !(vec_dot_error < max_allowed_error);
|
||||
|
Loading…
Reference in New Issue
Block a user