From 9465ec6e12be0498f409af7d8d5e978403058d9a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 25 Jun 2024 01:32:14 -0400 Subject: [PATCH] ggml-quants : ARM NEON vec_dot for q2_2 and q1_3 --- ggml/src/ggml-quants.c | 188 +++++++++++++++++++++++++++++++++++++---- ggml/src/ggml.c | 12 +++ 2 files changed, 182 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 927737fa0..14a1ee4e9 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -686,6 +686,10 @@ void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict } } +void quantize_row_q2_2(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q2_2_reference(x, y, k); +} + // reference implementation for deterministic creation of model files void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; @@ -3900,17 +3904,81 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r } *s = hsum_float_8(acc); +#elif defined(__ARM_NEON) + float sumf0 = 0.0f; + float sumf1 = 0.0f; + + const uint8x8_t mask = vdup_n_u8(3); + const int8x8_t offset = vdup_n_s8(2); + + const int leftovers = nb % 2; + + for (int i = 0; i < nb - leftovers; i += 2) { + const uint8x8_t xq8_0 = vld1_u8(x[0].qs); + const uint8x8_t xq8_1 = vld1_u8(x[1].qs); + + const int8x8_t xq8_0_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8_0, mask)), offset); + const int8x8_t xq8_0_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 2), mask)), offset); + const int8x8_t xq8_0_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 4), mask)), offset); + const int8x8_t xq8_0_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 6), mask)), offset); + const int8x8_t xq8_1_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8_1, mask)), offset); + const int8x8_t xq8_1_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 2), mask)), offset); + const int8x8_t xq8_1_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 4), mask)), offset); + const int8x8_t xq8_1_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 6), mask)), offset); + + const int8x16_t xq8_0_l = vcombine_s8(xq8_0_0, xq8_0_1); + const int8x16_t xq8_0_h = vcombine_s8(xq8_0_2, xq8_0_3); + const int8x16_t xq8_1_l = vcombine_s8(xq8_1_0, xq8_1_1); + const int8x16_t xq8_1_h = vcombine_s8(xq8_1_2, xq8_1_3); + + const int8x16_t yq8_0_l = vld1q_s8(y[0].qs + 0); + const int8x16_t yq8_0_h = vld1q_s8(y[0].qs + 16); + const int8x16_t yq8_1_l = vld1q_s8(y[1].qs + 0); + const int8x16_t yq8_1_h = vld1q_s8(y[1].qs + 16); + + const int16x8_t dot0 = vaddq_s16(vpaddlq_s8(vmulq_s8(xq8_0_l, yq8_0_l)), vpaddlq_s8(vmulq_s8(xq8_0_h, yq8_0_h))); + const int16x8_t dot1 = vaddq_s16(vpaddlq_s8(vmulq_s8(xq8_1_l, yq8_1_l)), vpaddlq_s8(vmulq_s8(xq8_1_h, yq8_1_h))); + + sumf0 += GGML_FP16_TO_FP32(y[0].d) * (float) vaddlvq_s16(dot0); + sumf1 += GGML_FP16_TO_FP32(y[1].d) * (float) vaddlvq_s16(dot1); + x += 2; + y += 2; + } + + // one block at a time + for (int i = nb - leftovers; i < nb; ++i) { + const uint8x8_t xq8 = vld1_u8(x->qs); + const int8x8_t xq8_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8, mask)), offset); + const int8x8_t xq8_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 2), mask)), offset); + const int8x8_t xq8_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 4), mask)), offset); + const int8x8_t xq8_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 6), mask)), offset); + + const int8x16_t xq8_l = vcombine_s8(xq8_0, xq8_1); + const int8x16_t xq8_h = vcombine_s8(xq8_2, xq8_3); + + const int8x16_t yq8_l = vld1q_s8(y->qs + 0); + const int8x16_t yq8_h = vld1q_s8(y->qs + 16); + + const int16x8_t dot0 = vpaddlq_s8(vmulq_s8(xq8_l, yq8_l)); + const int16x8_t dot1 = vpaddlq_s8(vmulq_s8(xq8_h, yq8_h)); + + sumf0 += GGML_FP16_TO_FP32(y->d) * (float) vaddlvq_s16(vaddq_s16(dot0, dot1)); + x += 1; + y += 1; + } + + *s = sumf0 + sumf1; #else - float sumf = 0.0; + float sumf = 0.0f; for (int i = 0; i < nb; i++) { int sumi = 0; for (int j = 0; j < qk / 4; j++) { const uint8_t weight = x[i].qs[j]; - sumi += (int)y[i].qs[j + 0*qk/4] * ((weight >> 0) & 3) - 2; - sumi += (int)y[i].qs[j + 1*qk/4] * ((weight >> 2) & 3) - 2; - sumi += (int)y[i].qs[j + 2*qk/4] * ((weight >> 4) & 3) - 2; - sumi += (int)y[i].qs[j + 3*qk/4] * ((weight >> 6) & 3) - 2; + sumi += (int)y[i].qs[j + 0*qk/4] * (((weight >> 0) & 3) - 2); + sumi += (int)y[i].qs[j + 1*qk/4] * (((weight >> 2) & 3) - 2); + sumi += (int)y[i].qs[j + 2*qk/4] * (((weight >> 4) & 3) - 2); + sumi += (int)y[i].qs[j + 3*qk/4] * (((weight >> 6) & 3) - 2); } sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d)); } @@ -11314,27 +11382,27 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r for (int i = 0; i < nb; ++i) { // const __m128i x12a = _mm_maskload_epi32((const int32_t *) x, _mm_set_epi32(0, -1, -1, -1)); - // const __m128i x12b = _mm_insert_epi8(x12a, x->qs[0], 12); + // const __m128i x13b = _mm_insert_epi8(x12a, x->qs[0], 12); // WARNING: reading 3 bytes further than necessary. // It's measurably faster than a masked load on an Intel Core m3-8100Y - const __m128i x12b = _mm_loadu_si128((const __m128i_u *) x); - const __m256i x12 = MM256_SET_M128I(x12b, x12b); + const __m128i x13b = _mm_loadu_si128((const __m128i_u *) x); + const __m256i x13 = MM256_SET_M128I(x13b, x13b); { // pre-shift the values by 8 bits, and prepare the layout for later packing - __m256i x0l = _mm256_shuffle_epi8(x12, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1, + __m256i x0l = _mm256_shuffle_epi8(x13, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1, 4, -1, 4, -1, 4, -1, 4, -1, 1, -1, 1, -1, 1, -1, 1, -1, 0, -1, 0, -1, 0, -1, 0, -1)); - __m256i x0h = _mm256_shuffle_epi8(x12, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1, + __m256i x0h = _mm256_shuffle_epi8(x13, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1, 6, -1, 6, -1, 6, -1, 6, -1, 3, -1, 3, -1, 3, -1, 3, -1, 2, -1, 2, -1, 2, -1, 2, -1)); - __m256i x1l = _mm256_shuffle_epi8(x12, _mm256_set_epi8(7, -1, 6, -1, 5, -1, 4, -1, + __m256i x1l = _mm256_shuffle_epi8(x13, _mm256_set_epi8(7, -1, 6, -1, 5, -1, 4, -1, 3, -1, 2, -1, 1, -1, 0, -1, 9, -1, 9, -1, 9, -1, 9, -1, 8, -1, 8, -1, 8, -1, 8, -1)); - __m256i x1h = _mm256_shuffle_epi8(x12, _mm256_set_epi8(12, -1, 12, -1, 12, -1, 12, -1, + __m256i x1h = _mm256_shuffle_epi8(x13, _mm256_set_epi8(12, -1, 12, -1, 12, -1, 12, -1, 11, -1, 10, -1, 9, -1, 8, -1, 11, -1, 11, -1, 11, -1, 11, -1, 10, -1, 10, -1, 10, -1, 10, -1)); @@ -11385,6 +11453,88 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r } *s = hsum_float_8(accumf); +#elif defined(__ARM_NEON) + + static const uint8_t k_mask0[16] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; + static const uint8_t k_mask1[16] = {4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7}; + static const uint8_t k_mask2[16] = {8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 11, 11}; + static const uint8_t k_mask3[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 12, 12}; + + static const uint8_t k_shift0[16] = {81, 27, 9, 3, 81, 27, 9, 3, 81, 27, 9, 3, 81, 27, 9, 3}; + static const uint8_t k_shift3[16] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 81, 27, 9, 3}; + + // float32x4_t sumv0 = vdupq_n_f32(0.0f); + // float32x4_t sumv1 = vdupq_n_f32(0.0f); + + float sumf0 = 0.0f; + float sumf1 = 0.0f; + + const uint8x16_t mask0 = vld1q_u8(k_mask0); + const uint8x16_t mask1 = vld1q_u8(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + const uint8x16_t mask3 = vld1q_u8(k_mask3); + + const uint8x16_t shift0 = vld1q_u8(k_shift0); + const uint8x16_t shift3 = vld1q_u8(k_shift3); + + const int8x16_t one = vdupq_n_s8(1); + + for (int i = 0; i < nb; ++i) { + // WARNING: reading 3 bytes further than necessary + const uint8x16_t x13b = vld1q_u8((const uint8_t *) x); + + uint8x16_t x0 = vqtbl1q_u8(x13b, mask0); + uint8x16_t x1 = vqtbl1q_u8(x13b, mask1); + uint8x16_t x2 = vqtbl1q_u8(x13b, mask2); + uint8x16_t x3 = vqtbl1q_u8(x13b, mask3); + + x0 = vmulq_u8(x0, shift0); + x1 = vmulq_u8(x1, shift0); + x2 = vmulq_u8(x2, shift0); + x3 = vmulq_u8(x3, shift3); + + // multiply by 3 and keep the 2 bits above 8 bits + x0 = vshrq_n_u8(vhaddq_u8(x0, vshrq_n_u8(x0, 1)), 6); + x1 = vshrq_n_u8(vhaddq_u8(x1, vshrq_n_u8(x1, 1)), 6); + x2 = vshrq_n_u8(vhaddq_u8(x2, vshrq_n_u8(x2, 1)), 6); + x3 = vshrq_n_u8(vhaddq_u8(x3, vshrq_n_u8(x3, 1)), 6); + + // 0, 1, 2 => -1, 0, 1 + int8x16_t x0i = vsubq_s8(vreinterpretq_s8_u8(x0), one); + int8x16_t x1i = vsubq_s8(vreinterpretq_s8_u8(x1), one); + int8x16_t x2i = vsubq_s8(vreinterpretq_s8_u8(x2), one); + int8x16_t x3i = vsubq_s8(vreinterpretq_s8_u8(x3), one); + + const int8x16_t y0 = vld1q_s8(y[0].qs + 0); + const int8x16_t y1 = vld1q_s8(y[0].qs + 16); + const int8x16_t y2 = vld1q_s8(y[1].qs + 0); + const int8x16_t y3 = vld1q_s8(y[1].qs + 16); + + // const int32x4_t p0 = vpaddlq_s16(vaddq_s16(vpaddlq_s8(x0i), vpaddlq_s8(x1i))); + // const int32x4_t p1 = vpaddlq_s16(vaddq_s16(vpaddlq_s8(x2i), vpaddlq_s8(x3i))); + + // there's no direct equivalent to _mm_sign_epi8, unfortunately + x0i = vmulq_s8(x0i, y0); + x1i = vmulq_s8(x1i, y1); + x2i = vmulq_s8(x2i, y2); + x3i = vmulq_s8(x3i, y3); + + // overall 18.5% faster than with vector sums on a cortex-A72 + sumf0 += GGML_FP16_TO_FP32(y[0].d) * (float) vaddlvq_s16(vaddq_s16(vpaddlq_s8(x0i), vpaddlq_s8(x1i))); + sumf1 += GGML_FP16_TO_FP32(y[1].d) * (float) vaddlvq_s16(vaddq_s16(vpaddlq_s8(x2i), vpaddlq_s8(x3i))); + + // const int32x4_t p0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), x0i, y0), x1i, y1); + // const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), x2i, y2), x3i, y3); + + // sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p0), GGML_FP16_TO_FP32(y[0].d)); + // sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p1), GGML_FP16_TO_FP32(y[1].d)); + + y += 2; + x += 1; + } + + // *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + *s = sumf0 + sumf1; #else float sumf = 0.0f; @@ -11393,34 +11543,36 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r for (int j = 0; j < 8; ++j) { const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[j]); for (int k = 0; k < 4; ++k) { - sum += xj[k] * (int16_t) y[2*i].qs[4*j + k]; + sum += xj[k] * (int16_t) y->qs[4*j + k]; } } - sumf += GGML_FP16_TO_FP32(y[2*i].d) * sum; + sumf += GGML_FP16_TO_FP32(y->d) * sum; + y += 1; sum = 0; for (int j = 0; j < 4; ++j) { const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[8 + j]); for (int k = 0; k < 4; ++k) { - sum += xj[k] * (int16_t) y[2*i + 1].qs[4*j + k]; + sum += xj[k] * (int16_t) y->qs[4*j + k]; } } for (size_t j = 0; j < 12; ++j) { uint16_t xj = x[i].q[j]; xj = (xj * 3) >> 8; - sum += ((int16_t) xj - 1) * (int16_t) y[2*i + 1].qs[16 + j]; + sum += ((int16_t) xj - 1) * (int16_t) y->qs[16 + j]; } { const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].qs[0]); for (int k = 0; k < 4; ++k) { - sum += (int16_t) xj[k] * (int16_t) y[2*i + 1].qs[28 + k]; + sum += (int16_t) xj[k] * (int16_t) y->qs[28 + k]; } } - sumf += GGML_FP16_TO_FP32(y[2*i + 1].d) * sum; + sumf += GGML_FP16_TO_FP32(y->d) * sum; + y += 1; } *s = sumf; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 8c444d0b6..a3a062ed8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -823,6 +823,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .vec_dot_type = GGML_TYPE_Q8_K, .nrows = 1, }, + [GGML_TYPE_Q2_2] = { + .type_name = "q2_2", + .blck_size = QK2_2, + .type_size = sizeof(block_q2_2), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q2_2, + .from_float = quantize_row_q2_2, + .from_float_reference = (ggml_from_float_t) quantize_row_q2_2_reference, + .vec_dot = ggml_vec_dot_q2_2_q8_0, + .vec_dot_type = GGML_TYPE_Q8_0, + .nrows = 1, + }, [GGML_TYPE_Q1_3] = { .type_name = "q1_3", .blck_size = QK1_3,