diff --git a/ggml-quants.c b/ggml-quants.c index dadff5bbe..183aaefb2 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10089,18 +10089,33 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v #if defined(__ARM_NEON) + typedef union { + uint16x8_t vec_index; + uint16_t index[8]; + } vec_index_t; + static const uint8_t k_mask1[32] = {0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x01, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03, 0x03 }; static const uint8_t k_mask2[16] = {0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80,}; - const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1); - const uint8x16_t mask2 = vld1q_u8(k_mask2); + static const int16_t k_shift[8] = {8, 7, 6, 5, 4, 3, 2, 1}; + + const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1); + const uint8x16_t mask2 = vld1q_u8(k_mask2); + const int16x8_t hshift = vld1q_s16(k_shift); + const uint16x8_t m256 = vdupq_n_u16(256); uint8x16x2_t vs; ggml_int8x16x4_t q3s; ggml_int8x16x4_t q8b; + vec_index_t idx; + +#if QK_K == 256 + uint32_t scales32[2]; + const uint8_t * scales8 = (const uint8_t *)scales32; +#endif float sumf = 0; for (int i = 0; i < nb; ++i) { @@ -10109,18 +10124,29 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const uint8_t * restrict qh = x[i].qh; const uint16_t * restrict signs = (const uint16_t *)x[i].signs; const int8_t * restrict q8 = y[i].qs; + +#if QK_K == 256 + memcpy(scales32, x[i].scales, 4); + scales32[1] = (((scales32[0] >> 4) & 0x0f0f0f0f) << 1) | 0x01010101; + scales32[0] = ((scales32[0] & 0x0f0f0f0f) << 1) | 0x01010101; +#endif + int sumi1 = 0, sumi2 = 0; for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { q8b = ggml_vld1q_s8_x4(q8); q8 += 64; - const uint32x4_t aux32x4_0 = {iq3xs_grid[qs[ 0] | ((qh[ib32+0] << 8) & 256)], iq3xs_grid[qs[ 1] | ((qh[ib32+0] << 7) & 256)], - iq3xs_grid[qs[ 2] | ((qh[ib32+0] << 6) & 256)], iq3xs_grid[qs[ 3] | ((qh[ib32+0] << 5) & 256)]}; - const uint32x4_t aux32x4_1 = {iq3xs_grid[qs[ 4] | ((qh[ib32+0] << 4) & 256)], iq3xs_grid[qs[ 5] | ((qh[ib32+0] << 3) & 256)], - iq3xs_grid[qs[ 6] | ((qh[ib32+0] << 2) & 256)], iq3xs_grid[qs[ 7] | ((qh[ib32+0] << 1) & 256)]}; - const uint32x4_t aux32x4_2 = {iq3xs_grid[qs[ 8] | ((qh[ib32+1] << 8) & 256)], iq3xs_grid[qs[ 9] | ((qh[ib32+1] << 7) & 256)], - iq3xs_grid[qs[10] | ((qh[ib32+1] << 6) & 256)], iq3xs_grid[qs[11] | ((qh[ib32+1] << 5) & 256)]}; - const uint32x4_t aux32x4_3 = {iq3xs_grid[qs[12] | ((qh[ib32+1] << 4) & 256)], iq3xs_grid[qs[13] | ((qh[ib32+1] << 3) & 256)], - iq3xs_grid[qs[14] | ((qh[ib32+1] << 2) & 256)], iq3xs_grid[qs[15] | ((qh[ib32+1] << 1) & 256)]}; - qs += 16; + + const uint8x16_t idx_l = vld1q_u8(qs); qs += 16; + idx.vec_index = vorrq_u16(vmovl_u8(vget_low_u8 (idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+0]), hshift), m256)); + const uint32x4_t aux32x4_0 = {iq3xs_grid[idx.index[0]], iq3xs_grid[idx.index[1]], + iq3xs_grid[idx.index[2]], iq3xs_grid[idx.index[3]]}; + const uint32x4_t aux32x4_1 = {iq3xs_grid[idx.index[4]], iq3xs_grid[idx.index[5]], + iq3xs_grid[idx.index[6]], iq3xs_grid[idx.index[7]]}; + idx.vec_index = vorrq_u16(vmovl_u8(vget_high_u8(idx_l)), vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), hshift), m256)); + const uint32x4_t aux32x4_2 = {iq3xs_grid[idx.index[0]], iq3xs_grid[idx.index[1]], + iq3xs_grid[idx.index[2]], iq3xs_grid[idx.index[3]]}; + const uint32x4_t aux32x4_3 = {iq3xs_grid[idx.index[4]], iq3xs_grid[idx.index[5]], + iq3xs_grid[idx.index[6]], iq3xs_grid[idx.index[7]]}; + vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); @@ -10144,8 +10170,13 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[0], q8b.val[0]), q3s.val[1], q8b.val[1]); const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); +#if QK_K == 256 + sumi1 += vaddvq_s32(p1) * scales8[ib32/2+0]; + sumi2 += vaddvq_s32(p2) * scales8[ib32/2+4]; +#else sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf)); sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4)); +#endif } sumf += d*(sumi1 + sumi2); }