iq2_xs: AVX2 dot product - 19.5 t/s

This commit is contained in:
Iwan Kawrakow 2024-01-10 08:49:38 +02:00
parent 52ea3f7930
commit 3198e94f00

View File

@ -7603,40 +7603,38 @@ void ggml_vec_dot_iq2_xs_q8_K(const int n, float * restrict s, const void * rest
} }
*s = 0.125f * sumf; *s = 0.125f * sumf;
#elif defined(z__AVX2__) #elif defined(__AVX2__)
const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;
uint32_t aux32[4];
const uint8_t * aux8 = (const uint8_t *)aux32;
__m256 accumf = _mm256_setzero_ps(); __m256 accumf = _mm256_setzero_ps();
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
const uint16_t * restrict q2 = x[i].qs; const uint16_t * restrict q2 = x[i].qs;
const uint8_t * restrict sc = x[i].scales;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
__m256i sumi1 = _mm256_setzero_si256(); __m256i sumi1 = _mm256_setzero_si256();
__m256i sumi2 = _mm256_setzero_si256(); __m256i sumi2 = _mm256_setzero_si256();
for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) {
const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_1 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32; const __m256i q8_2 = _mm256_loadu_si256((const __m256i *)q8); q8 += 32;
memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; const __m256i q2_1 = _mm256_set_epi64x(iq2xs_grid[q2[3] & 511], iq2xs_grid[q2[2] & 511], iq2xs_grid[q2[1] & 511], iq2xs_grid[q2[0] & 511]);
const __m256i q2_1 = _mm256_set_epi64x(iq2xxs_grid[aux8[ 3]], iq2xxs_grid[aux8[ 2]], iq2xxs_grid[aux8[1]], iq2xxs_grid[aux8[0]]); const __m256i q2_2 = _mm256_set_epi64x(iq2xs_grid[q2[7] & 511], iq2xs_grid[q2[6] & 511], iq2xs_grid[q2[5] & 511], iq2xs_grid[q2[4] & 511]);
const __m256i q2_2 = _mm256_set_epi64x(iq2xxs_grid[aux8[11]], iq2xxs_grid[aux8[10]], iq2xxs_grid[aux8[9]], iq2xxs_grid[aux8[8]]); const __m256i s2_1 = _mm256_set_epi64x(signs64[q2[3] >> 9], signs64[q2[2] >> 9], signs64[q2[1] >> 9], signs64[q2[0] >> 9]);
const __m256i s2_1 = _mm256_set_epi64x(signs64[(aux32[1] >> 21) & 127], signs64[(aux32[1] >> 14) & 127], const __m256i s2_2 = _mm256_set_epi64x(signs64[q2[7] >> 9], signs64[q2[6] >> 9], signs64[q2[5] >> 9], signs64[q2[4] >> 9]);
signs64[(aux32[1] >> 7) & 127], signs64[(aux32[1] >> 0) & 127]);
const __m256i s2_2 = _mm256_set_epi64x(signs64[(aux32[3] >> 21) & 127], signs64[(aux32[3] >> 14) & 127],
signs64[(aux32[3] >> 7) & 127], signs64[(aux32[3] >> 0) & 127]);
const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1); const __m256i q8s_1 = _mm256_sign_epi8(q8_1, s2_1);
const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2); const __m256i q8s_2 = _mm256_sign_epi8(q8_2, s2_2);
const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1); const __m256i dot1 = _mm256_maddubs_epi16(q2_1, q8s_1);
const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2); const __m256i dot2 = _mm256_maddubs_epi16(q2_2, q8s_2);
const uint16_t ls1 = aux32[1] >> 28;
const uint16_t ls2 = aux32[3] >> 28; const uint16_t ls1 = 2*(sc[0] & 0xf) + 1, ls2 = 2*(sc[0] >> 4) + 1;
const __m256i p1 = _mm256_madd_epi16(dot1, _mm256_set1_epi16(2*ls1+1)); const uint16_t ls3 = 2*(sc[1] & 0xf) + 1, ls4 = 2*(sc[1] >> 4) + 1;
const __m256i p2 = _mm256_madd_epi16(dot2, _mm256_set1_epi16(2*ls2+1)); const __m256i p1 = _mm256_madd_epi16(dot1, MM256_SET_M128I(_mm_set1_epi16(ls2), _mm_set1_epi16(ls1)));
const __m256i p2 = _mm256_madd_epi16(dot2, MM256_SET_M128I(_mm_set1_epi16(ls4), _mm_set1_epi16(ls3)));
sumi1 = _mm256_add_epi32(sumi1, p1); sumi1 = _mm256_add_epi32(sumi1, p1);
sumi2 = _mm256_add_epi32(sumi2, p2); sumi2 = _mm256_add_epi32(sumi2, p2);
q2 += 8;
sc += 2;
} }
accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf); accumf = _mm256_fmadd_ps(_mm256_set1_ps(d), _mm256_cvtepi32_ps(_mm256_add_epi32(sumi1, sumi2)), accumf);