mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
ggml : speed-up ggml_vec_dot_q4_1() ARM_NEON + 32-bit ARM support (#900)
* ggml : speed-up q4_1 ARM_NEON by ~5% * ggml : implement vaddvq when missing * ggml : implement vminvq and vmaxvq when missing * ggml : implement vzip when missing * ggml : fix comment * ggml : try to use correct ifdef
This commit is contained in:
parent
9190e8eac8
commit
d990e3fffc
160
ggml.c
160
ggml.c
@ -491,6 +491,77 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
||||
}
|
||||
#endif
|
||||
|
||||
#if __ARM_NEON
|
||||
|
||||
#if !defined(__aarch64__)
|
||||
|
||||
inline static uint16_t vaddvq_u8(uint8x16_t v) {
|
||||
return
|
||||
(uint16_t)vgetq_lane_u8(v, 0) + (uint16_t)vgetq_lane_u8(v, 1) +
|
||||
(uint16_t)vgetq_lane_u8(v, 2) + (uint16_t)vgetq_lane_u8(v, 3) +
|
||||
(uint16_t)vgetq_lane_u8(v, 4) + (uint16_t)vgetq_lane_u8(v, 5) +
|
||||
(uint16_t)vgetq_lane_u8(v, 6) + (uint16_t)vgetq_lane_u8(v, 7) +
|
||||
(uint16_t)vgetq_lane_u8(v, 8) + (uint16_t)vgetq_lane_u8(v, 9) +
|
||||
(uint16_t)vgetq_lane_u8(v, 10) + (uint16_t)vgetq_lane_u8(v, 11) +
|
||||
(uint16_t)vgetq_lane_u8(v, 12) + (uint16_t)vgetq_lane_u8(v, 13) +
|
||||
(uint16_t)vgetq_lane_u8(v, 14) + (uint16_t)vgetq_lane_u8(v, 15);
|
||||
}
|
||||
|
||||
inline static int32_t vaddvq_s16(int16x8_t v) {
|
||||
return
|
||||
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
|
||||
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
|
||||
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
|
||||
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
|
||||
}
|
||||
|
||||
inline static uint32_t vaddvq_u16(uint16x8_t v) {
|
||||
return
|
||||
(uint32_t)vgetq_lane_u16(v, 0) + (uint32_t)vgetq_lane_u16(v, 1) +
|
||||
(uint32_t)vgetq_lane_u16(v, 2) + (uint32_t)vgetq_lane_u16(v, 3) +
|
||||
(uint32_t)vgetq_lane_u16(v, 4) + (uint32_t)vgetq_lane_u16(v, 5) +
|
||||
(uint32_t)vgetq_lane_u16(v, 6) + (uint32_t)vgetq_lane_u16(v, 7);
|
||||
}
|
||||
|
||||
inline static int32_t vaddvq_s32(int32x4_t v) {
|
||||
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
|
||||
}
|
||||
|
||||
inline static float vaddvq_f32(float32x4_t v) {
|
||||
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
|
||||
}
|
||||
|
||||
inline float vminvq_f32(float32x4_t v) {
|
||||
return
|
||||
MIN(MIN(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
||||
MIN(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
||||
}
|
||||
|
||||
inline float vmaxvq_f32(float32x4_t v) {
|
||||
return
|
||||
MAX(MAX(vgetq_lane_f32(v, 0), vgetq_lane_f32(v, 1)),
|
||||
MAX(vgetq_lane_f32(v, 2), vgetq_lane_f32(v, 3)));
|
||||
}
|
||||
|
||||
inline int8x8_t vzip1_s8(int8x8_t a, int8x8_t b) {
|
||||
return vget_low_s8(vcombine_s8(a, b));
|
||||
}
|
||||
|
||||
inline int8x8_t vzip2_s8(int8x8_t a, int8x8_t b) {
|
||||
return vget_high_s8(vcombine_s8(a, b));
|
||||
}
|
||||
|
||||
inline uint8x8_t vzip1_u8(uint8x8_t a, uint8x8_t b) {
|
||||
return vget_low_u8(vcombine_u8(a, b));
|
||||
}
|
||||
|
||||
inline uint8x8_t vzip2_u8(uint8x8_t a, uint8x8_t b) {
|
||||
return vget_high_u8(vcombine_u8(a, b));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// method 5
|
||||
// blocks of QK elements
|
||||
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
||||
@ -1218,15 +1289,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
||||
#define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c)
|
||||
#define GGML_F32x4_ADD vaddq_f32
|
||||
#define GGML_F32x4_MUL vmulq_f32
|
||||
#if defined(__ARM_FEATURE_QRDMX)
|
||||
#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x)
|
||||
#else
|
||||
#define GGML_F32x4_REDUCE_ONE(x) \
|
||||
(vgetq_lane_f32(x, 0) + \
|
||||
vgetq_lane_f32(x, 1) + \
|
||||
vgetq_lane_f32(x, 2) + \
|
||||
vgetq_lane_f32(x, 3))
|
||||
#endif
|
||||
#define GGML_F32x4_REDUCE(res, x) \
|
||||
{ \
|
||||
for (int i = 0; i < GGML_F32_ARR/2; ++i) { \
|
||||
@ -1849,55 +1912,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
||||
// 4-bit -> 8-bit
|
||||
const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b));
|
||||
const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b));
|
||||
|
||||
const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4));
|
||||
const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4));
|
||||
|
||||
const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b));
|
||||
const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b));
|
||||
|
||||
const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4));
|
||||
const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4));
|
||||
|
||||
// sub 8
|
||||
const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b);
|
||||
const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b);
|
||||
|
||||
const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b);
|
||||
const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b);
|
||||
|
||||
const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b);
|
||||
const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b);
|
||||
|
||||
const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b);
|
||||
const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b);
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
// dot product into int16x8_t
|
||||
// dot product into int32x4_t
|
||||
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls);
|
||||
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls);
|
||||
|
||||
p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs);
|
||||
p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs);
|
||||
|
||||
// scalar
|
||||
#if defined(__ARM_FEATURE_QRDMX)
|
||||
sum0 += x0->d*y0->d*vaddvq_s32(p_0);
|
||||
sum1 += x1->d*y1->d*vaddvq_s32(p_1);
|
||||
#else
|
||||
sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3));
|
||||
sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3));
|
||||
#endif
|
||||
#else
|
||||
const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls));
|
||||
const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls));
|
||||
|
||||
const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs));
|
||||
const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs));
|
||||
|
||||
const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls));
|
||||
const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls));
|
||||
|
||||
const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs));
|
||||
const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs));
|
||||
|
||||
@ -1910,14 +1961,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
||||
const int16x8_t p_0 = vaddq_s16(pl_0, ph_0);
|
||||
const int16x8_t p_1 = vaddq_s16(pl_1, ph_1);
|
||||
|
||||
// scalar
|
||||
#if defined(__ARM_FEATURE_QRDMX)
|
||||
sum0 += x0->d*y0->d*vaddvq_s16(p_0);
|
||||
sum1 += x1->d*y1->d*vaddvq_s16(p_1);
|
||||
#else
|
||||
sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7));
|
||||
sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7));
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -2265,36 +2310,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
||||
float sum10 = 0.0f;
|
||||
float sum11 = 0.0f;
|
||||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
for (int i = 0; i < nb; i += 2) {
|
||||
const block_q4_1 * restrict x0 = &x[i + 0];
|
||||
const block_q4_1 * restrict y0 = &y[i + 0];
|
||||
const block_q4_1 * restrict x1 = &x[i + 1];
|
||||
const block_q4_1 * restrict y1 = &y[i + 1];
|
||||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
|
||||
const uint8x16_t v0_0 = vld1q_u8(x0->qs);
|
||||
const uint8x16_t v1_0 = vld1q_u8(y0->qs);
|
||||
const uint8x16_t v0_1 = vld1q_u8(x1->qs);
|
||||
const uint8x16_t v1_1 = vld1q_u8(y1->qs);
|
||||
|
||||
// and with 0xf
|
||||
// 4-bit -> 8-bit
|
||||
const uint8x16_t v0_0l = vandq_u8(v0_0, m4b);
|
||||
const uint8x16_t v1_0l = vandq_u8(v1_0, m4b);
|
||||
|
||||
const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4);
|
||||
const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4);
|
||||
|
||||
// dot product into uint16x8_t
|
||||
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
||||
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
||||
|
||||
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
|
||||
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
|
||||
|
||||
const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h);
|
||||
const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h);
|
||||
const uint8x16_t v0_1l = vandq_u8(v0_1, m4b);
|
||||
const uint8x16_t v1_1l = vandq_u8(v1_1, m4b);
|
||||
const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4);
|
||||
const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4);
|
||||
|
||||
sum00 += x0->m*y0->m;
|
||||
sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h));
|
||||
sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h));
|
||||
sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
|
||||
|
||||
sum00 += x1->m*y1->m;
|
||||
sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h));
|
||||
sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h));
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
// dot product into int32x4_t
|
||||
int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l);
|
||||
int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l);
|
||||
|
||||
p_0 = vdotq_s32(p_0, v0_0h, v1_0h);
|
||||
p_1 = vdotq_s32(p_1, v0_1h, v1_1h);
|
||||
|
||||
sum11 += x0->d*y0->d*vaddvq_s32(p_0);
|
||||
sum11 += x1->d*y1->d*vaddvq_s32(p_1);
|
||||
#else
|
||||
const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l));
|
||||
const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l));
|
||||
const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h));
|
||||
const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h));
|
||||
|
||||
const uint16x8_t pl1l = vmull_u8(vget_low_u8 (v0_1l), vget_low_u8 (v1_1l));
|
||||
const uint16x8_t pl1h = vmull_u8(vget_high_u8(v0_1l), vget_high_u8(v1_1l));
|
||||
const uint16x8_t ph1l = vmull_u8(vget_low_u8 (v0_1h), vget_low_u8 (v1_1h));
|
||||
const uint16x8_t ph1h = vmull_u8(vget_high_u8(v0_1h), vget_high_u8(v1_1h));
|
||||
|
||||
const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h);
|
||||
const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h);
|
||||
|
||||
const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h);
|
||||
const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h);
|
||||
|
||||
const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0);
|
||||
const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1);
|
||||
|
||||
sum11 += x0->d*y0->d*vaddvq_u16(p_0);
|
||||
sum11 += x1->d*y1->d*vaddvq_u16(p_1);
|
||||
#endif
|
||||
}
|
||||
|
||||
sumf = QK*sum00 + sum01 + sum10 + sum11;
|
||||
|
Loading…
Reference in New Issue
Block a user