Update quantize_row_q4_0 for AVX/AVX2

This commit is contained in:
Håkon H. Hitland 2023-04-05 01:02:43 +02:00 committed by Georgi Gerganov
parent 3698f79e6a
commit 5d5f2b2efa

65
ggml.c
View File

@ -794,22 +794,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
__m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32;
// Compute max(abs(e)) for the block
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
// Compute max for the block
__m256 max = _mm256_max_ps( v0, v1 );
__m256 maxTmp = _mm256_max_ps( v2, v3 );
max = _mm256_max_ps( max, maxTmp );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
// Compute min for the block
__m256 min = _mm256_min_ps( v0, v1 );
__m256 minTmp = _mm256_min_ps( v2, v3 );
min = _mm256_min_ps( min, minTmp );
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
const float minScalar = _mm_cvtss_f32( min4 );
// Quantize these floats
const float d = maxScalar / 7.0f;
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
const float d = magnitude / -8.0f;
y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier
@ -842,9 +851,11 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 );
i0 = _mm256_permutevar8x32_epi32( i0, perm );
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
const __m256i off = _mm256_set1_epi8( 8 );
i0 = _mm256_add_epi8( i0, off );
const __m256i maxNibble = _mm256_set1_epi8( 15 );
i0 = _mm256_min_epi8( i0, maxNibble );
// Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( i0 );
@ -859,22 +870,31 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
__m256 v3 = _mm256_loadu_ps( x + 24 );
x += 32;
// Compute max(abs(e)) for the block
const __m256 signBit = _mm256_set1_ps( -0.0f );
__m256 maxAbs = _mm256_andnot_ps( signBit, v0 );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) );
maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) );
// Compute max for the block
__m256 max = _mm256_max_ps( v0, v1 );
__m256 maxTmp = _mm256_max_ps( v2, v3 );
max = _mm256_max_ps( max, maxTmp );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) );
__m128 max4 = _mm_max_ps( _mm256_extractf128_ps( max, 1 ), _mm256_castps256_ps128( max ) );
max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) );
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) );
const float maxScalar = _mm_cvtss_f32( max4 );
// Compute min for the block
__m256 min = _mm256_min_ps( v0, v1 );
__m256 minTmp = _mm256_min_ps( v2, v3 );
min = _mm256_min_ps( min, minTmp );
__m128 min4 = _mm_min_ps( _mm256_extractf128_ps( min, 1 ), _mm256_castps256_ps128( min ) );
min4 = _mm_min_ps( min4, _mm_movehl_ps( min4, min4 ) );
min4 = _mm_min_ss( min4, _mm_movehdup_ps( min4 ) );
const float minScalar = _mm_cvtss_f32( min4 );
// Quantize these floats
const float d = maxScalar / 7.0f;
const float magnitude = maxScalar >= fabsf(minScalar) ? maxScalar : minScalar;
const float d = magnitude / -8.0f;
y[i].d = d;
const float id = ( maxScalar != 0.0f ) ? 7.0f / maxScalar : 0.0f;
const float id = ( magnitude != 0.0f ) ? -8.0f / magnitude : 0.0f;
const __m256 mul = _mm256_set1_ps( id );
// Apply the multiplier
@ -915,10 +935,13 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
ni0 = _mm_packs_epi16( ni0, ni2 );
ni4 = _mm_packs_epi16( ni4, ni6 );
// Apply offset to translate the range from [ -7 .. +7 ] into [ +1 .. +15 ]
const __m128i off = _mm_set1_epi8( 8);
// Apply offset and clamp to translate the range from [ -8 .. +8 ] into [ +0 .. +15 ]
const __m128i off = _mm_set1_epi8( 8 );
ni0 = _mm_add_epi8( ni0, off );
ni4 = _mm_add_epi8( ni4, off );
const __m128i maxNibble = _mm_set1_epi8( 15 );
ni0 = _mm_min_epi8( ni0, maxNibble );
ni4 = _mm_min_epi8( ni4, maxNibble );
// Compress the vector into 4 bit/value, and store
__m128i res = packNibbles( ni0, ni4 );