iq1_s: WIP AVX2 dot product - something is not right

This commit is contained in:
Iwan Kawrakow 2024-02-11 17:22:42 +02:00
parent d94139bf27
commit 592b3b26bb

View File

@ -9282,6 +9282,14 @@ void ggml_vec_dot_iq3_xxs_q8_K(int n, float * restrict s, size_t bs, const void
#endif #endif
} }
#ifdef __AVX2__
static inline __m256i mul_add_epi8(const __m256i x, const __m256i y) {
const __m256i ax = _mm256_sign_epi8(x, x);
const __m256i sy = _mm256_sign_epi8(y, x);
return _mm256_maddubs_epi16(ax, sy);
}
#endif
void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
assert(n % QK_K == 0); assert(n % QK_K == 0);
@ -9290,6 +9298,59 @@ void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restr
const int nb = n / QK_K; const int nb = n / QK_K;
#if defined __AVX2__
const __m128i m8 = _mm_set1_epi8(0x08);
const __m128i m7 = _mm_set1_epi8(0x07);
const __m128i shuffle_h = _mm_set_epi8(15, 7, 14, 6, 13, 5, 12, 4, 11, 3, 10, 2, 9, 1, 8, 0);
const __m128i shuffle_s[4] = {
_mm_set_epi32(0x03030303, 0x02020202, 0x01010101, 0x00000000),
_mm_set_epi32(0x07070707, 0x06060606, 0x05050505, 0x04040404),
_mm_set_epi32(0x0b0b0b0b, 0x0a0a0a0a, 0x09090909, 0x08080808),
_mm_set_epi32(0x0f0f0f0f, 0x0e0e0e0e, 0x0d0d0d0d, 0x0c0c0c0c)
};
uint64_t aux64;
__m256i v_gindex;
const uint16_t * gindex = (const uint16_t *)&v_gindex;
__m256 accum = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) {
const int8_t * q8 = y[i].qs;
const uint8_t * qs = x[i].qs;
const uint8_t * sc = x[i].scales;
__m256i sumi = _mm256_setzero_si256();
for (int i128 = 0; i128 < QK_K/128; ++i128) {
const __m128i ql = _mm_loadu_si128((const __m128i*)qs); qs += 16;
memcpy(&aux64, sc, 8); sc += 8;
const __m128i qh = _mm_shuffle_epi8(_mm_set_epi64x(aux64 >> 4, aux64), shuffle_h);
const __m256i hbit = _mm256_cvtepi8_epi16(_mm_and_si128(qh, m8));
v_gindex = _mm256_or_si256(_mm256_cvtepi8_epi16(ql), _mm256_slli_epi16(hbit, 5));
const __m128i scales = _mm_and_si128(qh, m7);
for (int i32 = 0; i32 < 4; ++i32) {
const __m256i q8b = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;
const __m256i q1b = _mm256_set_epi64x(iq1s_grid[gindex[4*i32+3]], iq1s_grid[gindex[4*i32+2]],
iq1s_grid[gindex[4*i32+1]], iq1s_grid[gindex[4*i32+0]]);
const __m256i dot = mul_add_epi8(q1b, q8b);
const __m256i s16 = _mm256_cvtepi8_epi16(_mm_shuffle_epi8(scales, shuffle_s[i32]));
const __m256i p = _mm256_madd_epi16(s16, dot);
sumi = _mm256_add_epi32(sumi, p);
}
}
accum = _mm256_fmadd_ps(_mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)), _mm256_cvtepi32_ps(sumi), accum);
}
*s = hsum_float_8(accum);
#else
int db[4]; int db[4];
uint16_t idx[4]; uint16_t idx[4];
@ -9326,6 +9387,8 @@ void ggml_vec_dot_iq1_s_q8_K(const int n, float * restrict s, const void * restr
*s = sumf; *s = sumf;
#endif
} }
// ================================ IQ2 quantization ============================================= // ================================ IQ2 quantization =============================================