ggml : remove Q4_1 bit shuffling (ARM NEON + reference)

This commit is contained in:
Georgi Gerganov 2023-05-04 20:53:14 +03:00
parent a546dc6d60
commit edb6c8bb66
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

418
ggml.c
View File

@ -837,17 +837,19 @@ static_assert(sizeof(block_q8_1) == 3*sizeof(float) + QK8_1, "wrong q8_1 block s
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) { static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
assert(QK4_0 / 16 == 0); static const int qk = QK4_0;
assert(k % QK4_0 == 0);
const int nb = k / QK4_0; assert(qk / 16 == 0);
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float amax = 0.0f; // absolute max float amax = 0.0f; // absolute max
float max = 0.0f; float max = 0.0f;
for (int l = 0; l < QK4_0; l++) { for (int l = 0; l < qk; l++) {
const float v = x[i*QK4_0 + l]; const float v = x[i*qk + l];
if (amax < fabsf(v)) { if (amax < fabsf(v)) {
amax = fabsf(v); amax = fabsf(v);
max = v; max = v;
@ -862,9 +864,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
uint64_t qs[QK4_0 / 16] = {0}; uint64_t qs[QK4_0 / 16] = {0};
// pack first half of weights into low nibbles and second half into high nibbles // pack first half of weights into low nibbles and second half into high nibbles
for (int l = 0; l < QK4_0/2; ++l) { for (int l = 0; l < qk/2; ++l) {
const float v0 = x[i*QK4_0 + 0 + l]*id; const float v0 = x[i*qk + 0 + l]*id;
const float v1 = x[i*QK4_0 + QK4_0/2 + l]*id; const float v1 = x[i*qk + qk/2 + l]*id;
const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f)); const uint64_t vi0 = MIN(15, (int8_t)(v0 + 8.5f));
const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f)); const uint64_t vi1 = MIN(15, (int8_t)(v1 + 8.5f));
@ -877,28 +879,25 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
} }
} }
static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) { static void quantize_row_q4_0(const float * restrict x, void * restrict y, int k) {
assert(k % QK4_0 == 0);
block_q4_0 * restrict y = vy;
quantize_row_q4_0_reference(x, y, k); quantize_row_q4_0_reference(x, y, k);
} }
static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) { static void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int k) {
assert(k % QK4_1 == 0); const int qk = QK4_1;
const int nb = k / QK4_1;
block_q4_1 * restrict y = vy; assert(qk / 16 == 0);
assert(k % qk == 0);
uint8_t pp[QK4_1/2]; const int nb = k / qk;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
float min = FLT_MAX; float min = FLT_MAX;
float max = -FLT_MAX; float max = -FLT_MAX;
for (int l = 0; l < QK4_1; l++) { for (int l = 0; l < qk; l++) {
const float v = x[i*QK4_1 + l]; const float v = x[i*qk + l];
if (v < min) min = v; if (v < min) min = v;
if (v > max) max = v; if (v > max) max = v;
} }
@ -909,144 +908,26 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
y[i].d = d; y[i].d = d;
y[i].m = min; y[i].m = min;
for (int l = 0; l < QK4_1; l += 2) { uint64_t qs[QK4_1 / 16] = {0};
const float v0 = (x[i*QK4_1 + l + 0] - min)*id;
const float v1 = (x[i*QK4_1 + l + 1] - min)*id;
const uint8_t vi0 = roundf(v0); // pack first half of weights into low nibbles and second half into high nibbles
const uint8_t vi1 = roundf(v1); for (int l = 0; l < qk/2; ++l) {
const float v0 = (x[i*qk + 0 + l] - min)*id;
const float v1 = (x[i*qk + qk/2 + l] - min)*id;
assert(vi0 < 16); const uint64_t vi0 = MIN(15, (int8_t)(v0 + 0.5f));
assert(vi1 < 16); const uint64_t vi1 = MIN(15, (int8_t)(v1 + 0.5f));
pp[l/2] = vi0 | (vi1 << 4); qs[l/8] |= vi0 << (8*(l & 7));
qs[l/8] |= vi1 << (8*(l & 7) + 4);
} }
memcpy(y[i].qs, pp, sizeof(pp)); memcpy(y[i].qs, qs, sizeof(qs));
} }
} }
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) { static void quantize_row_q4_1(const float * restrict x, void * restrict y, int k) {
assert(k % QK4_1 == 0); quantize_row_q4_1_reference(x, y, k);
const int nb = k / QK4_1;
block_q4_1 * restrict y = vy;
#if defined(__AVX2__)
for (int i = 0; i < nb; i++) {
// Load elements into 4 AVX vectors
__m256 v0 = _mm256_loadu_ps( x );
__m256 v1 = _mm256_loadu_ps( x + 8 );
__m256 v2 = _mm256_loadu_ps( x + 16 );
__m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32;
// Compute max for the block
__m256 vmax;
vmax = _mm256_max_ps( v0, v1 );
vmax = _mm256_max_ps( vmax, v2 );
vmax = _mm256_max_ps( vmax, v3 );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( vmax, 1 ), _mm256_castps256_ps128( vmax ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
// Compute min for the block
__m256 vmin;
vmin = _mm256_min_ps( v0, v1 );
vmin = _mm256_min_ps( vmin, v2 );
vmin = _mm256_min_ps( vmin, v3 );
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( vmin, 1 ), _mm256_castps256_ps128( vmin ) );
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
const float minScalar = _mm_cvtss_f32( min4 );
// Quantize these floats
const float d = (maxScalar - minScalar) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].m = minScalar;
y[i].d = d;
// x = (x-min)*id
const __m256 mul = _mm256_set1_ps( id );
const __m256 off = _mm256_set1_ps( minScalar );
v0 = _mm256_mul_ps( _mm256_sub_ps( v0, off ), mul );
v1 = _mm256_mul_ps( _mm256_sub_ps( v1, off ), mul );
v2 = _mm256_mul_ps( _mm256_sub_ps( v2, off ), mul );
v3 = _mm256_mul_ps( _mm256_sub_ps( v3, off ), mul );
// Round to nearest integer
v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST );
v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST );
v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST );
v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST );
// Convert floats to integers
__m256i i0 = _mm256_cvtps_epi32( v0 );
__m256i i1 = _mm256_cvtps_epi32( v1 );
__m256i i2 = _mm256_cvtps_epi32( v2 );
__m256i i3 = _mm256_cvtps_epi32( v3 );
// Convert int32 to int16
i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15
i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31
// Convert int16 to int8
i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31
// We got our precious signed bytes, but the order is now wrong
// These AVX2 pack instructions process 16-byte pieces independently
// The following instruction is fixing the order
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
// Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( i0 );
_mm_storeu_si128( ( __m128i* )y[i].qs, res );
}
#elif __ARM_NEON
for (int i = 0; i < nb; i++) {
float32x4_t srcv[8];
float32x4_t minv[8];
float32x4_t maxv[8];
for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*QK4_1 + 4*l);
for (int l = 0; l < 4; l++) minv[2*l] = vminq_f32(srcv[2*l], srcv[2*l + 1]);
for (int l = 0; l < 2; l++) minv[4*l] = vminq_f32(minv[4*l], minv[4*l + 2]);
for (int l = 0; l < 1; l++) minv[8*l] = vminq_f32(minv[8*l], minv[8*l + 4]);
for (int l = 0; l < 4; l++) maxv[2*l] = vmaxq_f32(srcv[2*l], srcv[2*l + 1]);
for (int l = 0; l < 2; l++) maxv[4*l] = vmaxq_f32(maxv[4*l], maxv[4*l + 2]);
for (int l = 0; l < 1; l++) maxv[8*l] = vmaxq_f32(maxv[8*l], maxv[8*l + 4]);
const float min = vminvq_f32(minv[0]);
const float max = vmaxvq_f32(maxv[0]);
const float d = (max - min) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y[i].d = d;
y[i].m = min;
const float32x4_t minv0 = vdupq_n_f32(min);
for (int l = 0; l < 8; l++) {
const float32x4_t v = vmulq_n_f32(vsubq_f32(srcv[l], minv0), id);
const float32x4_t vf = vaddq_f32(v, vdupq_n_f32(0.5f)); // needed to round to nearest
const int32x4_t vi = vcvtq_s32_f32(vf);
y[i].qs[2*l + 0] = vgetq_lane_s32(vi, 0) | (vgetq_lane_s32(vi, 1) << 4);
y[i].qs[2*l + 1] = vgetq_lane_s32(vi, 2) | (vgetq_lane_s32(vi, 3) << 4);
}
}
#else
// scalar
quantize_row_q4_1_reference(x, vy, k);
#endif
} }
// reference implementation for deterministic creation of model files // reference implementation for deterministic creation of model files
@ -1546,13 +1427,13 @@ static void quantize_row_q8_1(const float * restrict x, void * restrict vy, int
#endif #endif
} }
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) { static void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int k) {
assert(QK4_0 / 16 == 0); static const int qk = QK4_0;
assert(k % QK4_0 == 0);
const int nb = k / QK4_0; assert(qk / 16 == 0);
assert(k % qk == 0);
const block_q4_0 * restrict x = vx; const int nb = k / qk;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d = x[i].d; const float d = x[i].d;
@ -1561,125 +1442,48 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
uint64_t qs[QK4_0 / 8] = {0}; uint64_t qs[QK4_0 / 8] = {0};
memcpy(qs + 0, x[i].qs, sizeof(x[i].qs)); memcpy(qs + 0, x[i].qs, sizeof(x[i].qs));
memcpy(qs + QK4_0 / 16, x[i].qs, sizeof(x[i].qs));
for (int l = 0; l < QK4_0 / 16; ++l) { for (int l = 0; l < qk / 16; ++l) {
qs[l ] = (qs[l ] & 0x0F0F0F0F0F0F0F0FULL) >> 0; qs[l + qk/16] = (qs[l] & 0xF0F0F0F0F0F0F0F0ULL) >> 4;
qs[l + QK4_0/16] = (qs[l + QK4_0/16] & 0xF0F0F0F0F0F0F0F0ULL) >> 4; qs[l + 0 ] = (qs[l] & 0x0F0F0F0F0F0F0F0FULL) >> 0;
} }
const uint8_t * restrict qsp = (const uint8_t * restrict) qs; const uint8_t * restrict qsp = (const uint8_t * restrict) qs;
for (int l = 0; l < QK4_0; ++l) { for (int l = 0; l < qk; ++l) {
y[i*QK4_0 + l] = (qsp[l] - 8)*d; y[i*qk + l] = (qsp[l] - 8)*d;
} }
} }
} }
static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) { static void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int k) {
assert(k % QK4_1 == 0); static const int qk = QK4_1;
const int nb = k / QK4_1;
const block_q4_1 * restrict x = vx; assert(qk / 16 == 0);
assert(k % qk == 0);
#if defined(__AVX2__) const int nb = k / qk;
for (int i = 0; i < nb; i++) {
const __m256 d_v = _mm256_broadcast_ss(&x[i].d);
const __m256 d_m = _mm256_broadcast_ss(&x[i].m);
const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK4_1; l += 32) {
// Load 32x4-bit integers into 32x8-bit integers
__m256i vx8 = bytes_from_nibbles_32(pp+l/2);
// Convert to 16-bit int
const __m256i vx16_lo = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 0));
const __m256i vx16_hi = _mm256_cvtepi8_epi16(_mm256_extracti128_si256(vx8, 1));
// Convert to 32-bit int -> float 32
const __m256 vf[4] = {
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 0))),
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_lo, 1))),
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 0))),
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_extracti128_si256(vx16_hi, 1)))
};
// Scale, add m and store
for (int j = 0; j < 4; j++) {
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
_mm256_storeu_ps(y + i * QK4_1 + l + j*8, result);
}
}
}
#elif defined(__ARM_NEON)
for (int i = 0; i < nb; i++) {
const float32x4_t vd = vdupq_n_f32(x[i].d);
const float32x4_t vm = vdupq_n_f32(x[i].m);
const uint8_t * restrict pp = x[i].qs;
for (int l = 0; l < QK4_1; l += 16) {
// Load 16x4-bit integers into 8x8-bit integers
const uint8x8_t v8 = vld1_u8(pp + l/2);
// Expand 4-bit qs to 8-bit bytes
const uint8x8_t v0 = vand_u8(v8, vdup_n_u8(0x0F));
const uint8x8_t v1 = vshr_n_u8(v8, 4);
// Interleave and combine
const uint8x8_t vx_0 = vzip1_u8(v0, v1);
const uint8x8_t vx_1 = vzip2_u8(v0, v1);
const uint8x16_t vq = vcombine_u8(vx_0, vx_1);
// convert to 2x uint16x8_t
const uint16x8_t vi_0 = vmovl_u8(vget_low_u8 (vq));
const uint16x8_t vi_1 = vmovl_u8(vget_high_u8(vq));
// convert to 4x float32x4_t
const float32x4_t vf_0 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_0)));
const float32x4_t vf_1 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_0)));
const float32x4_t vf_2 = vcvtq_f32_u32(vmovl_u16(vget_low_u16 (vi_1)));
const float32x4_t vf_3 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vi_1)));
// multiply by d and add m
const float32x4_t r0 = vmlaq_f32(vm, vf_0, vd);
const float32x4_t r1 = vmlaq_f32(vm, vf_1, vd);
const float32x4_t r2 = vmlaq_f32(vm, vf_2, vd);
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
// Store
vst1q_f32(y + i*QK4_1 + l + 0, r0);
vst1q_f32(y + i*QK4_1 + l + 4, r1);
vst1q_f32(y + i*QK4_1 + l + 8, r2);
vst1q_f32(y + i*QK4_1 + l + 12, r3);
}
}
#else
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d = x[i].d; const float d = x[i].d;
const float m = x[i].m; const float m = x[i].m;
const uint8_t * restrict pp = x[i].qs; // unpack nibbles into bytes
uint64_t qs[QK4_0 / 8] = {0};
for (int l = 0; l < QK4_1; l += 2) { memcpy(qs + 0, x[i].qs, sizeof(x[i].qs));
const uint8_t vi = pp[l/2];
const int8_t vi0 = vi & 0x0F; for (int l = 0; l < qk / 16; ++l) {
const int8_t vi1 = vi >> 4; qs[l + qk/16] = (qs[l] & 0xF0F0F0F0F0F0F0F0ULL) >> 4;
qs[l + 0 ] = (qs[l] & 0x0F0F0F0F0F0F0F0FULL) >> 0;
}
const float v0 = vi0*d + m; const uint8_t * restrict qsp = (const uint8_t * restrict) qs;
const float v1 = vi1*d + m;
y[i*QK4_1 + l + 0] = v0; for (int l = 0; l < qk; ++l) {
y[i*QK4_1 + l + 1] = v1; y[i*qk + l] = qsp[l]*d + m;
assert(!isnan(y[i*QK4_1 + l + 0]));
assert(!isnan(y[i*QK4_1 + l + 1]));
} }
} }
#endif
} }
static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) { static void dequantize_row_q4_2(const void * restrict vx, float * restrict y, int k) {
@ -1810,7 +1614,7 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
[GGML_TYPE_Q4_0] = { [GGML_TYPE_Q4_0] = {
.dequantize_row_q = dequantize_row_q4_0, .dequantize_row_q = (dequantize_row_q_t) dequantize_row_q4_0,
.quantize_row_q = quantize_row_q4_0, .quantize_row_q = quantize_row_q4_0,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_0_reference,
.quantize_row_q_dot = quantize_row_q8_0, .quantize_row_q_dot = quantize_row_q8_0,
@ -1818,7 +1622,7 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_0, .vec_dot_type = GGML_TYPE_Q8_0,
}, },
[GGML_TYPE_Q4_1] = { [GGML_TYPE_Q4_1] = {
.dequantize_row_q = dequantize_row_q4_1, .dequantize_row_q = (dequantize_row_q_t)dequantize_row_q4_1,
.quantize_row_q = quantize_row_q4_1, .quantize_row_q = quantize_row_q4_1,
.quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q4_1_reference,
.quantize_row_q_dot = quantize_row_q8_1, .quantize_row_q_dot = quantize_row_q8_1,
@ -2467,9 +2271,10 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
} }
static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK8_0; const int qk = QK8_0;
const int nb = n / qk;
assert(n % QK8_0 == 0); assert(n % qk == 0);
assert(nb % 2 == 0); assert(nb % 2 == 0);
const block_q4_0 * restrict x = vx; const block_q4_0 * restrict x = vx;
@ -2604,41 +2409,45 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
#else #else
// scalar // scalar
float sumf = 0.0; float sumf = 0.0;
for (int i = 0; i < nb; i++) {
const float d0 = x[i].d;
const float d1 = y[i].d;
const uint8_t * restrict p0 = x[i].qs; for (int i = 0; i < nb; i++) {
const int8_t * restrict p1 = y[i].qs; // unpack nibbles into bytes
uint64_t qs[QK8_0 / 8] = {0};
memcpy(qs + 0, x[i].qs, sizeof(x[i].qs));
for (int l = 0; l < qk / 16; ++l) {
qs[l + qk/16] = (qs[l] & 0xF0F0F0F0F0F0F0F0ULL) >> 4;
qs[l + 0 ] = (qs[l] & 0x0F0F0F0F0F0F0F0FULL) >> 0;
}
const uint8_t * restrict px = (const uint8_t * restrict) qs;
const int8_t * restrict py = y[i].qs;
int sumi = 0; int sumi = 0;
for (int j = 0; j < QK8_0/2; j++) {
const uint8_t v0 = p0[j];
const int i0 = (int8_t) (v0 & 0x0F) - 8; for (int j = 0; j < qk; ++j) {
const int i1 = (int8_t) (v0 >> 4) - 8; sumi += (px[j] - 8) * py[j];
const int i2 = p1[2*j + 0];
const int i3 = p1[2*j + 1];
sumi += i0*i2 + i1*i3;
} }
sumf += d0*d1*sumi;
sumf += (x[i].d*y[i].d)*sumi;
} }
*s = sumf; *s = sumf;
#endif #endif
} }
static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int nb = n / QK8_1; const int qk = QK8_1;
const int nb = n / qk;
assert(n % QK8_1 == 0); assert(n % qk == 0);
assert(nb % 2 == 0); assert(nb % 2 == 0);
const block_q4_1 * restrict x = vx; const block_q4_1 * restrict x = vx;
const block_q8_1 * restrict y = vy; const block_q8_1 * restrict y = vy;
// TODO: add AVX / WASM SIMD / etc // TODO: add WASM SIMD
#if defined(__ARM_NEON) #if defined(__ARM_NEON)
float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f);
@ -2664,12 +2473,6 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b)); const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// interleave
const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h);
const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h);
const int8x16_t v0_1lz = vzip1q_s8(v0_1l, v0_1h);
const int8x16_t v0_1hz = vzip2q_s8(v0_1l, v0_1h);
// load y // load y
const int8x16_t v1_0l = vld1q_s8(y0->qs); const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
@ -2678,21 +2481,21 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
// dot product into int32x4_t // dot product into int32x4_t
const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h); const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l), v0_0h, v1_0h);
const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h); const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l), v0_1h, v1_1h);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d);
#else #else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0h), vget_low_s8 (v1_0h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0h), vget_high_s8(v1_0h));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1l), vget_low_s8 (v1_1l));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1l), vget_high_s8(v1_1l));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1h));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1h));
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
@ -2738,27 +2541,30 @@ static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void *
#else #else
// scalar // scalar
float sumf = 0.0; float sumf = 0.0;
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
const float d0 = x[i].d; // unpack nibbles into bytes
const float m0 = x[i].m; uint64_t qs[QK8_1 / 8] = {0};
const float d1 = y[i].d;
const uint8_t * restrict p0 = x[i].qs; memcpy(qs + 0, x[i].qs, sizeof(x[i].qs));
const int8_t * restrict p1 = y[i].qs;
// TODO: this is very slow .. for (int l = 0; l < qk / 16; ++l) {
for (int j = 0; j < QK8_1/2; j++) { qs[l + qk/16] = (qs[l] & 0xF0F0F0F0F0F0F0F0ULL) >> 4;
const uint8_t v0 = p0[j]; qs[l + 0 ] = (qs[l] & 0x0F0F0F0F0F0F0F0FULL) >> 0;
const float f0 = d0*(v0 & 0x0F) + m0;
const float f1 = d0*(v0 >> 4) + m0;
const float f2 = d1*p1[2*j + 0];
const float f3 = d1*p1[2*j + 1];
sumf += f0*f2 + f1*f3;
} }
const uint8_t * restrict px = (const uint8_t * restrict) qs;
const int8_t * restrict py = y[i].qs;
int sumi = 0;
for (int j = 0; j < qk; ++j) {
sumi += px[j]*py[j];
} }
sumf += (x[i].d*y[i].d)*sumi + x[i].m*(y[i].s0 + y[i].s1);
}
*s = sumf; *s = sumf;
#endif #endif
} }