ggml : even faster TQ2_0

This commit is contained in:
Francis Couture-Harpin 2024-07-30 23:36:52 -04:00
parent 77b8f84ae7
commit 560873f337

View File

@ -3456,8 +3456,8 @@ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y,
for (size_t m = 0; m < 32; ++m) { for (size_t m = 0; m < 32; ++m) {
uint8_t q = 0; uint8_t q = 0;
for (size_t n = 0; n < 4; ++n) { for (size_t n = 0; n < 4; ++n) {
// -1, 0, 1 -> 1, 2, 3 // -1, 0, 1 -> 0, 1, 2
int xi = nearest_int(x[m + n*32] * id) + 2; int xi = nearest_int(x[m + n*32] * id) + 1;
q += (xi & 3) << (2*n); q += (xi & 3) << (2*n);
} }
y[i].q[j + m] = q; y[i].q[j + m] = q;
@ -3544,7 +3544,7 @@ void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, in
for (size_t j = 0; j < sizeof(x->q); j += 32) { for (size_t j = 0; j < sizeof(x->q); j += 32) {
for (size_t l = 0; l < 4; ++l) { for (size_t l = 0; l < 4; ++l) {
for (size_t m = 0; m < 32; ++m) { for (size_t m = 0; m < 32; ++m) {
*y++ = (float) (((x[i].q[j + m] >> (l*2)) & 3) - 2) * d; *y++ = (float) (((x[i].q[j + m] >> (l*2)) & 3) - 1) * d;
} }
} }
} }
@ -6127,33 +6127,31 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
__m256i qx2 = _mm256_srli_epi16(qx0, 4); __m256i qx2 = _mm256_srli_epi16(qx0, 4);
__m256i qx3 = _mm256_srli_epi16(qx0, 6); __m256i qx3 = _mm256_srli_epi16(qx0, 6);
// 1, 2, 3 => -1, 0, 1 // 0, 1, 2 (should not be 3)
qx0 = _mm256_sub_epi8(_mm256_and_si256(qx0, _mm256_set1_epi8(3)), _mm256_set1_epi8(2)); qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3));
qx1 = _mm256_sub_epi8(_mm256_and_si256(qx1, _mm256_set1_epi8(3)), _mm256_set1_epi8(2)); qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3));
qx2 = _mm256_sub_epi8(_mm256_and_si256(qx2, _mm256_set1_epi8(3)), _mm256_set1_epi8(2)); qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3));
qx3 = _mm256_sub_epi8(_mm256_and_si256(qx3, _mm256_set1_epi8(3)), _mm256_set1_epi8(2)); qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3));
const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0)); const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0));
const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32)); const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32));
const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64)); const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64));
const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96)); const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96));
qx0 = _mm256_sign_epi8(qy0, qx0); qx0 = _mm256_maddubs_epi16(qx0, qy0);
qx1 = _mm256_sign_epi8(qy1, qx1); qx1 = _mm256_maddubs_epi16(qx1, qy1);
qx2 = _mm256_sign_epi8(qy2, qx2); qx2 = _mm256_maddubs_epi16(qx2, qy2);
qx3 = _mm256_sign_epi8(qy3, qx3); qx3 = _mm256_maddubs_epi16(qx3, qy3);
qx0 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx0);
qx1 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx1);
qx2 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx2);
qx3 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx3);
sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1));
sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3));
} }
const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums);
const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d));
sumi0 = _mm256_add_epi16(sumi0, sumi1); sumi0 = _mm256_add_epi16(sumi0, sumi1);
sumi0 = _mm256_sub_epi16(sumi0, ysum);
sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1));
sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf);
@ -6169,7 +6167,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
for (size_t j = 0; j < sizeof(x->q); j += 32) { for (size_t j = 0; j < sizeof(x->q); j += 32) {
for (size_t l = 0; l < 4; ++l) { for (size_t l = 0; l < 4; ++l) {
for (size_t k = 0; k < 32; ++k) { for (size_t k = 0; k < 32; ++k) {
sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].q[j + k] >> (l*2)) & 3) - 2); sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].q[j + k] >> (l*2)) & 3) - 1);
} }
} }
} }