diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 44506f731..eb71aa9aa 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3456,8 +3456,8 @@ void quantize_row_tq2_0_ref(const float * restrict x, block_tq2_0 * restrict y, for (size_t m = 0; m < 32; ++m) { uint8_t q = 0; for (size_t n = 0; n < 4; ++n) { - // -1, 0, 1 -> 1, 2, 3 - int xi = nearest_int(x[m + n*32] * id) + 2; + // -1, 0, 1 -> 0, 1, 2 + int xi = nearest_int(x[m + n*32] * id) + 1; q += (xi & 3) << (2*n); } y[i].q[j + m] = q; @@ -3544,7 +3544,7 @@ void dequantize_row_tq2_0(const block_tq2_0 * restrict x, float * restrict y, in for (size_t j = 0; j < sizeof(x->q); j += 32) { for (size_t l = 0; l < 4; ++l) { for (size_t m = 0; m < 32; ++m) { - *y++ = (float) (((x[i].q[j + m] >> (l*2)) & 3) - 2) * d; + *y++ = (float) (((x[i].q[j + m] >> (l*2)) & 3) - 1) * d; } } } @@ -6127,33 +6127,31 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * __m256i qx2 = _mm256_srli_epi16(qx0, 4); __m256i qx3 = _mm256_srli_epi16(qx0, 6); - // 1, 2, 3 => -1, 0, 1 - qx0 = _mm256_sub_epi8(_mm256_and_si256(qx0, _mm256_set1_epi8(3)), _mm256_set1_epi8(2)); - qx1 = _mm256_sub_epi8(_mm256_and_si256(qx1, _mm256_set1_epi8(3)), _mm256_set1_epi8(2)); - qx2 = _mm256_sub_epi8(_mm256_and_si256(qx2, _mm256_set1_epi8(3)), _mm256_set1_epi8(2)); - qx3 = _mm256_sub_epi8(_mm256_and_si256(qx3, _mm256_set1_epi8(3)), _mm256_set1_epi8(2)); + // 0, 1, 2 (should not be 3) + qx0 = _mm256_and_si256(qx0, _mm256_set1_epi8(3)); + qx1 = _mm256_and_si256(qx1, _mm256_set1_epi8(3)); + qx2 = _mm256_and_si256(qx2, _mm256_set1_epi8(3)); + qx3 = _mm256_and_si256(qx3, _mm256_set1_epi8(3)); const __m256i qy0 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 0)); const __m256i qy1 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 32)); const __m256i qy2 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 64)); const __m256i qy3 = _mm256_loadu_si256((const __m256i *) (y[i].qs + j*4 + 96)); - qx0 = _mm256_sign_epi8(qy0, qx0); - qx1 = _mm256_sign_epi8(qy1, qx1); - qx2 = _mm256_sign_epi8(qy2, qx2); - qx3 = _mm256_sign_epi8(qy3, qx3); - - qx0 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx0); - qx1 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx1); - qx2 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx2); - qx3 = _mm256_maddubs_epi16(_mm256_set1_epi8(1), qx3); + qx0 = _mm256_maddubs_epi16(qx0, qy0); + qx1 = _mm256_maddubs_epi16(qx1, qy1); + qx2 = _mm256_maddubs_epi16(qx2, qy2); + qx3 = _mm256_maddubs_epi16(qx3, qy3); sumi0 = _mm256_add_epi16(sumi0, _mm256_add_epi16(qx0, qx1)); sumi1 = _mm256_add_epi16(sumi1, _mm256_add_epi16(qx2, qx3)); } + + const __m256i ysum = _mm256_loadu_si256((const __m256i *) y[i].bsums); const __m256 d = _mm256_set1_ps(y[i].d * GGML_FP16_TO_FP32(x[i].d)); sumi0 = _mm256_add_epi16(sumi0, sumi1); + sumi0 = _mm256_sub_epi16(sumi0, ysum); sumi0 = _mm256_madd_epi16(sumi0, _mm256_set1_epi16(1)); sumf = _mm256_add_ps(_mm256_mul_ps(_mm256_cvtepi32_ps(sumi0), d), sumf); @@ -6169,7 +6167,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void * for (size_t j = 0; j < sizeof(x->q); j += 32) { for (size_t l = 0; l < 4; ++l) { for (size_t k = 0; k < 32; ++k) { - sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].q[j + k] >> (l*2)) & 3) - 2); + sumi += y[i].qs[j*4 + l*32 + k] * (((x[i].q[j + k] >> (l*2)) & 3) - 1); } } }