mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 11:40:17 +00:00
sgemm : improved Q4_0 and Q8_0 performance via 4xN and Mx4 gemm (#8908)
This commit is contained in:
parent
49271efbaf
commit
ea5d7478b1
@ -606,17 +606,29 @@ class tinyBLAS_Q0_AVX {
|
|||||||
case 0x44:
|
case 0x44:
|
||||||
mc = 4;
|
mc = 4;
|
||||||
nc = 4;
|
nc = 4;
|
||||||
|
#if defined(__AVX2__) && defined(__F16C__)
|
||||||
|
gemm4xN<4>(m0, m, n0, n);
|
||||||
|
#else
|
||||||
gemm<4, 4>(m0, m, n0, n);
|
gemm<4, 4>(m0, m, n0, n);
|
||||||
|
#endif
|
||||||
break;
|
break;
|
||||||
case 0x43:
|
case 0x43:
|
||||||
mc = 4;
|
mc = 4;
|
||||||
nc = 3;
|
nc = 3;
|
||||||
|
#if defined(__AVX2__) && defined(__F16C__)
|
||||||
|
gemm4xN<3>(m0, m, n0, n);
|
||||||
|
#else
|
||||||
gemm<4, 3>(m0, m, n0, n);
|
gemm<4, 3>(m0, m, n0, n);
|
||||||
|
#endif
|
||||||
break;
|
break;
|
||||||
case 0x34:
|
case 0x34:
|
||||||
mc = 3;
|
mc = 3;
|
||||||
nc = 4;
|
nc = 4;
|
||||||
|
#if defined(__AVX2__) && defined(__F16C__)
|
||||||
|
gemmMx4<3>(m0, m, n0, n);
|
||||||
|
#else
|
||||||
gemm<3, 4>(m0, m, n0, n);
|
gemm<3, 4>(m0, m, n0, n);
|
||||||
|
#endif
|
||||||
break;
|
break;
|
||||||
case 0x33:
|
case 0x33:
|
||||||
mc = 3;
|
mc = 3;
|
||||||
@ -626,12 +638,20 @@ class tinyBLAS_Q0_AVX {
|
|||||||
case 0x42:
|
case 0x42:
|
||||||
mc = 4;
|
mc = 4;
|
||||||
nc = 2;
|
nc = 2;
|
||||||
|
#if defined(__AVX2__) && defined(__F16C__)
|
||||||
|
gemm4xN<2>(m0, m, n0, n);
|
||||||
|
#else
|
||||||
gemm<4, 2>(m0, m, n0, n);
|
gemm<4, 2>(m0, m, n0, n);
|
||||||
|
#endif
|
||||||
break;
|
break;
|
||||||
case 0x24:
|
case 0x24:
|
||||||
mc = 2;
|
mc = 2;
|
||||||
nc = 4;
|
nc = 4;
|
||||||
|
#if defined(__AVX2__) && defined(__F16C__)
|
||||||
|
gemmMx4<2>(m0, m, n0, n);
|
||||||
|
#else
|
||||||
gemm<2, 4>(m0, m, n0, n);
|
gemm<2, 4>(m0, m, n0, n);
|
||||||
|
#endif
|
||||||
break;
|
break;
|
||||||
#else
|
#else
|
||||||
case 0x44:
|
case 0x44:
|
||||||
@ -639,13 +659,21 @@ class tinyBLAS_Q0_AVX {
|
|||||||
case 0x42:
|
case 0x42:
|
||||||
mc = 4;
|
mc = 4;
|
||||||
nc = 2;
|
nc = 2;
|
||||||
|
#if defined(__AVX2__) && defined(__F16C__)
|
||||||
|
gemm4xN<2>(m0, m, n0, n);
|
||||||
|
#else
|
||||||
gemm<4, 2>(m0, m, n0, n);
|
gemm<4, 2>(m0, m, n0, n);
|
||||||
|
#endif
|
||||||
break;
|
break;
|
||||||
case 0x34:
|
case 0x34:
|
||||||
case 0x24:
|
case 0x24:
|
||||||
mc = 2;
|
mc = 2;
|
||||||
nc = 4;
|
nc = 4;
|
||||||
|
#if defined(__AVX2__) && defined(__F16C__)
|
||||||
|
gemmMx4<2>(m0, m, n0, n);
|
||||||
|
#else
|
||||||
gemm<2, 4>(m0, m, n0, n);
|
gemm<2, 4>(m0, m, n0, n);
|
||||||
|
#endif
|
||||||
break;
|
break;
|
||||||
case 0x33:
|
case 0x33:
|
||||||
#endif
|
#endif
|
||||||
@ -662,7 +690,11 @@ class tinyBLAS_Q0_AVX {
|
|||||||
case 0x41:
|
case 0x41:
|
||||||
mc = 4;
|
mc = 4;
|
||||||
nc = 1;
|
nc = 1;
|
||||||
|
#if defined(__AVX2__) && defined(__F16C__)
|
||||||
|
gemm4xN<1>(m0, m, n0, n);
|
||||||
|
#else
|
||||||
gemm<4, 1>(m0, m, n0, n);
|
gemm<4, 1>(m0, m, n0, n);
|
||||||
|
#endif
|
||||||
break;
|
break;
|
||||||
case 0x22:
|
case 0x22:
|
||||||
mc = 2;
|
mc = 2;
|
||||||
@ -672,7 +704,11 @@ class tinyBLAS_Q0_AVX {
|
|||||||
case 0x14:
|
case 0x14:
|
||||||
mc = 1;
|
mc = 1;
|
||||||
nc = 4;
|
nc = 4;
|
||||||
|
#if defined(__AVX2__) && defined(__F16C__)
|
||||||
|
gemmMx4<1>(m0, m, n0, n);
|
||||||
|
#else
|
||||||
gemm<1, 4>(m0, m, n0, n);
|
gemm<1, 4>(m0, m, n0, n);
|
||||||
|
#endif
|
||||||
break;
|
break;
|
||||||
case 0x31:
|
case 0x31:
|
||||||
mc = 3;
|
mc = 3;
|
||||||
@ -708,6 +744,119 @@ class tinyBLAS_Q0_AVX {
|
|||||||
mnpack(m0, m, np, n);
|
mnpack(m0, m, np, n);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(__AVX2__) && defined(__F16C__)
|
||||||
|
// Templated functions for gemm of dimensions 4xN
|
||||||
|
template <int RN>
|
||||||
|
NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
|
int64_t ytiles = (m - m0) / 4;
|
||||||
|
int64_t xtiles = (n - n0) / RN;
|
||||||
|
int64_t tiles = xtiles * ytiles;
|
||||||
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
|
int64_t start = duty * ith;
|
||||||
|
int64_t end = start + duty;
|
||||||
|
if (end > tiles)
|
||||||
|
end = tiles;
|
||||||
|
for (int64_t job = start; job < end; ++job) {
|
||||||
|
int64_t ii = m0 + job / xtiles * 4;
|
||||||
|
int64_t jj = n0 + job % xtiles * RN;
|
||||||
|
__m256 Cv[RN][4] = {};
|
||||||
|
for (int64_t l = 0; l < k; ++l) {
|
||||||
|
uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d);
|
||||||
|
// Convert delta values for four blocks to float values
|
||||||
|
__m128 da = _mm_cvtph_ps(_mm_set_epi64x(0, a_delta));
|
||||||
|
__m256i avec0 = load(A + lda * (ii + 0) + l);
|
||||||
|
__m256i avec1 = load(A + lda * (ii + 1) + l);
|
||||||
|
__m256i avec2 = load(A + lda * (ii + 2) + l);
|
||||||
|
__m256i avec3 = load(A + lda * (ii + 3) + l);
|
||||||
|
for (int64_t j = 0; j < RN; ++j) {
|
||||||
|
__m128 db = _mm_set1_ps(unhalf(B[ldb * (jj + j) + l].d));
|
||||||
|
// Computation of product of delta values for four blocks and replicate it across 256 bit lane
|
||||||
|
__m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
|
||||||
|
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
|
||||||
|
// Computation of dot product and multiplication with appropriate delta value products
|
||||||
|
Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
|
||||||
|
updot(_mm256_sign_epi8(avec0, avec0),
|
||||||
|
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)),
|
||||||
|
Cv[j][0]);
|
||||||
|
Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
|
||||||
|
updot(_mm256_sign_epi8(avec1, avec1),
|
||||||
|
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)),
|
||||||
|
Cv[j][1]);
|
||||||
|
Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
|
||||||
|
updot(_mm256_sign_epi8(avec2, avec2),
|
||||||
|
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)),
|
||||||
|
Cv[j][2]);
|
||||||
|
Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
|
||||||
|
updot(_mm256_sign_epi8(avec3, avec3),
|
||||||
|
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)),
|
||||||
|
Cv[j][3]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < RN; ++j)
|
||||||
|
for (int64_t i = 0; i < 4; ++i)
|
||||||
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Templated functions for gemm of dimensions Mx4
|
||||||
|
template <int RM>
|
||||||
|
NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
|
int64_t ytiles = (m - m0) / RM;
|
||||||
|
int64_t xtiles = (n - n0) / 4;
|
||||||
|
int64_t tiles = xtiles * ytiles;
|
||||||
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
|
int64_t start = duty * ith;
|
||||||
|
int64_t end = start + duty;
|
||||||
|
if (end > tiles)
|
||||||
|
end = tiles;
|
||||||
|
for (int64_t job = start; job < end; ++job) {
|
||||||
|
int64_t ii = m0 + job / xtiles * RM;
|
||||||
|
int64_t jj = n0 + job % xtiles * 4;
|
||||||
|
__m256 Cv[4][RM] = {};
|
||||||
|
for (int64_t l = 0; l < k; ++l) {
|
||||||
|
uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d);
|
||||||
|
// Convert delta values for four blocks to float values
|
||||||
|
__m128 db = _mm_cvtph_ps(_mm_set_epi64x(0, b_delta));
|
||||||
|
__m256i bvec0 = load(B + ldb * (jj + 0) + l);
|
||||||
|
__m256i bvec1 = load(B + ldb * (jj + 1) + l);
|
||||||
|
__m256i bvec2 = load(B + ldb * (jj + 2) + l);
|
||||||
|
__m256i bvec3 = load(B + ldb * (jj + 3) + l);
|
||||||
|
for (int64_t i = 0; i < RM; ++i) {
|
||||||
|
__m128 da = _mm_set1_ps(unhalf((A[lda * (ii + i) + l].d)));
|
||||||
|
// Computation of product of delta values for four blocks and replicate it across 256 bit lane
|
||||||
|
__m256 dvec = _mm256_castps128_ps256(_mm_mul_ps(da, db));
|
||||||
|
dvec = _mm256_permute2f128_ps(dvec ,dvec, 0);
|
||||||
|
// Computation of dot product and multiplication with appropriate delta value products
|
||||||
|
Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0),
|
||||||
|
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
||||||
|
load(A + lda * (ii + i) + l)),
|
||||||
|
_mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))),
|
||||||
|
Cv[0][i]);
|
||||||
|
Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85),
|
||||||
|
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
||||||
|
load(A + lda * (ii + i) + l)),
|
||||||
|
_mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))),
|
||||||
|
Cv[1][i]);
|
||||||
|
Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170),
|
||||||
|
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
||||||
|
load(A + lda * (ii + i) + l)),
|
||||||
|
_mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))),
|
||||||
|
Cv[2][i]);
|
||||||
|
Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255),
|
||||||
|
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
|
||||||
|
load(A + lda * (ii + i) + l)),
|
||||||
|
_mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))),
|
||||||
|
Cv[3][i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int64_t j = 0; j < 4; ++j)
|
||||||
|
for (int64_t i = 0; i < RM; ++i)
|
||||||
|
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
template <int RM, int RN>
|
template <int RM, int RN>
|
||||||
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
int64_t ytiles = (m - m0) / RM;
|
int64_t ytiles = (m - m0) / RM;
|
||||||
|
Loading…
Reference in New Issue
Block a user