Add support for loongarch backend in sgemm.cpp

This commit is contained in:
Tianzhengshuyuan 2024-07-25 22:32:41 +08:00
parent de280085e7
commit 53ad4bd89f

View File

@ -58,7 +58,7 @@
#define NOINLINE __attribute__((__noinline__))
#endif
#if defined(__ARM_NEON) || defined(__AVX512F__)
#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__loongarch_asx)
#define VECTOR_REGISTERS 32
#else
#define VECTOR_REGISTERS 16
@ -66,6 +66,15 @@
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
#ifdef __clang__
#define VREGS_PREFIX "$vr"
#define XREGS_PREFIX "$xr"
#else // GCC
#define VREGS_PREFIX "$f"
#define XREGS_PREFIX "$f"
#endif
#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
namespace {
inline float unhalf(ggml_fp16_t d) {
@ -105,6 +114,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if defined(__loongarch_asx)
inline __m256 add(__m256 x, __m256 y) { return __lasx_xvfadd_s(x, y); }
inline __m256 sub(__m256 x, __m256 y) { return __lasx_xvfsub_s(x, y); }
inline __m256 mul(__m256 x, __m256 y) { return __lasx_xvfmul_s(x, y); }
#endif // __loongarch_asx
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED FUSED MULTIPLY ADD
@ -144,6 +159,12 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
#endif
#endif
#if defined(__loongarch_asx)
template <>
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
return __lasx_xvfmadd_s(a, b, c);
}
#endif //__loongarch_asx
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED HORIZONTAL SUM
@ -189,6 +210,68 @@ inline float hsum(__m512 x) {
}
#endif // __AVX512F__
#if defined(__loongarch_asx)
static inline __m128i lasx_extracti128_lo(__m256i in) {
__m128i out;
__asm__ volatile (
".ifnc %[out], %[in] \n\t"
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
" vori.b $vr\\i, $vr\\j, 0 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
".endif \n\t"
: [out] "=f" (out) : [in] "f" (in)
);
return out;
}
static inline __m128i lasx_extracti128_hi(__m256i in) {
__m128i out;
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
" xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
: [out] "=f" (out) : [in] "f" (in)
);
return out;
}
static __m128 lasx_extractf128( __m256 a, int pos) {
__m128 ret;
if( pos == 0)
{
ret = (__m128)lasx_extracti128_lo((__m256i)a);
} else {
ret = (__m128)lasx_extracti128_hi((__m256i)a);
}
return ret;
}
static inline float hsum_float_8(const __m256 x) {
__m128 res = lasx_extractf128(x, 1);
ft_union tmp;
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
tmp.i = __lsx_vpickve2gr_w(res, 0);
return tmp.f;
}
inline float hsum(__m256 x) {
return hsum_float_8(x);
}
#endif // __loongarch_asx
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED MEMORY LOADING
@ -235,6 +318,47 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
}
#endif // __AVX512F__
#if defined(__loongarch_asx)
template <> inline __m256 load(const float *p) {
return (__m256)__lasx_xvld(p, 0);
}
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
__m256i out;
__asm__ volatile (
".irp i," __ALL_REGS "\n\t"
" .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"//拼接hi和lo
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
".ifnc %[out], %[hi] \n\t"
".irp i," __ALL_REGS "\n\t"
" .ifc %[out], " XREGS_PREFIX "\\i \n\t"
" .irp j," __ALL_REGS "\n\t"
" .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
" xvori.b $xr\\i, $xr\\j, 0 \n\t"//复制hi到out
" .endif \n\t"
" .endr \n\t"
" .endif \n\t"
".endr \n\t"
".endif \n\t"
: [out] "=f" (out), [hi] "+f" (inhi)
: [lo] "f" (inlo)
);
return out;
}
template <> inline __m256 load(const ggml_fp16_t *p) {
__m128i vector16 = __lsx_vld((const __m128i *)p, 0);
__m128i hi = (__m128i)__lsx_vfcvth_s_h(vector16);
__m128i lo = (__m128i)__lsx_vfcvtl_s_h(vector16);
return (__m256)lasx_set_q(hi,lo);
}
#endif // __loongarch_asx
////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION
@ -813,6 +937,165 @@ class tinyBLAS_Q0_AVX {
};
#endif // __AVX__
#if defined(__loongarch_asx)
template <typename TA, typename TB, typename TC>
class tinyBLAS_Q0_LOONGARCH {
public:
tinyBLAS_Q0_LOONGARCH(int64_t k,
const TA *A, int64_t lda,
const TB *B, int64_t ldb,
TC *C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n) {
mnpack(0, m, 0, n);
}
private:
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t mc, nc, mp, np;
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 1)) {
case 0x44:
case 0x43:
case 0x42:
case 0x41:
mc = 4;
nc = 1;
gemm<4, 1>(m0, m, n0, n);
break;
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
break;
case 0x14:
mc = 1;
nc = 4;
gemm<1, 4>(m0, m, n0, n);
break;
case 0x31:
mc = 3;
nc = 1;
gemm<3, 1>(m0, m, n0, n);
break;
case 0x13:
mc = 1;
nc = 3;
gemm<1, 3>(m0, m, n0, n);
break;
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
break;
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
break;
default:
return;
}
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
inline __m256i load(const block_q8_0 *b) {
return __lasx_xvld(((const __m256i *)b->qs), 0);
}
static __m256i lasx_insertf128( __m128i x, __m128i y) {
return lasx_set_q(x, y);
}
// Unpack 32 4-bit fields into 32 bytes
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
__m128i hi = __lsx_vsrli_h(lo, 4);
return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
}
inline __m256i load(const block_q4_0 *b) {
return __lasx_xvsub_b(bytes_from_nibbles_32(b->qs), __lasx_xvreplgr2vr_b(8));
}
// add int16_t pairwise and return as float vector
static inline __m256 sum_i16_pairs_float(const __m256i x) {
__m256i v = __lasx_xvpackod_h(x, x);
__m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
return __lasx_xvffint_s_w(summed_pairs);
}
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
__m256i tmp1, tmp2;
tmp1 = __lasx_xvmulwev_h_b(a, b);
tmp2 = __lasx_xvmulwod_h_b(a, b);
return __lasx_xvsadd_h(tmp1, tmp2);
}
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
const __m256i dot = lasx_maddubs_h(ax, sy);
return sum_i16_pairs_float(dot);
}
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
const __m256i ax = __lasx_xvsigncov_b(x, x);
const __m256i sy = __lasx_xvsigncov_b(x, y);
return mul_sum_us8_pairs_float(ax, sy);
}
template <int RM, int RN>
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
__m256 Cv[RN][RM] = {};
for (int64_t l = 0; l < k; ++l)
for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i) {
__m256 udTmp = mul_sum_i8_pairs_float(load(B + ldb * (jj + j) + l),
load(A + lda * (ii + i) + l));
Cv[j][i] = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d)),
udTmp,
Cv[j][i]);
}
for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum_float_8(Cv[j][i]);
}
}
const TA *const A;
const TB *const B;
TC *const C;
const int64_t k;
const int64_t lda;
const int64_t ldb;
const int64_t ldc;
const int ith;
const int nth;
};
#endif // __loongarch_asx
} // namespace
/**
@ -897,6 +1180,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
ith, nth};
tb.matmul(m, n);
return true;
// #elif defined(__loongarch_asx)
// if (k % 8)
// return false;
// tinyBLAS<8, __m256, __m256, float, float, float> tb{
// k, (const float *)A, lda,
// (const float *)B, ldb,
// (float *)C, ldc,
// ith, nth};
// tb.matmul(m, n);
// return true;
#else
return false;
#endif
@ -953,6 +1246,18 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
ith, nth};
tb.matmul(m, n);
return true;
#elif defined(__loongarch_asx)
if (k % 8)
return false;
if (Btype != GGML_TYPE_F32)
return false;
tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
k, (const ggml_fp16_t *)A, lda,
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n);
return true;
#else
return false;
#endif
@ -977,6 +1282,14 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
ith, nth};
tb.matmul(m, n);
return true;
#elif defined(__loongarch_asx)
tinyBLAS_Q0_LOONGARCH<block_q8_0, block_q8_0, float> tb{
k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n);
return true;
#else
return false;
#endif
@ -1001,6 +1314,14 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
ith, nth};
tb.matmul(m, n);
return true;
#elif defined(__loongarch_asx)
tinyBLAS_Q0_LOONGARCH<block_q4_0, block_q8_0, float> tb{
k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n);
return true;
#else
return false;
#endif