iq2_xs: better ARM_NEON dot product

We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when
running on the CPU.
This commit is contained in:
Iwan Kawrakow 2024-01-09 19:43:39 +01:00
parent ff49d876c6
commit 52ea3f7930

View File

@ -7558,14 +7558,26 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
int8x16x4_t q2s;
int8x16x4_t q8b;
int32x4x4_t scales32;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs;
const uint8_t * restrict sc = x[i].scales;
const int8_t * restrict q8 = y[i].qs;
float sumf1 = 0, sumf2 = 0;
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const uint8x8_t scales8 = vld1_u8(x[i].scales);
const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
int32x4_t sumi = vdupq_n_s32(0);
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
q8b = vld1q_s8_x4(q8); q8 += 64;
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
@ -7583,16 +7595,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
sumf1 += vaddvq_s32(p1) * (0.5f + (sc[0] & 0xf));
sumf2 += vaddvq_s32(p2) * (0.5f + (sc[0] >> 4));
sumf1 += vaddvq_s32(p3) * (0.5f + (sc[1] & 0xf));
sumf2 += vaddvq_s32(p4) * (0.5f + (sc[1] >> 4));
const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
q2 += 8;
sc += 2;
}
sumf += d*(sumf1 + sumf2);
sumf += d*vaddvq_s32(sumi);
}
*s = 0.25f * sumf;
*s = 0.125f * sumf;
#elif defined(z__AVX2__)