mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
iq2_xs: better ARM_NEON dot product
We are now at 19.5 t/s for TG-128 and 61 t/s for PP-512 when running on the CPU.
This commit is contained in:
parent
ff49d876c6
commit
52ea3f7930
@ -7558,14 +7558,26 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|||||||
int8x16x4_t q2s;
|
int8x16x4_t q2s;
|
||||||
int8x16x4_t q8b;
|
int8x16x4_t q8b;
|
||||||
|
|
||||||
|
int32x4x4_t scales32;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
const uint16_t * restrict q2 = x[i].qs;
|
const uint16_t * restrict q2 = x[i].qs;
|
||||||
const uint8_t * restrict sc = x[i].scales;
|
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
const int8_t * restrict q8 = y[i].qs;
|
||||||
float sumf1 = 0, sumf2 = 0;
|
const uint8x8_t scales8 = vld1_u8(x[i].scales);
|
||||||
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
|
const uint8x8_t scales_l = vand_u8(scales8, vdup_n_u8(0xf));
|
||||||
|
const uint8x8_t scales_h = vshr_n_u8(scales8, 4);
|
||||||
|
uint8x16_t scales = vcombine_u8(vzip1_u8(scales_l, scales_h), vzip2_u8(scales_l, scales_h));
|
||||||
|
scales = vaddq_u8(vshlq_n_u8(scales, 1), vdupq_n_u8(1));
|
||||||
|
const uint16x8_t scales1 = vmovl_u8(vget_low_u8(scales));
|
||||||
|
const uint16x8_t scales2 = vmovl_u8(vget_high_u8(scales));
|
||||||
|
scales32.val[0] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales1)));
|
||||||
|
scales32.val[1] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales1)));
|
||||||
|
scales32.val[2] = vreinterpretq_s32_u32(vmovl_u16(vget_low_u16(scales2)));
|
||||||
|
scales32.val[3] = vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales2)));
|
||||||
|
int32x4_t sumi = vdupq_n_s32(0);
|
||||||
|
for (int ib64 = 0; ib64 < QK_K/64; ++ib64) {
|
||||||
q8b = vld1q_s8_x4(q8); q8 += 64;
|
q8b = vld1q_s8_x4(q8); q8 += 64;
|
||||||
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
|
q2u.val[0] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[0] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[1] & 511))));
|
||||||
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
|
q2u.val[1] = vcombine_s8(vld1_s8((const void *)(iq2xs_grid + (q2[2] & 511))), vld1_s8((const void *)(iq2xs_grid + (q2[3] & 511))));
|
||||||
@ -7583,16 +7595,13 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
|
|||||||
const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
|
const int32x4_t p2 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[1], q8b.val[1]);
|
||||||
const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
|
const int32x4_t p3 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]);
|
||||||
const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
|
const int32x4_t p4 = ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[3], q8b.val[3]);
|
||||||
sumf1 += vaddvq_s32(p1) * (0.5f + (sc[0] & 0xf));
|
const int32x4_t p = vpaddq_s32(vpaddq_s32(p1, p2), vpaddq_s32(p3, p4));
|
||||||
sumf2 += vaddvq_s32(p2) * (0.5f + (sc[0] >> 4));
|
sumi = vmlaq_s32(sumi, p, scales32.val[ib64]);
|
||||||
sumf1 += vaddvq_s32(p3) * (0.5f + (sc[1] & 0xf));
|
|
||||||
sumf2 += vaddvq_s32(p4) * (0.5f + (sc[1] >> 4));
|
|
||||||
q2 += 8;
|
q2 += 8;
|
||||||
sc += 2;
|
|
||||||
}
|
}
|
||||||
sumf += d*(sumf1 + sumf2);
|
sumf += d*vaddvq_s32(sumi);
|
||||||
}
|
}
|
||||||
*s = 0.25f * sumf;
|
*s = 0.125f * sumf;
|
||||||
|
|
||||||
#elif defined(z__AVX2__)
|
#elif defined(z__AVX2__)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user