mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
iq2_xs: better ARM_NEON dot product
We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when running on the CPU.
This commit is contained in:
parent
ff49d876c6
commit
52ea3f7930
@ -7558,14 +7558,26 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
||||
int8x16x4_t q2s;
|
||||
int8x16x4_t q8b;
|
||||
|
||||
int32x4x4_t scales32;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||
const uint16_t * restrict q2 = x[i].qs;
|
||||
const uint8_t * restrict sc = x[i].scales;
|
||||
const int8_t * restrict q8 = y[i].qs;
|
||||
float sumf1 = 0, sumf2 = 0;
|
||||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
||||
const uint8x8_t scales8 = vld1_u8(x[i].scales);
|
||||
const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
|
||||
const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
|
||||
uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
|
||||
scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
|
||||
const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
|
||||
const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
|
||||
scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
|
||||
scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
|
||||
scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
|
||||
scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
|
||||
int32x4_t sumi = vdupq_n_s32(0);
|
||||
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
|
||||
q8b = vld1q_s8_x4(q8); q8 += 64;
|
||||
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
|
||||
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
|
||||
@ -7583,16 +7595,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
||||
const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
|
||||
const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
|
||||
const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
|
||||
sumf1 += vaddvq_s32(p1) * (0.5f + (sc[0] & 0xf));
|
||||
sumf2 += vaddvq_s32(p2) * (0.5f + (sc[0] >> 4));
|
||||
sumf1 += vaddvq_s32(p3) * (0.5f + (sc[1] & 0xf));
|
||||
sumf2 += vaddvq_s32(p4) * (0.5f + (sc[1] >> 4));
|
||||
const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
|
||||
sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
|
||||
q2 += 8;
|
||||
sc += 2;
|
||||
}
|
||||
sumf += d*(sumf1 + sumf2);
|
||||
sumf += d*vaddvq_s32(sumi);
|
||||
}
|
||||
*s = 0.25f * sumf;
|
||||
*s = 0.125f * sumf;
|
||||
|
||||
#elif defined(z__AVX2__)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user