From 66ea164e1d6a0b95628123ce39f66ab76b8a69f0 Mon Sep 17 00:00:00 2001 From: Matvey Soloviev Date: Thu, 23 Mar 2023 04:28:51 +0100 Subject: [PATCH] Kahan summation on Q4_1 --- ggml.c | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/ggml.c b/ggml.c index 8f405468d..890b4077f 100644 --- a/ggml.c +++ b/ggml.c @@ -1704,6 +1704,9 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void // Accumulator for constant offsets __m128 acc_offset = _mm_setzero_ps(); //0.0f; + __m256 acc_err = _mm256_setzero_ps(); + __m128 acc_offset_err = _mm_setzero_ps(); + // Main loop for (int i = 0; i < nb; ++i) { const float * m0 = (const float *) (pm0 + i*bs); @@ -1756,17 +1759,30 @@ inline static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void __m256i sumsi = _mm256_or_si256( xsumi, _mm256_slli_si256( ysumi, 4 ) ); __m256 sums = _mm256_cvtepi32_ps( sumsi ); - // Apply the scales, and accumulate - // acc += d0*m1*x + d1*m0*y - acc = _mm256_fmadd_ps( cross_scales, sums, acc ); - // Convert int32_t to float __m256 p = _mm256_cvtepi32_ps( i32 ); - // acc += d0*d1*x*y - acc = _mm256_fmadd_ps( scale_01, p, acc ); + + // Apply the scales, and accumulate + // Use Kahan error compensation + // acc += d0*m1*x + d1*m0*y + d0*d1*x*y + __m256 delta = _mm256_mul_ps( scale_01, p ); + delta = _mm256_fmadd_ps( cross_scales, sums, delta ); + delta = _mm256_sub_ps( delta, acc_err ); + + __m256 acc_next = _mm256_add_ps( acc, delta ); + acc_err = _mm256_sub_ps( _mm256_sub_ps( acc_next, acc ), delta ); + + acc = acc_next; + + __m128 offs_delta = _mm_mul_ss( _mm256_castps256_ps128( m0v ), _mm256_castps256_ps128( m1v ) ); + offs_delta = _mm_sub_ss( offs_delta, acc_offset_err ); + + __m128 offs_next = _mm_add_ss( acc_offset, offs_delta ); + acc_offset_err = _mm_sub_ss( _mm_sub_ss( offs_next, acc_offset ), offs_delta ); + acc_offset = offs_next; // acc_offset += m0*m1 (avoid reloading from RAM) - acc_offset = _mm_fmadd_ss( _mm256_castps256_ps128( m0v ), _mm256_castps256_ps128( m1v ), acc_offset ); + //acc_offset = _mm_fmadd_ss( _mm256_castps256_ps128( m0v ), _mm256_castps256_ps128( m1v ), acc_offset ); } // Return horizontal sum of the acc vector