diff --git a/ggml.c b/ggml.c index 2ea4e68fd..814776381 100644 --- a/ggml.c +++ b/ggml.c @@ -450,6 +450,24 @@ static inline __m128i bytes_from_nibbles_16(const uint8_t * rsi) return bytes; } +// horizontally add 8 floats +static inline float hsum_float_8(const __m256 x) { + __m128 res = _mm256_extractf128_ps(x, 1); + res = _mm_add_ps(res, _mm256_castps256_ps128(x)); + res = _mm_add_ps(res, _mm_movehl_ps(res, res)); + res = _mm_add_ss(res, _mm_movehdup_ps(res)); + return _mm_cvtss_f32(res); +} + +// horizontally add 8 int32_t +static inline int hsum_i32_8(const __m256i a) { + const __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1)); + const __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); + const __m128i sum64 = _mm_add_epi32(hi64, sum128); + const __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); + return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); +} + #if __AVX2__ || __AVX512F__ // Unpack 32 4-bit fields into 32 bytes // The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval @@ -470,6 +488,24 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) return bytes; } +// add int16_t pairwise and return as float vector +static inline __m256 sum_i16_pairs_float(const __m256i x) { + const __m256i ones = _mm256_set1_epi16(1); + const __m256i summed_pairs = _mm256_madd_epi16(ones, x); + return _mm256_cvtepi32_ps(summed_pairs); +} + +// multiply int8_t, add results pairwise twice and return as float vector +static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) { + // Get absolute values of x vectors + const __m256i ax = _mm256_sign_epi8(x, x); + // Sign the values of the y vectors + const __m256i sy = _mm256_sign_epi8(y, x); + // Perform multiplication and create 16-bit values + const __m256i dot = _mm256_maddubs_epi16(ax, sy); + return sum_i16_pairs_float(dot); +} + static inline __m128i packNibbles( __m256i bytes ) { // Move bits within 16-bit lanes from 0000_abcd_0000_efgh into 0000_0000_abcd_efgh @@ -1273,29 +1309,6 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r } } -#ifdef __AVX2__ -// There is no better way of doing this? -// I guess not, AVX is not very good at horizontal sums. -// The commented solution for a hotrizontal sum was suggested by @pubby as being slightly -// faster than the solution below. As I don't have an AVX2 system handt right now to test, -// keeping the original. -// TODO: Please try and if it does make a differece, uncomment and remove the implementation below. -//static inline float horizontal_sum(__m256i a) { -// __m256i b = _mm256_castps_si256(_mm256_movehdup_ps(_mm256_castsi256_ps(a))); -// __m256i sum = _mm256_add_epi32(a, b); -// __m256i hi = _mm256_unpackhi_epi64(sum, sum); -// sum = _mm256_add_epi32(sum, hi); -// return _mm256_cvtsi256_si32(sum) + _mm256_extract_epi32(sum, 4); -//} -static inline float horizontal_sum(__m256i a) { - __m128i sum128 = _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extracti128_si256(a, 1)); - __m128i hi64 = _mm_unpackhi_epi64(sum128, sum128); - __m128i sum64 = _mm_add_epi32(hi64, sum128); - __m128i hi32 = _mm_shuffle_epi32(sum64, _MM_SHUFFLE(2, 3, 0, 1)); - return _mm_cvtsi128_si32(_mm_add_epi32(sum64, hi32)); -} -#endif - static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int k) { assert(k % QK8_0 == 0); const int nb = k / QK8_0; @@ -1384,9 +1397,8 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m256i i3 = _mm256_cvtps_epi32( v3 ); #if defined(__AVX2__) - // Compute the sum of the quants and set y[i].s - y[i].s = d * horizontal_sum(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); + y[i].s = d * hsum_i32_8(_mm256_add_epi32(_mm256_add_epi32(i0, i1), _mm256_add_epi32(i2, i3))); // Convert int32 to int16 i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 @@ -1413,6 +1425,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int __m128i ni6 = _mm256_castsi256_si128( i3 ); __m128i ni7 = _mm256_extractf128_si256( i3, 1); + // Compute the sum of the quants and set y[i].s + const __m128i s0 = _mm_add_epi32(_mm_add_epi32(ni0, ni1), _mm_add_epi32(ni2, ni3)); + const __m128i s1 = _mm_add_epi32(_mm_add_epi32(ni4, ni5), _mm_add_epi32(ni6, ni7)); + y[i].s = d * hsum_i32_8(_mm256_set_m128i(s1, s0)); + // Convert int32 to int16 ni0 = _mm_packs_epi32( ni0, ni1 ); ni2 = _mm_packs_epi32( ni2, ni3 ); @@ -1430,14 +1447,6 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int // scalar quantize_row_q8_0_reference(x, y, k); #endif -#if defined __AVX__ - // TODO: vectorize this - for (int i=0; i