ggml : remove Q4_2 bit shuffling (WIP, BROKEN)

This commit is contained in:
Georgi Gerganov 2023-05-04 22:07:40 +03:00
parent 086cfea11f
commit a6a1d96c91
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

109
ggml.c
View File

@ -884,7 +884,7 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
static const int qk = QK4_0; static const int qk = QK4_0;
assert(qk / 16 == 0); assert(qk / 16 == 0);
assert(k % qk == 0); assert( k % qk == 0);
const int nb = k / qk; const int nb = k / qk;
@ -919,7 +919,7 @@ static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * r
const int qk = QK4_1; const int qk = QK4_1;
assert(qk / 16 == 0); assert(qk / 16 == 0);
assert(k % qk == 0); assert( k % qk == 0);
const int nb = k / qk; const int nb = k / qk;
@ -952,16 +952,19 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) { static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * restrict y, int k) {
assert(k % QK4_2 == 0); static const int qk = QK4_2;
const int nb = k / QK4_2; assert(qk / 16 == 0);
assert( k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
float max = 0.0f; float max = 0.0f;
for (int l = 0; l < QK4_2; l++) { for (int l = 0; l < qk; l++) {
const float v = x[i*QK4_2 + l]; const float v = x[i*qk + l];
if (amax < fabsf(v)) { if (amax < fabsf(v)) {
amax = fabsf(v); amax = fabsf(v);
max = v; max = v;
@ -969,31 +972,17 @@ static void quantize_row_q4_2_reference(const float * restrict x, block_q4_2 * r
} }
const float d = max / -8; const float d = max / -8;
const float id = d ? 1.0f/d : 0.0f; const float id = d ? 1.0f/d : 0.0f;
y[i].d = GGML_FP32_TO_FP16(d); y[i].d = GGML_FP32_TO_FP16(d);
for (int l = 0; l < QK4_2; l += 2) { uint64_t qs[QK4_2 / 16] = {0};
const float v0 = x[i*QK4_2 + l + 0]*id;
const float v1 = x[i*QK4_2 + l + 1]*id;
const uint8_t vi0 = MIN(15, (uint8_t)(v0 + 8.5f)); nibbles_from_floats_64_0(qk, x + i*qk, id, y[i].qs, qs);
const uint8_t vi1 = MIN(15, (uint8_t)(v1 + 8.5f));
assert(vi0 < 16);
assert(vi1 < 16);
y[i].qs[l/2] = vi0 | (vi1 << 4);
}
} }
} }
static void quantize_row_q4_2(const float * restrict x, void * restrict vy, int k) { static void quantize_row_q4_2(const float * restrict x, void * restrict y, int k) {
assert(k % QK4_2 == 0);
block_q4_2 * restrict y = vy;
quantize_row_q4_2_reference(x, y, k); quantize_row_q4_2_reference(x, y, k);
} }
@ -1451,7 +1440,7 @@ static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict
static const int qk = QK4_0; static const int qk = QK4_0;
assert(qk / 16 == 0); assert(qk / 16 == 0);
assert(k % qk == 0); assert( k % qk == 0);
const int nb = k / qk; const int nb = k / qk;
@ -1472,7 +1461,7 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
static const int qk = QK4_1; static const int qk = QK4_1;
assert(qk / 16 == 0); assert(qk / 16 == 0);
assert(k % qk == 0); assert( k % qk == 0);
const int nb = k / qk; const int nb = k / qk;
@ -1490,31 +1479,23 @@ static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict
} }
} }
static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) { static void dequantize_row_q4_2(const block_q4_2 * restrict x, float * restrict y, int k) {
assert(k % QK4_2 == 0); static const int qk = QK4_2;
const int nb = k / QK4_2;
const block_q4_2 * restrict x = vx; assert(qk / 16 == 0);
assert( k % qk == 0);
const int nb = k / qk;
uint64_t qs[QK4_2 / 8];
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d = GGML_FP16_TO_FP32(x[i].d); const float d = GGML_FP16_TO_FP32(x[i].d);
const uint8_t * restrict pp = x[i].qs; const uint8_t * qsp = bytes_from_nibbles_64(qk, x[i].qs, qs);
for (int l = 0; l < QK4_2; l += 2) { for (int l = 0; l < qk; ++l) {
const uint8_t vi = pp[l/2]; y[i*qk + l] = (qsp[l] - 8)*d;
const int8_t vi0 = vi & 0x0F;
const int8_t vi1 = vi >> 4;
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
y[i*QK4_2 + l + 0] = v0;
y[i*QK4_2 + l + 1] = v1;
assert(!isnan(y[i*QK4_2 + l + 0]));
assert(!isnan(y[i*QK4_2 + l + 1]));
} }
} }
} }
@ -1634,7 +1615,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_1, .vec_dot_type = GGML_TYPE_Q8_1,
}, },
[GGML_TYPE_Q4_2] = { [GGML_TYPE_Q4_2] = {
.dequantize_row_q = dequantize_row_q4_2, .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_2,
.quantize_row_q = quantize_row_q4_2, .quantize_row_q = quantize_row_q4_2,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_2_reference,
.quantize_row_q_dot = quantize_row_q8_0, .quantize_row_q_dot = quantize_row_q8_0,
@ -2559,11 +2540,13 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
} }
static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK8_0; const int qk = QK8_0;
const int nb = n / qk;
assert(n % QK8_0 == 0); assert(n % qk == 0);
assert(nb % 2 == 0); assert(nb % 2 == 0);
assert(QK8_0 == 2*QK4_2);
assert(qk == 2*QK4_2);
const block_q4_2 * restrict x = vx; const block_q4_2 * restrict x = vx;
const block_q8_0 * restrict y = vy; const block_q8_0 * restrict y = vy;
@ -2599,12 +2582,6 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
// interleave
const int8x16_t v0_0lz = vzip1q_s8(v0_0ls, v0_0hs);
const int8x16_t v0_0hz = vzip2q_s8(v0_0ls, v0_0hs);
const int8x16_t v0_1lz = vzip1q_s8(v0_1ls, v0_1hs);
const int8x16_t v0_1hz = vzip2q_s8(v0_1ls, v0_1hs);
// load y // load y
const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
@ -2613,22 +2590,22 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void *
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
sumv0 = vmlaq_n_f32(sumv0, vaddq_f32( sumv0 = vmlaq_n_f32(sumv0, vaddq_f32(
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)), vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l)), GGML_FP16_TO_FP32(x0_0->d)),
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d); vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hs, v1_0h)), GGML_FP16_TO_FP32(x0_1->d))), y0->d);
sumv1 = vmlaq_n_f32(sumv1, vaddq_f32( sumv1 = vmlaq_n_f32(sumv1, vaddq_f32(
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)), vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l)), GGML_FP16_TO_FP32(x1_0->d)),
vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d); vmulq_n_f32(vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hs, v1_1h)), GGML_FP16_TO_FP32(x1_1->d))), y1->d);
#else #else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h));
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));