mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
Improve AVX2 for vec_dot_q4_3_q8_0 (#1138)
This commit is contained in:
parent
c6524f46eb
commit
53c8434398
14
ggml.c
14
ggml.c
@ -2947,6 +2947,7 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
|
|||||||
#elif defined(__AVX2__)
|
#elif defined(__AVX2__)
|
||||||
// Initialize accumulator with zeros
|
// Initialize accumulator with zeros
|
||||||
__m256 acc = _mm256_setzero_ps();
|
__m256 acc = _mm256_setzero_ps();
|
||||||
|
float summs = 0.0f;
|
||||||
|
|
||||||
// Main loop
|
// Main loop
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
@ -2954,9 +2955,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
|
|||||||
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
const __m128 d1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].d));
|
||||||
const __m256 dx = _mm256_set_m128(d1, d0);
|
const __m256 dx = _mm256_set_m128(d1, d0);
|
||||||
|
|
||||||
const __m128 m0 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 0].m));
|
summs += GGML_FP16_TO_FP32(x[2*i + 0].m) * y[i].s0
|
||||||
const __m128 m1 = _mm_set1_ps(GGML_FP16_TO_FP32(x[2*i + 1].m));
|
+ GGML_FP16_TO_FP32(x[2*i + 1].m) * y[i].s1;
|
||||||
const __m256 mx = _mm256_set_m128(m1, m0);
|
|
||||||
|
|
||||||
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
const __m128i bx0 = bytes_from_nibbles_16(x[2*i + 0].qs);
|
||||||
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
const __m128i bx1 = bytes_from_nibbles_16(x[2*i + 1].qs);
|
||||||
@ -2965,16 +2965,12 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void *
|
|||||||
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
const __m256 dy = _mm256_broadcast_ss(&y[i].d);
|
||||||
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
|
||||||
|
|
||||||
const __m256i syi = _mm256_maddubs_epi16(_mm256_set1_epi8(1), by);
|
|
||||||
const __m256 syf = sum_i16_pairs_float(syi);
|
|
||||||
|
|
||||||
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
const __m256 q = mul_sum_i8_pairs_float(bx, by);
|
||||||
|
|
||||||
const __m256 sxy = _mm256_fmadd_ps(q, dx, _mm256_mul_ps(mx, syf));
|
acc = _mm256_fmadd_ps(q, _mm256_mul_ps(dx, dy), acc);
|
||||||
acc = _mm256_fmadd_ps(sxy, dy, acc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = hsum_float_8(acc);
|
*s = hsum_float_8(acc) + summs;
|
||||||
#else
|
#else
|
||||||
// scalar
|
// scalar
|
||||||
float sumf = 0.0;
|
float sumf = 0.0;
|
||||||
|
Loading…
Reference in New Issue
Block a user