ggml : speed-up Q5_0 + Q5_1 at 4 threads

This commit is contained in:
Georgi Gerganov 2023-05-10 22:58:45 +03:00
parent ffd76e18d6
commit e116eb638c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

223
ggml.c
View File

@ -339,8 +339,9 @@ static float table_f32_f16[1 << 16];
#define B7(c,s,n) B6(c,s,n ## c), B6(c,s,n ## s)
#define B8(c,s ) B7(c,s, c), B7(c,s, s)
// precomputed tables for expanding 8bits to 8 bytes (shl 4)
static const uint64_t table_b2b_u[1 << 8] = { B8(00, 10) };
// precomputed tables for expanding 8bits to 8 bytes:
static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
#endif
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
@ -2307,68 +2308,102 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
const block_q8_0 * restrict y = vy;
#if defined(__ARM_NEON)
float32x4_t sumv = vdupq_n_f32(0.0f);
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
uint64_t tmp[4];
uint32_t qh0;
uint32_t qh1;
for (int i = 0; i < nb; ++i) {
uint64_t tmp0[4];
uint64_t tmp1[4];
for (int i = 0; i < nb; i += 2) {
const block_q5_0 * restrict x0 = &x[i];
const block_q5_0 * restrict x1 = &x[i + 1];
const block_q8_0 * restrict y0 = &y[i];
const block_q8_0 * restrict y1 = &y[i + 1];
const uint8x16_t m4b = vdupq_n_u8(0x0F);
const int8x16_t s16b = vdupq_n_s8(0x10);
const uint8x16_t m4b = vdupq_n_u8(0x0F);
// extract the 5th bit
uint32_t qh;
memcpy(&qh, x0->qh, sizeof(qh));
// extract the 5th bit via lookup table ((!b) << 4)
memcpy(&qh0, x0->qh, sizeof(qh0));
memcpy(&qh1, x1->qh, sizeof(qh1));
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
tmp[3] = table_b2b_u[(qh >> 24) ];
tmp0[0] = table_b2b_1[(qh0 >> 0) & 0xFF];
tmp0[1] = table_b2b_1[(qh0 >> 8) & 0xFF];
tmp0[2] = table_b2b_1[(qh0 >> 16) & 0xFF];
tmp0[3] = table_b2b_1[(qh0 >> 24) ];
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
tmp1[0] = table_b2b_1[(qh1 >> 0) & 0xFF];
tmp1[1] = table_b2b_1[(qh1 >> 8) & 0xFF];
tmp1[2] = table_b2b_1[(qh1 >> 16) & 0xFF];
tmp1[3] = table_b2b_1[(qh1 >> 24) ];
const uint8x16_t v0 = vld1q_u8(x0->qs);
const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
// 4-bit -> 8-bit
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, m4b));
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// add high bit and sub 16
const int8x16_t v0lf = vsubq_s8(vorrq_s8(v0l, qhl), s16b);
const int8x16_t v0hf = vsubq_s8(vorrq_s8(v0h, qhh), s16b);
// add high bit and sub 16 (equivalent to sub 0x10 when bit is zero)
const int8x16_t v0_0lf = vsubq_s8(v0_0l, qhl0);
const int8x16_t v0_0hf = vsubq_s8(v0_0h, qhh0);
const int8x16_t v0_1lf = vsubq_s8(v0_1l, qhl1);
const int8x16_t v0_1hf = vsubq_s8(v0_1h, qhh1);
// load y
const int8x16_t v1l = vld1q_s8(y0->qs);
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
const float x0d = GGML_FP16_TO_FP32(x0->d);
const float x1d = GGML_FP16_TO_FP32(x1->d);
#if defined(__ARM_FEATURE_DOTPROD)
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
#endif
}
*s = vaddvq_f32(sumv);
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
#elif defined(__wasm_simd128__)
v128_t sumv = wasm_f32x4_splat(0.0f);
uint32_t qh;
uint64_t tmp[4];
// TODO: check if unrolling this is better
for (int i = 0; i < nb; ++i) {
const block_q5_0 * restrict x0 = &x[i];
const block_q8_0 * restrict y0 = &y[i];
@ -2377,13 +2412,12 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
const v128_t s16b = wasm_i8x16_splat(0x10);
// extract the 5th bit
uint32_t qh;
memcpy(&qh, x0->qh, sizeof(qh));
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
tmp[3] = table_b2b_u[(qh >> 24) ];
tmp[0] = table_b2b_1[(qh >> 0) & 0xFF];
tmp[1] = table_b2b_1[(qh >> 8) & 0xFF];
tmp[2] = table_b2b_1[(qh >> 16) & 0xFF];
tmp[3] = table_b2b_1[(qh >> 24) ];
const v128_t qhl = wasm_v128_load(tmp + 0);
const v128_t qhh = wasm_v128_load(tmp + 2);
@ -2395,8 +2429,8 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
const v128_t v0h = wasm_u8x16_shr(v0, 4);
// add high bit and sub 16
const v128_t v0lf = wasm_i8x16_sub(wasm_v128_or(v0l, qhl), s16b);
const v128_t v0hf = wasm_i8x16_sub(wasm_v128_or(v0h, qhh), s16b);
const v128_t v0lf = wasm_i8x16_sub(v0l, qhl);
const v128_t v0hf = wasm_i8x16_sub(v0h, qhh);
// load y
const v128_t v1l = wasm_v128_load(y0->qs);
@ -2488,69 +2522,107 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const block_q8_1 * restrict y = vy;
#if defined(__ARM_NEON)
float32x4_t sumv = vdupq_n_f32(0.0f);
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
float summs = 0.0f;
float summs0 = 0.0f;
float summs1 = 0.0f;
uint64_t tmp[4];
uint32_t qh0;
uint32_t qh1;
for (int i = 0; i < nb; ++i) {
uint64_t tmp0[4];
uint64_t tmp1[4];
for (int i = 0; i < nb; i += 2) {
const block_q5_1 * restrict x0 = &x[i];
const block_q5_1 * restrict x1 = &x[i + 1];
const block_q8_1 * restrict y0 = &y[i];
const block_q8_1 * restrict y1 = &y[i + 1];
summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
const uint8x16_t m4b = vdupq_n_u8(0x0F);
// extract the 5th bit
uint32_t qh;
memcpy(&qh, x0->qh, sizeof(qh));
summs0 += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1);
summs1 += GGML_FP16_TO_FP32(x1->m) * (y1->s0 + y1->s1);
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
tmp[3] = table_b2b_u[(qh >> 24) ];
// extract the 5th bit via lookup table ((b) << 4)
memcpy(&qh0, x0->qh, sizeof(qh0));
memcpy(&qh1, x1->qh, sizeof(qh1));
const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0));
const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 2));
tmp0[0] = table_b2b_0[(qh0 >> 0) & 0xFF];
tmp0[1] = table_b2b_0[(qh0 >> 8) & 0xFF];
tmp0[2] = table_b2b_0[(qh0 >> 16) & 0xFF];
tmp0[3] = table_b2b_0[(qh0 >> 24) ];
const uint8x16_t v0 = vld1q_u8(x0->qs);
tmp1[0] = table_b2b_0[(qh1 >> 0) & 0xFF];
tmp1[1] = table_b2b_0[(qh1 >> 8) & 0xFF];
tmp1[2] = table_b2b_0[(qh1 >> 16) & 0xFF];
tmp1[3] = table_b2b_0[(qh1 >> 24) ];
const int8x16_t qhl0 = vld1q_s8((const int8_t *)(tmp0 + 0));
const int8x16_t qhh0 = vld1q_s8((const int8_t *)(tmp0 + 2));
const int8x16_t qhl1 = vld1q_s8((const int8_t *)(tmp1 + 0));
const int8x16_t qhh1 = vld1q_s8((const int8_t *)(tmp1 + 2));
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
// 4-bit -> 8-bit
const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F)));
const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4));
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, m4b));
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8 (v0_1, m4b));
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
// add
const int8x16_t v0lf = vorrq_s8(v0l, qhl);
const int8x16_t v0hf = vorrq_s8(v0h, qhh);
// add 5th bit
const int8x16_t v0_0lf = vorrq_s8(v0_0l, qhl0);
const int8x16_t v0_0hf = vorrq_s8(v0_0h, qhh0);
const int8x16_t v0_1lf = vorrq_s8(v0_1l, qhl1);
const int8x16_t v0_1hf = vorrq_s8(v0_1h, qhh1);
// load y
const int8x16_t v1l = vld1q_s8(y0->qs);
const int8x16_t v1h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_0l = vld1q_s8(y0->qs);
const int8x16_t v1_0h = vld1q_s8(y0->qs + 16);
const int8x16_t v1_1l = vld1q_s8(y1->qs);
const int8x16_t v1_1h = vld1q_s8(y1->qs + 16);
const float x0d = GGML_FP16_TO_FP32(x0->d);
const float x1d = GGML_FP16_TO_FP32(x1->d);
#if defined(__ARM_FEATURE_DOTPROD)
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0lf, v1l),
vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_0lf, v1_0l),
vdotq_s32(vdupq_n_s32(0), v0_0hf, v1_0h))), x0d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(
vdotq_s32(vdupq_n_s32(0), v0_1lf, v1_1l),
vdotq_s32(vdupq_n_s32(0), v0_1hf, v1_1h))), x1d*y1->d);
#else
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h));
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lf), vget_low_s8 (v1_0l));
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lf), vget_high_s8(v1_0l));
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hf), vget_low_s8 (v1_0h));
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hf), vget_high_s8(v1_0h));
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lf), vget_low_s8 (v1_1l));
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lf), vget_high_s8(v1_1l));
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hf), vget_low_s8 (v1_1h));
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hf), vget_high_s8(v1_1h));
const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h));
const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h));
const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h));
const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h));
sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d);
sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1d*y1->d);
#endif
}
*s = vaddvq_f32(sumv) + summs;
*s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + summs0 + summs1;
#elif defined(__wasm_simd128__)
v128_t sumv = wasm_f32x4_splat(0.0f);
float summs = 0.0f;
uint32_t qh;
uint64_t tmp[4];
for (int i = 0; i < nb; ++i) {
@ -2562,13 +2634,12 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
const v128_t m4b = wasm_i8x16_splat(0x0F);
// extract the 5th bit
uint32_t qh;
memcpy(&qh, x0->qh, sizeof(qh));
tmp[0] = table_b2b_u[(qh >> 0) & 0xFF];
tmp[1] = table_b2b_u[(qh >> 8) & 0xFF];
tmp[2] = table_b2b_u[(qh >> 16) & 0xFF];
tmp[3] = table_b2b_u[(qh >> 24) ];
tmp[0] = table_b2b_0[(qh >> 0) & 0xFF];
tmp[1] = table_b2b_0[(qh >> 8) & 0xFF];
tmp[2] = table_b2b_0[(qh >> 16) & 0xFF];
tmp[3] = table_b2b_0[(qh >> 24) ];
const v128_t qhl = wasm_v128_load(tmp + 0);
const v128_t qhh = wasm_v128_load(tmp + 2);