mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 04:14:35 +00:00
ggml-quants : deduplicate TQ1_0 and TQ2_0 __ARM_FEATURE_DOTPROD support
This commit is contained in:
parent
82b240406d
commit
35cc5567c8
@ -5667,7 +5667,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
|
|
||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
|
#if defined(__ARM_NEON)
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
||||||
@ -5675,8 +5675,13 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
const uint8x16_t shift = vld1q_u8(k_shift);
|
const uint8x16_t shift = vld1q_u8(k_shift);
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
int32x4_t sumi0 = vdupq_n_s32(0);
|
int32x4_t sumi0 = vdupq_n_s32(0);
|
||||||
int32x4_t sumi1 = vdupq_n_s32(0);
|
int32x4_t sumi1 = vdupq_n_s32(0);
|
||||||
|
#else
|
||||||
|
int16x8_t sumi0 = vdupq_n_s16(0);
|
||||||
|
int16x8_t sumi1 = vdupq_n_s16(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
// first 32 bytes of 5 elements
|
// first 32 bytes of 5 elements
|
||||||
{
|
{
|
||||||
@ -5714,6 +5719,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
|
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
|
||||||
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
|
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
||||||
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
||||||
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
||||||
@ -5724,103 +5730,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
||||||
sumi0 = vdotq_s32(sumi0, sqx8, qy8);
|
sumi0 = vdotq_s32(sumi0, sqx8, qy8);
|
||||||
sumi1 = vdotq_s32(sumi1, sqx9, qy9);
|
sumi1 = vdotq_s32(sumi1, sqx9, qy9);
|
||||||
}
|
#else
|
||||||
|
|
||||||
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
|
||||||
{
|
|
||||||
uint8x16_t qx0 = vld1q_u8(x[i].qs + 32);
|
|
||||||
uint8x16_t qx1 = vmulq_u8(qx0, vdupq_n_u8(3));
|
|
||||||
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(9));
|
|
||||||
uint8x16_t qx3 = vmulq_u8(qx0, vdupq_n_u8(27));
|
|
||||||
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(81));
|
|
||||||
uint32_t qh;
|
|
||||||
memcpy(&qh, x[i].qh, sizeof(qh)); // potentially unaligned
|
|
||||||
uint8x16_t qx5 = vreinterpretq_u8_u32(vdupq_n_u32(qh));
|
|
||||||
qx5 = vmulq_u8(qx5, shift);
|
|
||||||
|
|
||||||
// multiply by 3 and keep the 2 bits above 8 bits
|
|
||||||
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
|
||||||
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
|
||||||
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
|
||||||
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
|
||||||
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
|
||||||
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
|
||||||
|
|
||||||
const int8x16_t qy0 = vld1q_s8(y[i].qs + 160);
|
|
||||||
const int8x16_t qy1 = vld1q_s8(y[i].qs + 176);
|
|
||||||
const int8x16_t qy2 = vld1q_s8(y[i].qs + 192);
|
|
||||||
const int8x16_t qy3 = vld1q_s8(y[i].qs + 208);
|
|
||||||
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
|
|
||||||
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
|
|
||||||
|
|
||||||
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
|
||||||
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
|
||||||
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
|
||||||
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
|
||||||
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
|
||||||
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
|
||||||
}
|
|
||||||
|
|
||||||
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
|
||||||
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
|
||||||
|
|
||||||
sumi0 = vaddq_s32(sumi0, sumi1);
|
|
||||||
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
|
||||||
|
|
||||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
||||||
|
|
||||||
sumf += d * (float) vaddvq_s32(sumi0);
|
|
||||||
}
|
|
||||||
|
|
||||||
*s = sumf;
|
|
||||||
|
|
||||||
#elif defined __ARM_NEON
|
|
||||||
float sumf = 0.0f;
|
|
||||||
|
|
||||||
uint8_t k_shift[16] = {1, 1, 1, 1, 3, 3, 3, 3, 9, 9, 9, 9, 27, 27, 27, 27};
|
|
||||||
|
|
||||||
const uint8x16_t shift = vld1q_u8(k_shift);
|
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
|
||||||
int16x8_t sumi0 = vdupq_n_s16(0);
|
|
||||||
int16x8_t sumi1 = vdupq_n_s16(0);
|
|
||||||
|
|
||||||
// first 32 bytes of 5 elements
|
|
||||||
{
|
|
||||||
uint8x16_t qx0 = vld1q_u8(x[i].qs + 0);
|
|
||||||
uint8x16_t qx1 = vld1q_u8(x[i].qs + 16);
|
|
||||||
uint8x16_t qx2 = vmulq_u8(qx0, vdupq_n_u8(3));
|
|
||||||
uint8x16_t qx3 = vmulq_u8(qx1, vdupq_n_u8(3));
|
|
||||||
uint8x16_t qx4 = vmulq_u8(qx0, vdupq_n_u8(9));
|
|
||||||
uint8x16_t qx5 = vmulq_u8(qx1, vdupq_n_u8(9));
|
|
||||||
uint8x16_t qx6 = vmulq_u8(qx0, vdupq_n_u8(27));
|
|
||||||
uint8x16_t qx7 = vmulq_u8(qx1, vdupq_n_u8(27));
|
|
||||||
uint8x16_t qx8 = vmulq_u8(qx0, vdupq_n_u8(81));
|
|
||||||
uint8x16_t qx9 = vmulq_u8(qx1, vdupq_n_u8(81));
|
|
||||||
|
|
||||||
// multiply by 3 and keep the 2 bits above 8 bits
|
|
||||||
int8x16_t sqx0 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx0, vshrq_n_u8(qx0, 1)), 6));
|
|
||||||
int8x16_t sqx1 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx1, vshrq_n_u8(qx1, 1)), 6));
|
|
||||||
int8x16_t sqx2 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx2, vshrq_n_u8(qx2, 1)), 6));
|
|
||||||
int8x16_t sqx3 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx3, vshrq_n_u8(qx3, 1)), 6));
|
|
||||||
int8x16_t sqx4 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx4, vshrq_n_u8(qx4, 1)), 6));
|
|
||||||
int8x16_t sqx5 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx5, vshrq_n_u8(qx5, 1)), 6));
|
|
||||||
int8x16_t sqx6 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx6, vshrq_n_u8(qx6, 1)), 6));
|
|
||||||
int8x16_t sqx7 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx7, vshrq_n_u8(qx7, 1)), 6));
|
|
||||||
int8x16_t sqx8 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx8, vshrq_n_u8(qx8, 1)), 6));
|
|
||||||
int8x16_t sqx9 = vreinterpretq_s8_u8(vshrq_n_u8(vhaddq_u8(qx9, vshrq_n_u8(qx9, 1)), 6));
|
|
||||||
|
|
||||||
const int8x16_t qy0 = vld1q_s8(y[i].qs + 0);
|
|
||||||
const int8x16_t qy1 = vld1q_s8(y[i].qs + 16);
|
|
||||||
const int8x16_t qy2 = vld1q_s8(y[i].qs + 32);
|
|
||||||
const int8x16_t qy3 = vld1q_s8(y[i].qs + 48);
|
|
||||||
const int8x16_t qy4 = vld1q_s8(y[i].qs + 64);
|
|
||||||
const int8x16_t qy5 = vld1q_s8(y[i].qs + 80);
|
|
||||||
const int8x16_t qy6 = vld1q_s8(y[i].qs + 96);
|
|
||||||
const int8x16_t qy7 = vld1q_s8(y[i].qs + 112);
|
|
||||||
const int8x16_t qy8 = vld1q_s8(y[i].qs + 128);
|
|
||||||
const int8x16_t qy9 = vld1q_s8(y[i].qs + 144);
|
|
||||||
|
|
||||||
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
||||||
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
||||||
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
||||||
@ -5841,6 +5751,7 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
|
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx8), vget_high_s8(qy8));
|
||||||
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
|
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx9), vget_low_s8(qy9));
|
||||||
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
|
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx9), vget_high_s8(qy9));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
// last 16 bytes of 5-element, along with the 4 bytes of 4 elements
|
||||||
@ -5870,6 +5781,14 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
|
const int8x16_t qy4 = vld1q_s8(y[i].qs + 224);
|
||||||
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
|
const int8x16_t qy5 = vld1q_s8(y[i].qs + 240);
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx3, qy3);
|
||||||
|
sumi0 = vdotq_s32(sumi0, sqx4, qy4);
|
||||||
|
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
||||||
|
#else
|
||||||
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
||||||
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
||||||
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
||||||
@ -5882,22 +5801,30 @@ void ggml_vec_dot_tq1_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx4), vget_high_s8(qy4));
|
||||||
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx5), vget_low_s8(qy5));
|
||||||
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx5), vget_high_s8(qy5));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
||||||
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
||||||
|
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
sumi0 = vaddq_s32(sumi0, sumi1);
|
||||||
|
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
||||||
|
|
||||||
|
sumf += d * (float) vaddvq_s32(sumi0);
|
||||||
|
#else
|
||||||
sumi0 = vaddq_s16(sumi0, sumi1);
|
sumi0 = vaddq_s16(sumi0, sumi1);
|
||||||
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
||||||
|
|
||||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
||||||
|
|
||||||
sumf += d * (float) vaddlvq_s16(sumi0);
|
sumf += d * (float) vaddlvq_s16(sumi0);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
|
|
||||||
#elif defined __AVX2__
|
#elif defined(__AVX2__)
|
||||||
__m256 sumf = _mm256_setzero_ps();
|
__m256 sumf = _mm256_setzero_ps();
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
@ -6063,14 +5990,19 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
|
|
||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#if defined __ARM_NEON && defined __ARM_FEATURE_DOTPROD
|
#if defined(__ARM_NEON)
|
||||||
float sumf = 0.0f;
|
float sumf = 0.0f;
|
||||||
|
|
||||||
const uint8x16_t m3 = vdupq_n_u8(3);
|
const uint8x16_t m3 = vdupq_n_u8(3);
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
int32x4_t sumi0 = vdupq_n_s32(0);
|
int32x4_t sumi0 = vdupq_n_s32(0);
|
||||||
int32x4_t sumi1 = vdupq_n_s32(0);
|
int32x4_t sumi1 = vdupq_n_s32(0);
|
||||||
|
#else
|
||||||
|
int16x8_t sumi0 = vdupq_n_s16(0);
|
||||||
|
int16x8_t sumi1 = vdupq_n_s16(0);
|
||||||
|
#endif
|
||||||
|
|
||||||
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
||||||
uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
|
uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
|
||||||
@ -6100,6 +6032,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
|
const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
|
||||||
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
|
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
sumi0 = vdotq_s32(sumi0, sqx0, qy0);
|
||||||
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
sumi1 = vdotq_s32(sumi1, sqx1, qy1);
|
||||||
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
sumi0 = vdotq_s32(sumi0, sqx2, qy2);
|
||||||
@ -6108,58 +6041,7 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
sumi1 = vdotq_s32(sumi1, sqx5, qy5);
|
||||||
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
sumi0 = vdotq_s32(sumi0, sqx6, qy6);
|
||||||
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
sumi1 = vdotq_s32(sumi1, sqx7, qy7);
|
||||||
}
|
#else
|
||||||
|
|
||||||
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
|
||||||
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
|
||||||
|
|
||||||
sumi0 = vaddq_s32(sumi0, sumi1);
|
|
||||||
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
|
||||||
|
|
||||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
||||||
|
|
||||||
sumf += d * (float) vaddvq_s32(sumi0);
|
|
||||||
}
|
|
||||||
|
|
||||||
*s = sumf;
|
|
||||||
|
|
||||||
#elif defined __ARM_NEON
|
|
||||||
float sumf = 0.0f;
|
|
||||||
|
|
||||||
const uint8x16_t m3 = vdupq_n_u8(3);
|
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
|
||||||
int16x8_t sumi0 = vdupq_n_s16(0);
|
|
||||||
int16x8_t sumi1 = vdupq_n_s16(0);
|
|
||||||
|
|
||||||
for (size_t j = 0; j < sizeof(x->qs); j += 32) {
|
|
||||||
uint8x16_t qx0 = vld1q_u8(x[i].qs + j);
|
|
||||||
uint8x16_t qx1 = vld1q_u8(x[i].qs + j + 16);
|
|
||||||
uint8x16_t qx2 = vshrq_n_u8(qx0, 2);
|
|
||||||
uint8x16_t qx3 = vshrq_n_u8(qx1, 2);
|
|
||||||
uint8x16_t qx4 = vshrq_n_u8(qx0, 4);
|
|
||||||
uint8x16_t qx5 = vshrq_n_u8(qx1, 4);
|
|
||||||
uint8x16_t qx6 = vshrq_n_u8(qx0, 6);
|
|
||||||
uint8x16_t qx7 = vshrq_n_u8(qx1, 6);
|
|
||||||
|
|
||||||
int8x16_t sqx0 = vreinterpretq_s8_u8(vandq_u8(qx0, m3));
|
|
||||||
int8x16_t sqx1 = vreinterpretq_s8_u8(vandq_u8(qx1, m3));
|
|
||||||
int8x16_t sqx2 = vreinterpretq_s8_u8(vandq_u8(qx2, m3));
|
|
||||||
int8x16_t sqx3 = vreinterpretq_s8_u8(vandq_u8(qx3, m3));
|
|
||||||
int8x16_t sqx4 = vreinterpretq_s8_u8(vandq_u8(qx4, m3));
|
|
||||||
int8x16_t sqx5 = vreinterpretq_s8_u8(vandq_u8(qx5, m3));
|
|
||||||
int8x16_t sqx6 = vreinterpretq_s8_u8(vandq_u8(qx6, m3));
|
|
||||||
int8x16_t sqx7 = vreinterpretq_s8_u8(vandq_u8(qx7, m3));
|
|
||||||
|
|
||||||
const int8x16_t qy0 = vld1q_s8(y[i].qs + j*4 + 0);
|
|
||||||
const int8x16_t qy1 = vld1q_s8(y[i].qs + j*4 + 16);
|
|
||||||
const int8x16_t qy2 = vld1q_s8(y[i].qs + j*4 + 32);
|
|
||||||
const int8x16_t qy3 = vld1q_s8(y[i].qs + j*4 + 48);
|
|
||||||
const int8x16_t qy4 = vld1q_s8(y[i].qs + j*4 + 64);
|
|
||||||
const int8x16_t qy5 = vld1q_s8(y[i].qs + j*4 + 80);
|
|
||||||
const int8x16_t qy6 = vld1q_s8(y[i].qs + j*4 + 96);
|
|
||||||
const int8x16_t qy7 = vld1q_s8(y[i].qs + j*4 + 112);
|
|
||||||
|
|
||||||
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx0), vget_low_s8(qy0));
|
||||||
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx0), vget_high_s8(qy0));
|
||||||
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx1), vget_low_s8(qy1));
|
||||||
@ -6176,22 +6058,30 @@ void ggml_vec_dot_tq2_0_q8_K(int n, float * restrict s, size_t bs, const void *
|
|||||||
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
|
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx6), vget_high_s8(qy6));
|
||||||
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
|
sumi0 = vmlal_s8(sumi0, vget_low_s8(sqx7), vget_low_s8(qy7));
|
||||||
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
|
sumi1 = vmlal_s8(sumi1, vget_high_s8(sqx7), vget_high_s8(qy7));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
const int16x8_t ysum0 = vld1q_s16(y[i].bsums);
|
||||||
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
const int16x8_t ysum1 = vld1q_s16(y[i].bsums + 8);
|
||||||
|
|
||||||
|
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
sumi0 = vaddq_s32(sumi0, sumi1);
|
||||||
|
sumi0 = vsubq_s32(sumi0, vpaddlq_s16(vaddq_s16(ysum0, ysum1)));
|
||||||
|
|
||||||
|
sumf += d * (float) vaddvq_s32(sumi0);
|
||||||
|
#else
|
||||||
sumi0 = vaddq_s16(sumi0, sumi1);
|
sumi0 = vaddq_s16(sumi0, sumi1);
|
||||||
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
sumi0 = vsubq_s16(sumi0, vaddq_s16(ysum0, ysum1));
|
||||||
|
|
||||||
const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d;
|
|
||||||
|
|
||||||
sumf += d * (float) vaddlvq_s16(sumi0);
|
sumf += d * (float) vaddlvq_s16(sumi0);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
|
|
||||||
#elif defined __AVX2__
|
#elif defined(__AVX2__)
|
||||||
__m256 sumf = _mm256_setzero_ps();
|
__m256 sumf = _mm256_setzero_ps();
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
Loading…
Reference in New Issue
Block a user