ggml : restore vec dot stride arg names (#5453)

This commit is contained in:
Georgi Gerganov 2024-02-18 22:58:57 +02:00
parent b1de96824b
commit 14278f55d2
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -3855,7 +3855,7 @@ static inline __m128i get_scale_shuffle(int i) {
} }
#endif #endif
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
@ -3866,8 +3866,8 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
assert(nrc == 1); assert(nrc == 1);
#endif #endif
UNUSED(nrc); UNUSED(nrc);
UNUSED(bbx); UNUSED(bx);
UNUSED(bby); UNUSED(by);
UNUSED(bs); UNUSED(bs);
const block_q4_0 * restrict x = vx; const block_q4_0 * restrict x = vx;
@ -4024,15 +4024,15 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs);
__m128i bx = _mm_and_si128(lowMask, tmp); __m128i bx_0 = _mm_and_si128(lowMask, tmp);
__m128i by = _mm_loadu_si128((const __m128i *)y[i].qs); __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs);
bx = _mm_sub_epi8(bx, off); bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_0 = mul_sum_i8_pairs(bx, by); const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0);
bx = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4));
by = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16));
bx = _mm_sub_epi8(bx, off); bx_0 = _mm_sub_epi8(bx_0, off);
const __m128i i32_1 = mul_sum_i8_pairs(bx, by); const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0);
// Convert int32_t to float // Convert int32_t to float
__m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1));
@ -4222,7 +4222,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
#endif #endif
} }
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1; const int qk = QK8_1;
const int nb = n / qk; const int nb = n / qk;
@ -4233,8 +4233,8 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
assert(nrc == 1); assert(nrc == 1);
#endif #endif
UNUSED(nrc); UNUSED(nrc);
UNUSED(bbx); UNUSED(bx);
UNUSED(bby); UNUSED(by);
UNUSED(bs); UNUSED(bs);
const block_q4_1 * restrict x = vx; const block_q4_1 * restrict x = vx;
@ -4440,7 +4440,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r
#endif #endif
} }
void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
@ -4448,8 +4448,8 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
assert(qk == QK5_0); assert(qk == QK5_0);
assert(nrc == 1); assert(nrc == 1);
UNUSED(nrc); UNUSED(nrc);
UNUSED(bbx); UNUSED(bx);
UNUSED(bby); UNUSED(by);
UNUSED(bs); UNUSED(bs);
const block_q5_0 * restrict x = vx; const block_q5_0 * restrict x = vx;
@ -4618,21 +4618,21 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
/* Compute combined scale for the block */ /* Compute combined scale for the block */
const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d)); const __m256 d = _mm256_set1_ps(GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(y[i].d));
__m256i bx = bytes_from_nibbles_32(x[i].qs); __m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh); const __m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_andnot_si128(bxhil, mask); bxhil = _mm_andnot_si128(bxhil, mask);
bxhih = _mm_andnot_si128(bxhih, mask); bxhih = _mm_andnot_si128(bxhih, mask);
__m128i bxl = _mm256_castsi256_si128(bx); __m128i bxl = _mm256_castsi256_si128(bx_0);
__m128i bxh = _mm256_extractf128_si256(bx, 1); __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
bxl = _mm_or_si128(bxl, bxhil); bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih); bxh = _mm_or_si128(bxh, bxhih);
bx = MM256_SET_M128I(bxh, bxl); bx_0 = MM256_SET_M128I(bxh, bxl);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_i8_pairs_float(bx, by); const __m256 q = mul_sum_i8_pairs_float(bx_0, by_0);
/* Multiply q with scale and accumulate */ /* Multiply q with scale and accumulate */
acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc); acc = _mm256_add_ps(_mm256_mul_ps(d, q), acc);
@ -4731,7 +4731,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r
#endif #endif
} }
void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_1; const int qk = QK8_1;
const int nb = n / qk; const int nb = n / qk;
@ -4739,8 +4739,8 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
assert(qk == QK5_1); assert(qk == QK5_1);
assert(nrc == 1); assert(nrc == 1);
UNUSED(nrc); UNUSED(nrc);
UNUSED(bbx); UNUSED(bx);
UNUSED(bby); UNUSED(by);
UNUSED(bs); UNUSED(bs);
const block_q5_1 * restrict x = vx; const block_q5_1 * restrict x = vx;
@ -4925,22 +4925,22 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s; summs += GGML_FP16_TO_FP32(x[i].m) * y[i].s;
__m256i bx = bytes_from_nibbles_32(x[i].qs); __m256i bx_0 = bytes_from_nibbles_32(x[i].qs);
const __m256i bxhi = bytes_from_bits_32(x[i].qh); const __m256i bxhi = bytes_from_bits_32(x[i].qh);
__m128i bxhil = _mm256_castsi256_si128(bxhi); __m128i bxhil = _mm256_castsi256_si128(bxhi);
__m128i bxhih = _mm256_extractf128_si256(bxhi, 1); __m128i bxhih = _mm256_extractf128_si256(bxhi, 1);
bxhil = _mm_and_si128(bxhil, mask); bxhil = _mm_and_si128(bxhil, mask);
bxhih = _mm_and_si128(bxhih, mask); bxhih = _mm_and_si128(bxhih, mask);
__m128i bxl = _mm256_castsi256_si128(bx); __m128i bxl = _mm256_castsi256_si128(bx_0);
__m128i bxh = _mm256_extractf128_si256(bx, 1); __m128i bxh = _mm256_extractf128_si256(bx_0, 1);
bxl = _mm_or_si128(bxl, bxhil); bxl = _mm_or_si128(bxl, bxhil);
bxh = _mm_or_si128(bxh, bxhih); bxh = _mm_or_si128(bxh, bxhih);
bx = MM256_SET_M128I(bxh, bxl); bx_0 = MM256_SET_M128I(bxh, bxl);
const __m256 dy = _mm256_set1_ps(y[i].d); const __m256 dy = _mm256_set1_ps(y[i].d);
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs); const __m256i by_0 = _mm256_loadu_si256((const __m256i *)y[i].qs);
const __m256 q = mul_sum_us8_pairs_float(bx, by); const __m256 q = mul_sum_us8_pairs_float(bx_0, by_0);
acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc); acc = _mm256_add_ps(_mm256_mul_ps(q, _mm256_mul_ps(dx, dy)), acc);
} }
@ -5035,7 +5035,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r
#endif #endif
} }
void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bbx, const void * restrict vy, size_t bby, int nrc) { void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;
@ -5046,8 +5046,8 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
assert(nrc == 1); assert(nrc == 1);
#endif #endif
UNUSED(nrc); UNUSED(nrc);
UNUSED(bbx); UNUSED(bx);
UNUSED(bby); UNUSED(by);
UNUSED(bs); UNUSED(bs);
const block_q8_0 * restrict x = vx; const block_q8_0 * restrict x = vx;
@ -5169,10 +5169,10 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
// load elements // load elements
vint8m1_t bx = __riscv_vle8_v_i8m1(x[i].qs, vl); vint8m1_t bx_0 = __riscv_vle8_v_i8m1(x[i].qs, vl);
vint8m1_t by = __riscv_vle8_v_i8m1(y[i].qs, vl); vint8m1_t by_0 = __riscv_vle8_v_i8m1(y[i].qs, vl);
vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx, by, vl); vint16m2_t vw_mul = __riscv_vwmul_vv_i16m2(bx_0, by_0, vl);
vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl); vint32m1_t v_zero = __riscv_vmv_v_x_i32m1(0, vl);
vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl); vint32m1_t v_sum = __riscv_vwredsum_vs_i16m2_i32m1(vw_mul, v_zero, vl);