mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
RMSE-optimized quants for all quantization types
By default this new option is ON. One can turn it off by setting LLAMA_NO_RMSE. With this option enabled, the Q4_3 quantization results in a perplexity of 6.0344, so 0.0273 lower than simple Q4_3 quantization.
This commit is contained in:
parent
0e018fe008
commit
e435bfd93c
@ -68,6 +68,9 @@ option(LLAMA_ACCELERATE "llama: enable Accelerate framework"
|
||||
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
|
||||
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
|
||||
|
||||
# RMSE minimization when quantizing
|
||||
option(LLAMA_NO_RMSE "llama: disable RMSE minimization" OFF)
|
||||
|
||||
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
|
||||
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
|
||||
|
||||
@ -99,6 +102,10 @@ if (NOT MSVC)
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (LLAMA_NO_RMSE)
|
||||
add_compile_definitions(GGML_NO_RMSE)
|
||||
endif()
|
||||
|
||||
if (APPLE AND LLAMA_ACCELERATE)
|
||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||
if (ACCELERATE_FRAMEWORK)
|
||||
|
4
Makefile
4
Makefile
@ -134,6 +134,10 @@ ifneq ($(filter armv8%,$(UNAME_M)),)
|
||||
CFLAGS += -mfp16-format=ieee -mno-unaligned-access
|
||||
endif
|
||||
|
||||
ifdef LLAMA_NO_RMSE
|
||||
CFLAGS += -DGGML_NO_RMSE
|
||||
endif
|
||||
|
||||
#
|
||||
# Print build information
|
||||
#
|
||||
|
353
ggml.c
353
ggml.c
@ -670,10 +670,107 @@ typedef struct {
|
||||
} block_q8_0;
|
||||
static_assert(sizeof(block_q8_0) == 3*sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
|
||||
|
||||
#ifndef GGML_NO_RMSE
|
||||
// Stuff for RMSE-minimizing quantization
|
||||
static inline int nearest_int(float fval) {
|
||||
assert(fval <= 4194303.f);
|
||||
float val = fval + 12582912.f;
|
||||
int i; memcpy(&i, &val, sizeof(int));
|
||||
return (i & 0x007fffff) - 0x00400000;
|
||||
}
|
||||
|
||||
static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
|
||||
const float * restrict candidates, int8_t * restrict L) {
|
||||
assert (nmin >= INT8_MIN);
|
||||
assert (nmax <= INT8_MAX);
|
||||
float amax = 0;
|
||||
for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
|
||||
if (!amax) { // all zero
|
||||
for (int i=0; i<n; ++i) L[i] = 0;
|
||||
return 1.f;
|
||||
}
|
||||
float best = 0, bestScale = 0;
|
||||
for (int si=0; si<nCandidates; ++si) {
|
||||
float iscale = candidates[si]/amax;
|
||||
float sumlxP = 0; int suml2P = 0;
|
||||
float sumlxM = 0; int suml2M = 0;
|
||||
for (int i=0; i<n; ++i) {
|
||||
int l = nearest_int(iscale*X[i]);
|
||||
int lp = MAX(nmin, MIN(nmax, +l));
|
||||
int lm = MAX(nmin, MIN(nmax, -l));
|
||||
sumlxP += X[i]*lp; suml2P += lp*lp;
|
||||
sumlxM += X[i]*lm; suml2M += lm*lm;
|
||||
}
|
||||
float sumlxP2 = sumlxP*sumlxP;
|
||||
float sumlxM2 = sumlxM*sumlxM;
|
||||
if (sumlxP2*suml2M > sumlxM2*suml2P) {
|
||||
if (sumlxP2 > best*suml2P) {
|
||||
best = sumlxP2/suml2P; bestScale = iscale;
|
||||
}
|
||||
} else {
|
||||
if (sumlxM2 > best*suml2M) {
|
||||
best = sumlxM2/suml2M; bestScale = -iscale;
|
||||
}
|
||||
}
|
||||
}
|
||||
float sumlx = 0; int suml2 = 0;
|
||||
for (int i=0; i<n; ++i) {
|
||||
int l = nearest_int(bestScale*X[i]);
|
||||
l = MAX(nmin, MIN(nmax, l));
|
||||
sumlx += X[i]*l; suml2 += l*l;
|
||||
L[i] = l;
|
||||
}
|
||||
float scale = sumlx/suml2;
|
||||
return scale;
|
||||
}
|
||||
static float kquantize_q4_with_bound_plus(int n, int nmax, const float * restrict X, int nCandidates,
|
||||
const float * restrict candidates, int8_t * restrict L) {
|
||||
assert (nmax <= INT8_MAX);
|
||||
float amax = 0;
|
||||
for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
|
||||
if (!amax) { // all zero
|
||||
for (int i=0; i<n; ++i) L[i] = 0;
|
||||
return 1.f;
|
||||
}
|
||||
float best = 0, bestScale = 0;
|
||||
for (int si=0; si<nCandidates; ++si) {
|
||||
float iscale = candidates[si]/amax;
|
||||
float sumlx = 0; int suml2 = 0;
|
||||
for (int i=0; i<n; ++i) {
|
||||
int l = nearest_int(iscale*X[i]);
|
||||
l = MAX(0, MIN(nmax, l));
|
||||
sumlx += X[i]*l; suml2 += l*l;
|
||||
}
|
||||
float sumlx2 = sumlx*sumlx;
|
||||
if (sumlx2 > best*suml2) {
|
||||
best = sumlx2/suml2; bestScale = iscale;
|
||||
}
|
||||
}
|
||||
float sumlx = 0; int suml2 = 0;
|
||||
for (int i=0; i<n; ++i) {
|
||||
int l = nearest_int(bestScale*X[i]);
|
||||
l = MAX(0, MIN(nmax, l));
|
||||
sumlx += X[i]*l; suml2 += l*l;
|
||||
L[i] = l;
|
||||
}
|
||||
float scale = sumlx/suml2;
|
||||
return scale;
|
||||
}
|
||||
|
||||
static void quantize_row_q4_0_rmse(const float * restrict x, block_q4_0 * restrict y, int k);
|
||||
static void quantize_row_q4_1_rmse(const float * restrict x, block_q4_1 * restrict y, int k);
|
||||
static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k);
|
||||
static void quantize_row_q4_3_rmse(const float * restrict x, block_q4_3 * restrict y, int k);
|
||||
#endif
|
||||
|
||||
|
||||
// reference implementation for deterministic creation of model files
|
||||
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
||||
assert(k % QK4_0 == 0);
|
||||
#ifndef GGML_NO_RMSE
|
||||
quantize_row_q4_0_rmse(x, y, k);
|
||||
return;
|
||||
#endif
|
||||
const int nb = k / QK4_0;
|
||||
|
||||
uint8_t pp[QK4_0/2];
|
||||
@ -714,6 +811,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
||||
|
||||
block_q4_0 * restrict y = vy;
|
||||
|
||||
#ifndef GGML_NO_RMSE
|
||||
quantize_row_q4_0_rmse(x, y, k);
|
||||
return;
|
||||
#endif
|
||||
|
||||
#if defined(__POWER9_VECTOR__)
|
||||
const vector float v85 = vec_splats(8.5f);
|
||||
for (int i = 0; i < nb; i++) {
|
||||
@ -964,6 +1066,10 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
||||
const int nb = k / QK4_1;
|
||||
|
||||
block_q4_1 * restrict y = vy;
|
||||
#ifndef GGML_NO_RMSE
|
||||
quantize_row_q4_1_rmse(x, y, k);
|
||||
return;
|
||||
#endif
|
||||
|
||||
uint8_t pp[QK4_1/2];
|
||||
|
||||
@ -1007,6 +1113,11 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
||||
|
||||
block_q4_1 * restrict y = vy;
|
||||
|
||||
#ifndef GGML_NO_RMSE
|
||||
quantize_row_q4_1_rmse(x, y, k);
|
||||
return;
|
||||
#endif
|
||||
|
||||
#if defined(__AVX2__)
|
||||
for (int i = 0; i < nb; i++) {
|
||||
// Load elements into 4 AVX vectors
|
||||
@ -1127,6 +1238,11 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
||||
static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
|
||||
assert(k % QK4_2 == 0);
|
||||
|
||||
#ifndef GGML_NO_RMSE
|
||||
quantize_row_q4_2_rmse(x, y, k);
|
||||
return;
|
||||
#endif
|
||||
|
||||
const int nb = k / QK4_2;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
@ -1158,56 +1274,121 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
|
||||
}
|
||||
}
|
||||
|
||||
static inline int nearest_int(float fval) {
|
||||
assert(fval <= 4194303.f);
|
||||
float val = fval + 12582912.f;
|
||||
int i; memcpy(&i, &val, sizeof(int));
|
||||
return (i & 0x007fffff) - 0x00400000;
|
||||
#ifndef GGML_NO_RMSE
|
||||
static void quantize_row_q41_helper(int n, const float * restrict x, int8_t * restrict L, float * restrict tmp_x,
|
||||
float * result_a, float * result_b) {
|
||||
#define CANDIDATE_COUNT 20
|
||||
static const float candidates[CANDIDATE_COUNT] = {
|
||||
+18.5f, +17.75f, +17.25f, +16.75f, +16.25f, +15.75f, +15.25f, +15.0f, +14.75f,
|
||||
+14.25f, +13.75f, +13.25f, +12.75f, +12.25, +11.75f, +11.25f, +10.75f, +10.25f, +9.25, +8.25
|
||||
};
|
||||
const float epsilon = 1e-5;
|
||||
float min = x[0], max = x[0];
|
||||
for (int j=1; j<n; ++j) {
|
||||
if (x[j] < min) min = x[j];
|
||||
if (x[j] > max) max = x[j];
|
||||
}
|
||||
if (max == min) {
|
||||
*result_a = min;
|
||||
*result_b = 1.f;
|
||||
for (int j=0; j<n; ++j) L[j] = 0;
|
||||
return;
|
||||
}
|
||||
float a = min, b = (max - min)/15;
|
||||
float bi = 15/(max - min);
|
||||
float simple_err = 0;
|
||||
for (int j=0; j<n; ++j) {
|
||||
L[j] = nearest_int(bi*(x[j] - min));
|
||||
float diff = x[j] - a - b*L[j];
|
||||
simple_err += diff*diff;
|
||||
}
|
||||
for (int itry=0; itry<3; ++itry) {
|
||||
for (int j=0; j<n; ++j) tmp_x[j] = x[j] - a;
|
||||
kquantize_q4_with_bound_plus(n, 15, tmp_x, CANDIDATE_COUNT, candidates, L);
|
||||
float sumlx = 0, sumx = 0;
|
||||
int suml2 = 0, suml = 0;
|
||||
for (int j=0; j<n; ++j) {
|
||||
int l = L[j];
|
||||
sumlx += x[j]*l;
|
||||
suml2 += l*l;
|
||||
suml += l;
|
||||
sumx += x[j];
|
||||
}
|
||||
int64_t D = suml2*n - suml*suml;
|
||||
if (!D) break;
|
||||
float aold = a, bold = b;
|
||||
a = (sumx*suml2 - sumlx*suml)/D;
|
||||
b = (sumlx*n - sumx*suml)/D;
|
||||
if (itry > 0 && fabsf(a - aold) < epsilon*fabsf(aold) && fabsf(b - bold) < epsilon*fabsf(bold)) break;
|
||||
}
|
||||
float err = 0;
|
||||
for (int j=0; j<n; ++j) {
|
||||
float diff = x[j] - a - b*L[j];
|
||||
err += diff*diff;
|
||||
}
|
||||
if (err > simple_err) {
|
||||
a = min; b = (max - min)/15;
|
||||
for (int j=0; j<n; ++j) {
|
||||
L[j] = nearest_int(bi*(x[j] - min));
|
||||
}
|
||||
}
|
||||
*result_a = a;
|
||||
*result_b = b;
|
||||
#undef CANDIDATE_COUNT
|
||||
}
|
||||
static void quantize_row_q4_1_rmse(const float * restrict x, block_q4_1 * restrict y, int k) {
|
||||
assert(k % QK4_1 == 0);
|
||||
|
||||
int8_t L[QK4_1];
|
||||
float tmp_x[QK4_1];
|
||||
|
||||
const int nb = k / QK4_1;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float a, b;
|
||||
quantize_row_q41_helper(QK4_1, x, L, tmp_x, &a, &b);
|
||||
y[i].d = b;
|
||||
y[i].m = a;
|
||||
|
||||
for (int l = 0; l < QK4_1; l += 2) {
|
||||
const uint8_t vi0 = (uint8_t)(L[l+0]);
|
||||
const uint8_t vi1 = (uint8_t)(L[l+1]);
|
||||
|
||||
assert(vi0 < 16);
|
||||
assert(vi1 < 16);
|
||||
|
||||
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
||||
}
|
||||
|
||||
static float kquantize_q4_with_bounds(int n, int nmin, int nmax, const float * restrict X, int nCandidates,
|
||||
const float * restrict candidates, int8_t * restrict L) {
|
||||
assert (nmin >= INT8_MIN);
|
||||
assert (nmax <= INT8_MAX);
|
||||
float amax = 0;
|
||||
for (int i=0; i<n; ++i) amax = MAX(amax, fabsf(X[i]));
|
||||
if (!amax) { // all zero
|
||||
for (int i=0; i<n; ++i) L[i] = 0;
|
||||
return 1.f;
|
||||
}
|
||||
float best = 0, bestScale = 0;
|
||||
for (int si=0; si<nCandidates; ++si) {
|
||||
float iscale = candidates[si]/amax;
|
||||
float sumlxP = 0; int suml2P = 0;
|
||||
float sumlxM = 0; int suml2M = 0;
|
||||
for (int i=0; i<n; ++i) {
|
||||
int l = nearest_int(iscale*X[i]);
|
||||
int lp = MAX(nmin, MIN(nmax, +l));
|
||||
int lm = MAX(nmin, MIN(nmax, -l));
|
||||
sumlxP += X[i]*lp; suml2P += lp*lp;
|
||||
sumlxM += X[i]*lm; suml2M += lm*lm;
|
||||
}
|
||||
float sumlxP2 = sumlxP*sumlxP;
|
||||
float sumlxM2 = sumlxM*sumlxM;
|
||||
if (sumlxP2*suml2M > sumlxM2*suml2P) {
|
||||
if (sumlxP2 > best*suml2P) {
|
||||
best = sumlxP2/suml2P; bestScale = iscale;
|
||||
}
|
||||
} else {
|
||||
if (sumlxM2 > best*suml2M) {
|
||||
best = sumlxM2/suml2M; bestScale = -iscale;
|
||||
x += QK4_1;
|
||||
}
|
||||
}
|
||||
static void quantize_row_q4_3_rmse(const float * restrict x, block_q4_3 * restrict y, int k) {
|
||||
assert(k % QK4_3 == 0);
|
||||
|
||||
int8_t L[QK4_3];
|
||||
float tmp_x[QK4_3];
|
||||
|
||||
const int nb = k / QK4_3;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float a, b;
|
||||
quantize_row_q41_helper(QK4_3, x, L, tmp_x, &a, &b);
|
||||
y[i].d = GGML_FP32_TO_FP16(b);
|
||||
y[i].m = GGML_FP32_TO_FP16(a);
|
||||
|
||||
for (int l = 0; l < QK4_3; l += 2) {
|
||||
const uint8_t vi0 = (uint8_t)(L[l+0]);
|
||||
const uint8_t vi1 = (uint8_t)(L[l+1]);
|
||||
|
||||
assert(vi0 < 16);
|
||||
assert(vi1 < 16);
|
||||
|
||||
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
||||
}
|
||||
float sumlx = 0; int suml2 = 0;
|
||||
for (int i=0; i<n; ++i) {
|
||||
int l = nearest_int(bestScale*X[i]);
|
||||
l = MAX(nmin, MIN(nmax, l));
|
||||
sumlx += X[i]*l; suml2 += l*l;
|
||||
L[i] = l;
|
||||
|
||||
x += QK4_3;
|
||||
}
|
||||
float scale = sumlx/suml2;
|
||||
return scale;
|
||||
}
|
||||
|
||||
static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restrict y, int k) {
|
||||
@ -1236,19 +1417,48 @@ static void quantize_row_q4_2_rmse(const float * restrict x, block_q4_2 * restri
|
||||
x += QK4_2;
|
||||
}
|
||||
}
|
||||
static void quantize_row_q4_0_rmse(const float * restrict x, block_q4_0 * restrict y, int k) {
|
||||
static const float candidates[CANDIDATE_COUNT] = { +8.7f, +8.3f, +8.1f, +7.8f, +7.3f, +7.0f, +6.3f, +5.7f };
|
||||
assert(k % QK4_0 == 0);
|
||||
|
||||
int8_t L[QK4_0];
|
||||
|
||||
const int nb = k / QK4_0;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float scale = kquantize_q4_with_bounds(QK4_0, -8, 7, x, CANDIDATE_COUNT, candidates, L);
|
||||
y[i].d = scale;
|
||||
|
||||
for (int l = 0; l < QK4_0; l += 2) {
|
||||
const uint8_t vi0 = (uint8_t)(L[l+0] + 8);
|
||||
const uint8_t vi1 = (uint8_t)(L[l+1] + 8);
|
||||
|
||||
assert(vi0 < 16);
|
||||
assert(vi1 < 16);
|
||||
|
||||
y[i].qs[l/2] = vi0 | (vi1 << 4);
|
||||
}
|
||||
|
||||
x += QK4_0;
|
||||
}
|
||||
}
|
||||
#undef CANDIDATE_COUNT
|
||||
#endif
|
||||
|
||||
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) {
|
||||
assert(k % QK4_2 == 0);
|
||||
|
||||
block_q4_2 * restrict y = vy;
|
||||
|
||||
//quantize_row_q4_2_reference(x, y, k);
|
||||
// This produces the exact same format, just better match to the input floats ("better" as measured by RMSE)
|
||||
quantize_row_q4_2_rmse(x, y, k);
|
||||
quantize_row_q4_2_reference(x, y, k);
|
||||
}
|
||||
|
||||
static void quantize_row_q4_3_reference(const float * restrict x, block_q4_3 * restrict y, int k) {
|
||||
assert(k % QK4_3 == 0);
|
||||
#ifndef GGML_NO_RMSE
|
||||
quantize_row_q4_3_rmse(x, y, k);
|
||||
return;
|
||||
#endif
|
||||
const int nb = k / QK4_3;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
@ -1787,7 +1997,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
||||
},
|
||||
[GGML_TYPE_Q4_1] = {
|
||||
.dequantize_row_q = dequantize_row_q4_1,
|
||||
.quantize_row_q = quantize_row_q4_1,
|
||||
.quantize_row_q = (quantize_row_q_t) quantize_row_q4_1,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
|
||||
.quantize_row_q_dot = quantize_row_q8_0,
|
||||
.vec_dot_q = ggml_vec_dot_q4_1_q8_0,
|
||||
@ -1795,7 +2005,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_Q4_2] = {
|
||||
.dequantize_row_q = dequantize_row_q4_2,
|
||||
.quantize_row_q = quantize_row_q4_2,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_rmse, //quantize_row_q4_2_reference,
|
||||
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
|
||||
.quantize_row_q_dot = quantize_row_q8_0,
|
||||
.vec_dot_q = ggml_vec_dot_q4_2_q8_0,
|
||||
},
|
||||
@ -12074,6 +12284,16 @@ enum ggml_opt_result ggml_opt(
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static void collect_quant_histogram(int n, const uint8_t * restrict qs, int64_t * hist) {
|
||||
for (int l = 0; l < n; l += 2) {
|
||||
const uint8_t vi0 = qs[l/2] & 0xF;
|
||||
const uint8_t vi1 = qs[l/2] >> 4;
|
||||
|
||||
hist[vi0]++;
|
||||
hist[vi1]++;
|
||||
}
|
||||
}
|
||||
|
||||
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
||||
assert(k % QK4_0 == 0);
|
||||
const int nb = k / QK4_0;
|
||||
@ -12084,13 +12304,7 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
||||
quantize_row_q4_0_reference(src + j, y, k);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int l = 0; l < QK4_0; l += 2) {
|
||||
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
||||
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
||||
|
||||
hist[vi0]++;
|
||||
hist[vi1]++;
|
||||
}
|
||||
collect_quant_histogram(QK4_0, y[i].qs, hist);
|
||||
}
|
||||
}
|
||||
|
||||
@ -12107,13 +12321,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
||||
quantize_row_q4_1_reference(src + j, y, k);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int l = 0; l < QK4_1; l += 2) {
|
||||
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
||||
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
||||
|
||||
hist[vi0]++;
|
||||
hist[vi1]++;
|
||||
}
|
||||
collect_quant_histogram(QK4_1, y[i].qs, hist);
|
||||
}
|
||||
}
|
||||
|
||||
@ -12127,17 +12335,10 @@ size_t ggml_quantize_q4_2(const float * src, void * dst, int n, int k, int64_t *
|
||||
for (int j = 0; j < n; j += k) {
|
||||
block_q4_2 * restrict y = (block_q4_2 *)dst + j/QK4_2;
|
||||
|
||||
//quantize_row_q4_2_reference(src + j, y, k);
|
||||
quantize_row_q4_2_rmse(src + j, y, k);
|
||||
quantize_row_q4_2_reference(src + j, y, k);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int l = 0; l < QK4_2; l += 2) {
|
||||
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
||||
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
||||
|
||||
hist[vi0]++;
|
||||
hist[vi1]++;
|
||||
}
|
||||
collect_quant_histogram(QK4_2, y[i].qs, hist);
|
||||
}
|
||||
}
|
||||
|
||||
@ -12154,13 +12355,7 @@ size_t ggml_quantize_q4_3(const float * src, void * dst, int n, int k, int64_t *
|
||||
quantize_row_q4_3_reference(src + j, y, k);
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
for (int l = 0; l < QK4_3; l += 2) {
|
||||
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
||||
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
||||
|
||||
hist[vi0]++;
|
||||
hist[vi1]++;
|
||||
}
|
||||
collect_quant_histogram(QK4_3, y[i].qs, hist);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user