diff --git a/ggml-quants.c b/ggml-quants.c index 59056e409..047a8cb4a 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -7558,14 +7558,26 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest int8x16x4_t q2s; int8x16x4_t q8b; + int32x4x4_t scales32; + float sumf = 0; for (int i = 0; i < nb; ++i) { const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const uint16_t * restrict q2 = x[i].qs; - const uint8_t * restrict sc = x[i].scales; const int8_t * restrict q8 = y[i].qs; - float sumf1 = 0, sumf2 = 0; - for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + const uint8x8_t scales8 = vld1_u8(x[i].scales); + const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf)); + const uint8x8_t scales_h = vshr_n_u8(scales8, 4); + uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h)); + scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1)); + const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales)); + const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales)); + scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1))); + scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1))); + scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2))); + scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2))); + int32x4_t sumi = vdupq_n_s32(0); + for (int ib64 = 0; ib64 < QK_K/64; ++ib64) { q8b = vld1q_s8_x4(q8); q8 += 64; q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511)))); q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511)))); @@ -7583,16 +7595,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]); const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]); const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]); - sumf1 += vaddvq_s32(p1) * (0.5f + (sc[0] & 0xf)); - sumf2 += vaddvq_s32(p2) * (0.5f + (sc[0] >> 4)); - sumf1 += vaddvq_s32(p3) * (0.5f + (sc[1] & 0xf)); - sumf2 += vaddvq_s32(p4) * (0.5f + (sc[1] >> 4)); + const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4)); + sumi = vmlaq_s32(sumi, p, scales32.val[ib64]); q2 += 8; - sc += 2; } - sumf += d*(sumf1 + sumf2); + sumf += d*vaddvq_s32(sumi); } - *s = 0.25f * sumf; + *s = 0.125f * sumf; #elif defined(z__AVX2__)