ggml-quants : ARM NEON vec_dot for q2_2 and q1_3

This commit is contained in:
Francis Couture-Harpin 2024-06-25 01:32:14 -04:00
parent 638ad52f87
commit 9465ec6e12
2 changed files with 182 additions and 18 deletions

View File

@ -686,6 +686,10 @@ void quantize_row_q2_2_reference(const float * restrict x, block_q2_2 * restrict
}
}
void quantize_row_q2_2(const float * restrict x, void * restrict y, int64_t k) {
quantize_row_q2_2_reference(x, y, k);
}
// reference implementation for deterministic creation of model files
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
static const int qk = QK4_0;
@ -3900,17 +3904,81 @@ void ggml_vec_dot_q2_2_q8_0(int n, float * restrict s, size_t bs, const void * r
}
*s = hsum_float_8(acc);
#elif defined(__ARM_NEON)
float sumf0 = 0.0f;
float sumf1 = 0.0f;
const uint8x8_t mask = vdup_n_u8(3);
const int8x8_t offset = vdup_n_s8(2);
const int leftovers = nb % 2;
for (int i = 0; i < nb - leftovers; i += 2) {
const uint8x8_t xq8_0 = vld1_u8(x[0].qs);
const uint8x8_t xq8_1 = vld1_u8(x[1].qs);
const int8x8_t xq8_0_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8_0, mask)), offset);
const int8x8_t xq8_0_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 2), mask)), offset);
const int8x8_t xq8_0_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 4), mask)), offset);
const int8x8_t xq8_0_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_0, 6), mask)), offset);
const int8x8_t xq8_1_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8_1, mask)), offset);
const int8x8_t xq8_1_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 2), mask)), offset);
const int8x8_t xq8_1_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 4), mask)), offset);
const int8x8_t xq8_1_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8_1, 6), mask)), offset);
const int8x16_t xq8_0_l = vcombine_s8(xq8_0_0, xq8_0_1);
const int8x16_t xq8_0_h = vcombine_s8(xq8_0_2, xq8_0_3);
const int8x16_t xq8_1_l = vcombine_s8(xq8_1_0, xq8_1_1);
const int8x16_t xq8_1_h = vcombine_s8(xq8_1_2, xq8_1_3);
const int8x16_t yq8_0_l = vld1q_s8(y[0].qs + 0);
const int8x16_t yq8_0_h = vld1q_s8(y[0].qs + 16);
const int8x16_t yq8_1_l = vld1q_s8(y[1].qs + 0);
const int8x16_t yq8_1_h = vld1q_s8(y[1].qs + 16);
const int16x8_t dot0 = vaddq_s16(vpaddlq_s8(vmulq_s8(xq8_0_l, yq8_0_l)), vpaddlq_s8(vmulq_s8(xq8_0_h, yq8_0_h)));
const int16x8_t dot1 = vaddq_s16(vpaddlq_s8(vmulq_s8(xq8_1_l, yq8_1_l)), vpaddlq_s8(vmulq_s8(xq8_1_h, yq8_1_h)));
sumf0 += GGML_FP16_TO_FP32(y[0].d) * (float) vaddlvq_s16(dot0);
sumf1 += GGML_FP16_TO_FP32(y[1].d) * (float) vaddlvq_s16(dot1);
x += 2;
y += 2;
}
// one block at a time
for (int i = nb - leftovers; i < nb; ++i) {
const uint8x8_t xq8 = vld1_u8(x->qs);
const int8x8_t xq8_0 = vsub_s8(vreinterpret_s8_u8(vand_u8(xq8, mask)), offset);
const int8x8_t xq8_1 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 2), mask)), offset);
const int8x8_t xq8_2 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 4), mask)), offset);
const int8x8_t xq8_3 = vsub_s8(vreinterpret_s8_u8(vand_u8(vshr_n_u8(xq8, 6), mask)), offset);
const int8x16_t xq8_l = vcombine_s8(xq8_0, xq8_1);
const int8x16_t xq8_h = vcombine_s8(xq8_2, xq8_3);
const int8x16_t yq8_l = vld1q_s8(y->qs + 0);
const int8x16_t yq8_h = vld1q_s8(y->qs + 16);
const int16x8_t dot0 = vpaddlq_s8(vmulq_s8(xq8_l, yq8_l));
const int16x8_t dot1 = vpaddlq_s8(vmulq_s8(xq8_h, yq8_h));
sumf0 += GGML_FP16_TO_FP32(y->d) * (float) vaddlvq_s16(vaddq_s16(dot0, dot1));
x += 1;
y += 1;
}
*s = sumf0 + sumf1;
#else
float sumf = 0.0;
float sumf = 0.0f;
for (int i = 0; i < nb; i++) {
int sumi = 0;
for (int j = 0; j < qk / 4; j++) {
const uint8_t weight = x[i].qs[j];
sumi += (int)y[i].qs[j + 0*qk/4] * ((weight >> 0) & 3) - 2;
sumi += (int)y[i].qs[j + 1*qk/4] * ((weight >> 2) & 3) - 2;
sumi += (int)y[i].qs[j + 2*qk/4] * ((weight >> 4) & 3) - 2;
sumi += (int)y[i].qs[j + 3*qk/4] * ((weight >> 6) & 3) - 2;
sumi += (int)y[i].qs[j + 0*qk/4] * (((weight >> 0) & 3) - 2);
sumi += (int)y[i].qs[j + 1*qk/4] * (((weight >> 2) & 3) - 2);
sumi += (int)y[i].qs[j + 2*qk/4] * (((weight >> 4) & 3) - 2);
sumi += (int)y[i].qs[j + 3*qk/4] * (((weight >> 6) & 3) - 2);
}
sumf += (float)(sumi)*(GGML_FP16_TO_FP32(y[i].d));
}
@ -11314,27 +11382,27 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r
for (int i = 0; i < nb; ++i) {
// const __m128i x12a = _mm_maskload_epi32((const int32_t *) x, _mm_set_epi32(0, -1, -1, -1));
// const __m128i x12b = _mm_insert_epi8(x12a, x->qs[0], 12);
// const __m128i x13b = _mm_insert_epi8(x12a, x->qs[0], 12);
// WARNING: reading 3 bytes further than necessary.
// It's measurably faster than a masked load on an Intel Core m3-8100Y
const __m128i x12b = _mm_loadu_si128((const __m128i_u *) x);
const __m256i x12 = MM256_SET_M128I(x12b, x12b);
const __m128i x13b = _mm_loadu_si128((const __m128i_u *) x);
const __m256i x13 = MM256_SET_M128I(x13b, x13b);
{
// pre-shift the values by 8 bits, and prepare the layout for later packing
__m256i x0l = _mm256_shuffle_epi8(x12, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
__m256i x0l = _mm256_shuffle_epi8(x13, _mm256_set_epi8(5, -1, 5, -1, 5, -1, 5, -1,
4, -1, 4, -1, 4, -1, 4, -1,
1, -1, 1, -1, 1, -1, 1, -1,
0, -1, 0, -1, 0, -1, 0, -1));
__m256i x0h = _mm256_shuffle_epi8(x12, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
__m256i x0h = _mm256_shuffle_epi8(x13, _mm256_set_epi8(7, -1, 7, -1, 7, -1, 7, -1,
6, -1, 6, -1, 6, -1, 6, -1,
3, -1, 3, -1, 3, -1, 3, -1,
2, -1, 2, -1, 2, -1, 2, -1));
__m256i x1l = _mm256_shuffle_epi8(x12, _mm256_set_epi8(7, -1, 6, -1, 5, -1, 4, -1,
__m256i x1l = _mm256_shuffle_epi8(x13, _mm256_set_epi8(7, -1, 6, -1, 5, -1, 4, -1,
3, -1, 2, -1, 1, -1, 0, -1,
9, -1, 9, -1, 9, -1, 9, -1,
8, -1, 8, -1, 8, -1, 8, -1));
__m256i x1h = _mm256_shuffle_epi8(x12, _mm256_set_epi8(12, -1, 12, -1, 12, -1, 12, -1,
__m256i x1h = _mm256_shuffle_epi8(x13, _mm256_set_epi8(12, -1, 12, -1, 12, -1, 12, -1,
11, -1, 10, -1, 9, -1, 8, -1,
11, -1, 11, -1, 11, -1, 11, -1,
10, -1, 10, -1, 10, -1, 10, -1));
@ -11385,6 +11453,88 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r
}
*s = hsum_float_8(accumf);
#elif defined(__ARM_NEON)
static const uint8_t k_mask0[16] = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
static const uint8_t k_mask1[16] = {4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7};
static const uint8_t k_mask2[16] = {8, 8, 8, 8, 9, 9, 9, 9, 10, 10, 10, 10, 11, 11, 11, 11};
static const uint8_t k_mask3[16] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 12, 12, 12};
static const uint8_t k_shift0[16] = {81, 27, 9, 3, 81, 27, 9, 3, 81, 27, 9, 3, 81, 27, 9, 3};
static const uint8_t k_shift3[16] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 81, 27, 9, 3};
// float32x4_t sumv0 = vdupq_n_f32(0.0f);
// float32x4_t sumv1 = vdupq_n_f32(0.0f);
float sumf0 = 0.0f;
float sumf1 = 0.0f;
const uint8x16_t mask0 = vld1q_u8(k_mask0);
const uint8x16_t mask1 = vld1q_u8(k_mask1);
const uint8x16_t mask2 = vld1q_u8(k_mask2);
const uint8x16_t mask3 = vld1q_u8(k_mask3);
const uint8x16_t shift0 = vld1q_u8(k_shift0);
const uint8x16_t shift3 = vld1q_u8(k_shift3);
const int8x16_t one = vdupq_n_s8(1);
for (int i = 0; i < nb; ++i) {
// WARNING: reading 3 bytes further than necessary
const uint8x16_t x13b = vld1q_u8((const uint8_t *) x);
uint8x16_t x0 = vqtbl1q_u8(x13b, mask0);
uint8x16_t x1 = vqtbl1q_u8(x13b, mask1);
uint8x16_t x2 = vqtbl1q_u8(x13b, mask2);
uint8x16_t x3 = vqtbl1q_u8(x13b, mask3);
x0 = vmulq_u8(x0, shift0);
x1 = vmulq_u8(x1, shift0);
x2 = vmulq_u8(x2, shift0);
x3 = vmulq_u8(x3, shift3);
// multiply by 3 and keep the 2 bits above 8 bits
x0 = vshrq_n_u8(vhaddq_u8(x0, vshrq_n_u8(x0, 1)), 6);
x1 = vshrq_n_u8(vhaddq_u8(x1, vshrq_n_u8(x1, 1)), 6);
x2 = vshrq_n_u8(vhaddq_u8(x2, vshrq_n_u8(x2, 1)), 6);
x3 = vshrq_n_u8(vhaddq_u8(x3, vshrq_n_u8(x3, 1)), 6);
// 0, 1, 2 => -1, 0, 1
int8x16_t x0i = vsubq_s8(vreinterpretq_s8_u8(x0), one);
int8x16_t x1i = vsubq_s8(vreinterpretq_s8_u8(x1), one);
int8x16_t x2i = vsubq_s8(vreinterpretq_s8_u8(x2), one);
int8x16_t x3i = vsubq_s8(vreinterpretq_s8_u8(x3), one);
const int8x16_t y0 = vld1q_s8(y[0].qs + 0);
const int8x16_t y1 = vld1q_s8(y[0].qs + 16);
const int8x16_t y2 = vld1q_s8(y[1].qs + 0);
const int8x16_t y3 = vld1q_s8(y[1].qs + 16);
// const int32x4_t p0 = vpaddlq_s16(vaddq_s16(vpaddlq_s8(x0i), vpaddlq_s8(x1i)));
// const int32x4_t p1 = vpaddlq_s16(vaddq_s16(vpaddlq_s8(x2i), vpaddlq_s8(x3i)));
// there's no direct equivalent to _mm_sign_epi8, unfortunately
x0i = vmulq_s8(x0i, y0);
x1i = vmulq_s8(x1i, y1);
x2i = vmulq_s8(x2i, y2);
x3i = vmulq_s8(x3i, y3);
// overall 18.5% faster than with vector sums on a cortex-A72
sumf0 += GGML_FP16_TO_FP32(y[0].d) * (float) vaddlvq_s16(vaddq_s16(vpaddlq_s8(x0i), vpaddlq_s8(x1i)));
sumf1 += GGML_FP16_TO_FP32(y[1].d) * (float) vaddlvq_s16(vaddq_s16(vpaddlq_s8(x2i), vpaddlq_s8(x3i)));
// const int32x4_t p0 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), x0i, y0), x1i, y1);
// const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), x2i, y2), x3i, y3);
// sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p0), GGML_FP16_TO_FP32(y[0].d));
// sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p1), GGML_FP16_TO_FP32(y[1].d));
y += 2;
x += 1;
}
// *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
*s = sumf0 + sumf1;
#else
float sumf = 0.0f;
@ -11393,34 +11543,36 @@ void ggml_vec_dot_q1_3_q8_0(int n, float * restrict s, size_t bs, const void * r
for (int j = 0; j < 8; ++j) {
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[j]);
for (int k = 0; k < 4; ++k) {
sum += xj[k] * (int16_t) y[2*i].qs[4*j + k];
sum += xj[k] * (int16_t) y->qs[4*j + k];
}
}
sumf += GGML_FP16_TO_FP32(y[2*i].d) * sum;
sumf += GGML_FP16_TO_FP32(y->d) * sum;
y += 1;
sum = 0;
for (int j = 0; j < 4; ++j) {
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].q[8 + j]);
for (int k = 0; k < 4; ++k) {
sum += xj[k] * (int16_t) y[2*i + 1].qs[4*j + k];
sum += xj[k] * (int16_t) y->qs[4*j + k];
}
}
for (size_t j = 0; j < 12; ++j) {
uint16_t xj = x[i].q[j];
xj = (xj * 3) >> 8;
sum += ((int16_t) xj - 1) * (int16_t) y[2*i + 1].qs[16 + j];
sum += ((int16_t) xj - 1) * (int16_t) y->qs[16 + j];
}
{
const int8_t * xj = (const int8_t *) (q1_3_grid + x[i].qs[0]);
for (int k = 0; k < 4; ++k) {
sum += (int16_t) xj[k] * (int16_t) y[2*i + 1].qs[28 + k];
sum += (int16_t) xj[k] * (int16_t) y->qs[28 + k];
}
}
sumf += GGML_FP16_TO_FP32(y[2*i + 1].d) * sum;
sumf += GGML_FP16_TO_FP32(y->d) * sum;
y += 1;
}
*s = sumf;

View File

@ -823,6 +823,18 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_Q2_2] = {
.type_name = "q2_2",
.blck_size = QK2_2,
.type_size = sizeof(block_q2_2),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q2_2,
.from_float = quantize_row_q2_2,
.from_float_reference = (ggml_from_float_t) quantize_row_q2_2_reference,
.vec_dot = ggml_vec_dot_q2_2_q8_0,
.vec_dot_type = GGML_TYPE_Q8_0,
.nrows = 1,
},
[GGML_TYPE_Q1_3] = {
.type_name = "q1_3",
.blck_size = QK1_3,