mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
ggml-quants : allow using ARM dot product instructions for TQ1_0
This commit is contained in:
parent
895004f3f8
commit
69f772682e
@ -5667,7 +5667,114 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
|
|
||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#if defined __ARM_NEON
|
#if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
|
||||||
|
float sumf = 0.0f;
|
||||||
|
|
||||||
|
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
||||||
|
|
||||||
|
const uint8x16_t shift = vld1q_u8(k_shift);
|
||||||
|
|
||||||
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
int32x4_t sumi0 = vdupq_n_s32(0);
|
||||||
|
int32x4_t sumi1 = vdupq_n_s32(0);
|
||||||
|
|
||||||
|
// first 32 bytes of 5 elements
|
||||||
|
{
|
||||||
|
uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
|
||||||
|
uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
|
||||||
|
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
|
||||||
|
uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
|
||||||
|
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
|
||||||
|
uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
|
||||||
|
uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
|
||||||
|
uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
|
||||||
|
uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
|
||||||
|
uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
|
||||||
|
|
||||||
|
// multiply by 3 and keep the 2 bits above 8 bits
|
||||||
|
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
||||||
|
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
||||||
|
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
||||||
|
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
||||||
|
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
||||||
|
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
||||||
|
int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
|
||||||
|
int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
|
||||||
|
int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
|
||||||
|
int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
|
||||||
|
|
||||||
|
const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
|
||||||
|
const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
|
||||||
|
const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
|
||||||
|
const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
|
||||||
|
const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
|
||||||
|
const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
|
||||||
|
const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
|
||||||
|
const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
|
||||||
|
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
|
||||||
|
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
|
||||||
|
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx8, qy8);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx9, qy9);
|
||||||
|
}
|
||||||
|
|
||||||
|
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
||||||
|
{
|
||||||
|
uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
|
||||||
|
uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
|
||||||
|
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
|
||||||
|
uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
|
||||||
|
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
|
||||||
|
uint32_t qh;
|
||||||
|
memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
|
||||||
|
uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
|
||||||
|
qx5 = vmulq_u8(qx5, shift);
|
||||||
|
|
||||||
|
// multiply by 3 and keep the 2 bits above 8 bits
|
||||||
|
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
||||||
|
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
||||||
|
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
||||||
|
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
||||||
|
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
||||||
|
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
||||||
|
|
||||||
|
const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
|
||||||
|
const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
|
||||||
|
const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
|
||||||
|
const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
|
||||||
|
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
|
||||||
|
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
|
||||||
|
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
||||||
|
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
||||||
|
|
||||||
|
sumi0 = vaddq_s32(sumi0, sumi1);
|
||||||
|
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
||||||
|
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
|
||||||
|
sumf += d * (float) vaddvq_s32(sumi0);
|
||||||
|
}
|
||||||
|
|
||||||
|
*s = sumf;
|
||||||
|
|
||||||
|
#elif defined __ARM_NEON
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
||||||
|
Loading…
Reference in New Issue
Block a user