diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 349db94c7..c20afaf3a 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -11371,40 +11371,68 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r __m256 accumf = _mm256_setzero_ps(); for (int i = 0; i < nb; ++i) { - { - __m256i x0 = _mm256_set_epi32(q1_3_grid[x[i].q[7]], q1_3_grid[x[i].q[6]], - q1_3_grid[x[i].q[5]], q1_3_grid[x[i].q[4]], - q1_3_grid[x[i].q[3]], q1_3_grid[x[i].q[2]], - q1_3_grid[x[i].q[1]], q1_3_grid[x[i].q[0]]); - __m256i y0 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i].qs)); - - __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i].d)); - - __m256 q = mul_sum_i8_pairs_float(x0, y0); - - accumf = _mm256_fmadd_ps(d, q, accumf); - } + // __m128i x12b = _mm_maskload_epi32((const int32_t *) x[i].q, _mm_set_epi32(0, -1, -1, -1)); + // __m128i x12b = _mm_insert_epi8(x12a, x[i].qs[0], 12); + // WARNING: reading 3 bytes further than necessary. It's faster than the above on my CPU, though. + __m128i x12b = _mm_loadu_si128((const __m128i_u *) x[i].q); + __m256i x12 = MM256_SET_M128I(x12b, x12b); { - __m256i x1 = _mm256_castsi128_si256(_mm_set_epi32(q1_3_grid[x[i].q[11]], q1_3_grid[x[i].q[10]], - q1_3_grid[x[i].q[9]], q1_3_grid[x[i].q[8]])); - __m256i x2 = _mm256_cvtepu8_epi16(_mm_maskload_epi32((const int32_t *) x[i].q, _mm_set_epi32(0, -1, -1, -1))); + __m256i x0l = _mm256_shuffle_epi8(x12, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1, + 4, -1, 4, -1, 4, -1, 4, -1, + 1, -1, 1, -1, 1, -1, 1, -1, + 0, -1, 0, -1, 0, -1, 0, -1)); + __m256i x0h = _mm256_shuffle_epi8(x12, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1, + 6, -1, 6, -1, 6, -1, 6, -1, + 3, -1, 3, -1, 3, -1, 3, -1, + 2, -1, 2, -1, 2, -1, 2, -1)); + __m256i x1l = _mm256_shuffle_epi8(x12, _mm256_set_epi8(7, -1, 6, -1, 5, -1, 4, -1, + 3, -1, 2, -1, 1, -1, 0, -1, + 9, -1, 9, -1, 9, -1, 9, -1, + 8, -1, 8, -1, 8, -1, 8, -1)); + __m256i x1h = _mm256_shuffle_epi8(x12, _mm256_set_epi8(12, -1, 12, -1, 12, -1, 12, -1, + 11, -1, 10, -1, 9, -1, 8, -1, + 11, -1, 11, -1, 11, -1, 11, -1, + 10, -1, 10, -1, 10, -1, 10, -1)); + const __m256i shift0 = _mm256_set_epi16(3, 9, 27, 81, + 3, 9, 27, 81, + 3, 9, 27, 81, + 3, 9, 27, 81); + const __m256i shift1l = _mm256_set_epi16(1, 1, 1, 1, + 1, 1, 1, 1, + 3, 9, 27, 81, + 3, 9, 27, 81); + const __m256i shift1h = _mm256_set_epi16(3, 9, 27, 81, + 1, 1, 1, 1, + 3, 9, 27, 81, + 3, 9, 27, 81); + x0l = _mm256_mullo_epi16(x0l, shift0); + x0h = _mm256_mullo_epi16(x0h, shift0); + x1l = _mm256_mullo_epi16(x1l, shift1l); + x1h = _mm256_mullo_epi16(x1h, shift1h); + x0l = _mm256_mulhi_epu16(x0l, _mm256_set1_epi16(3)); + x0h = _mm256_mulhi_epu16(x0h, _mm256_set1_epi16(3)); + x1l = _mm256_mulhi_epu16(x1l, _mm256_set1_epi16(3)); + x1h = _mm256_mulhi_epu16(x1h, _mm256_set1_epi16(3)); + x0l = _mm256_sub_epi16(x0l, _mm256_set1_epi16(1)); + x0h = _mm256_sub_epi16(x0h, _mm256_set1_epi16(1)); + x1l = _mm256_sub_epi16(x1l, _mm256_set1_epi16(1)); + x1h = _mm256_sub_epi16(x1h, _mm256_set1_epi16(1)); + + __m256i x0 = _mm256_packs_epi16(x0l, x0h); + __m256i x1 = _mm256_packs_epi16(x1l, x1h); + + __m256i y0 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i + 0].qs)); __m256i y1 = _mm256_lddqu_si256((const __m256i_u *) (y[2*i + 1].qs)); - x2 = _mm256_mulhi_epu16(x2, _mm256_set1_epi16(3 << 8)); - x2 = _mm256_sub_epi16(x2, _mm256_set1_epi16(1)); + __m256 d0 = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i].d)); + __m256 d1 = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i + 1].d)); - // TODO: reduce shuffling - x2 = _mm256_packs_epi16(x2, _mm256_setzero_si256()); - x2 = _mm256_permute4x64_epi64(x2, _MM_SHUFFLE(3, 1, 2, 0)); - __m128i x2_l = _mm_insert_epi32(_mm256_castsi256_si128(x2), q1_3_grid[x[i].qs[0]], 3); - x1 = _mm256_inserti128_si256(x1, x2_l, 1); + __m256 q0 = mul_sum_i8_pairs_float(x0, y0); + __m256 q1 = mul_sum_i8_pairs_float(x1, y1); - __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(y[2*i + 1].d)); - - __m256 q = mul_sum_i8_pairs_float(x1, y1); - - accumf = _mm256_fmadd_ps(d, q, accumf); + accumf = _mm256_fmadd_ps(d0, q0, accumf); + accumf = _mm256_fmadd_ps(d1, q1, accumf); } } diff --git a/gguf-py/gguf/quants.py b/gguf-py/gguf/quants.py index a2beb0d53..46820dce3 100644 --- a/gguf-py/gguf/quants.py +++ b/gguf-py/gguf/quants.py @@ -148,9 +148,9 @@ def __quantize_q1_3_rows(n: np.ndarray) -> np.ndarray: q48 = np.sum(q48 * pow3.reshape((1, 1, 4)), axis=2, keepdims=True).reshape((n_blocks, 12)) q4 = np.sum(q4 * pow3.reshape((1, 4)), axis=1, keepdims=True) q48 = q48 + (q12 * 81) - q = np.concatenate([q48, q4], axis=1); + q = np.concatenate([q48, q4], axis=1) q = ((q.astype(np.uint16) * 256) // 243).astype(np.uint8) - q = np.where(q != 0, q + 1, 0); + q = np.where(q != 0, q + 1, 0) return q.reshape(__quantize_q1_3_shape_change(shape)) @@ -170,4 +170,3 @@ def quantize_q1_3(data: np.ndarray): return __quantize_q1_3_lazy(data) else: return __quantize_q1_3_array(data) -