From 592b3b26bb10efe19f22c49b1f620dd4362a9087 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sun, 11 Feb 2024 17:22:42 +0200 Subject: [PATCH] iq1_s: WIP AVX2 dot product - something is not right --- ggml-quants.c | 63 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/ggml-quants.c b/ggml-quants.c index bb3f7bb55..eef36e962 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -9282,6 +9282,14 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void #endif } +#ifdef __AVX2__ +static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) { + const __m256i ax = _mm256_sign_epi8(x, x); + const __m256i sy = _mm256_sign_epi8(y, x); + return _mm256_maddubs_epi16(ax, sy); +} +#endif + void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { assert(n % QK_K == 0); @@ -9290,6 +9298,59 @@ void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restr const int nb = n / QK_K; +#if defined __AVX2__ + + const __m128i m8 = _mm_set1_epi8(0x08); + const __m128i m7 = _mm_set1_epi8(0x07); + const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0); + const __m128i shuffle_s[4] = { + _mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000), + _mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404), + _mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808), + _mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c) + }; + + uint64_t aux64; + + __m256i v_gindex; + const uint16_t * gindex = (const uint16_t *)&v_gindex; + + __m256 accum = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint8_t * sc = x[i].scales; + + __m256i sumi = _mm256_setzero_si256(); + for (int i128 = 0; i128 < QK_K/128; ++i128) { + const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16; + memcpy(&aux64, sc, 8); sc += 8; + const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h); + const __m256i hbit = _mm256_cvtepi8_epi16(_mm_and_si128(qh, m8)); + v_gindex = _mm256_or_si256(_mm256_cvtepi8_epi16(ql), _mm256_slli_epi16(hbit, 5)); + const __m128i scales = _mm_and_si128(qh, m7); + + for (int i32 = 0; i32 < 4; ++i32) { + const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; + const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]], + iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]); + const __m256i dot = mul_add_epi8(q1b, q8b); + const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32])); + const __m256i p = _mm256_madd_epi16(s16, dot); + sumi = _mm256_add_epi32(sumi, p); + } + + } + + accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum); + + } + + *s = hsum_float_8(accum); + +#else + int db[4]; uint16_t idx[4]; @@ -9326,6 +9387,8 @@ void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restr *s = sumf; +#endif + } // ================================ IQ2 quantization =============================================