iq1_s: ARM_NEON dot product. Works, but not very fast

This commit is contained in:
Iwan Kawrakow 2024-02-12 11:40:31 +02:00
parent 2ffb05acc8
commit 773014926f

View File

@ -9303,7 +9303,64 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const
const int nb = n / QK_K; const int nb = n / QK_K;
#if defined __AVX2__ #if defined __ARM_NEON
const uint8x16_t m8 = vdupq_n_u8(0x08);
const uint8x16_t m7 = vdupq_n_u8(0x07);
const uint8x16_t m1 = vdupq_n_u8(0x01);
const int32x4_t vzero = vdupq_n_s32(0);
uint16_t gindex[8];
uint16x8x2_t vindex;
int8x16x4_t q1b;
int8x16x4_t q8b;
uint16x8x4_t scales;
int32x4x2_t sumi;
int32x4x2_t dotq;
float sumf = 0;
for (int i = 0; i < nb; ++i) {
const int8_t * q8 = y[i].qs;
const uint8_t * qs = x[i].qs;
const uint8_t * sc = x[i].scales;
sumi.val[0] = sumi.val[1] = vzero;
for (int i128 = 0; i128 < QK_K/128; ++i128) {
const uint8x16_t ql = vld1q_u8(qs); qs += 16;
const uint8x8_t tm1 = vld1_u8 (sc); sc += 8;
const uint8x8_t tm2 = vshr_n_u8(tm1, 4);
const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2));
const uint8x16_t hbit = vandq_u8(qh, m8);
vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5));
vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5));
const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1);
scales.val[0] = vmovl_u8(vget_low_u8 (scales8));
scales.val[1] = vmovl_u8(vget_high_u8 (scales8));
for (int l = 0; l < 2; ++l) {
vst1q_u16(gindex+0, vindex.val[l]);
q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1])));
q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3])));
q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5])));
q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7])));
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1]));
dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3]));
sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l]))));
sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l]))));
}
}
sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1]));
}
*s = sumf;
#elif defined __AVX2__
const __m128i m8 = _mm_set1_epi8(0x08); const __m128i m8 = _mm_set1_epi8(0x08);
const __m128i m7 = _mm_set1_epi8(0x07); const __m128i m7 = _mm_set1_epi8(0x07);