mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-14 14:59:52 +00:00
Add support for loongarch backend in sgemm.cpp
This commit is contained in:
parent
de280085e7
commit
53ad4bd89f
@ -58,7 +58,7 @@
|
||||
#define NOINLINE __attribute__((__noinline__))
|
||||
#endif
|
||||
|
||||
#if defined(__ARM_NEON) || defined(__AVX512F__)
|
||||
#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__loongarch_asx)
|
||||
#define VECTOR_REGISTERS 32
|
||||
#else
|
||||
#define VECTOR_REGISTERS 16
|
||||
@ -66,6 +66,15 @@
|
||||
|
||||
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
||||
|
||||
#ifdef __clang__
|
||||
#define VREGS_PREFIX "$vr"
|
||||
#define XREGS_PREFIX "$xr"
|
||||
#else // GCC
|
||||
#define VREGS_PREFIX "$f"
|
||||
#define XREGS_PREFIX "$f"
|
||||
#endif
|
||||
#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
|
||||
|
||||
namespace {
|
||||
|
||||
inline float unhalf(ggml_fp16_t d) {
|
||||
@ -105,6 +114,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
||||
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
inline __m256 add(__m256 x, __m256 y) { return __lasx_xvfadd_s(x, y); }
|
||||
inline __m256 sub(__m256 x, __m256 y) { return __lasx_xvfsub_s(x, y); }
|
||||
inline __m256 mul(__m256 x, __m256 y) { return __lasx_xvfmul_s(x, y); }
|
||||
#endif // __loongarch_asx
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// VECTORIZED FUSED MULTIPLY ADD
|
||||
|
||||
@ -144,6 +159,12 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
template <>
|
||||
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
|
||||
return __lasx_xvfmadd_s(a, b, c);
|
||||
}
|
||||
#endif //__loongarch_asx
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// VECTORIZED HORIZONTAL SUM
|
||||
|
||||
@ -189,6 +210,68 @@ inline float hsum(__m512 x) {
|
||||
}
|
||||
#endif // __AVX512F__
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
static inline __m128i lasx_extracti128_lo(__m256i in) {
|
||||
__m128i out;
|
||||
__asm__ volatile (
|
||||
".ifnc %[out], %[in] \n\t"
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
||||
" vori.b $vr\\i, $vr\\j, 0 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
".endif \n\t"
|
||||
: [out] "=f" (out) : [in] "f" (in)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
|
||||
static inline __m128i lasx_extracti128_hi(__m256i in) {
|
||||
__m128i out;
|
||||
__asm__ volatile (
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
||||
" xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
: [out] "=f" (out) : [in] "f" (in)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
|
||||
static __m128 lasx_extractf128( __m256 a, int pos) {
|
||||
__m128 ret;
|
||||
if( pos == 0)
|
||||
{
|
||||
ret = (__m128)lasx_extracti128_lo((__m256i)a);
|
||||
} else {
|
||||
ret = (__m128)lasx_extracti128_hi((__m256i)a);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
static inline float hsum_float_8(const __m256 x) {
|
||||
__m128 res = lasx_extractf128(x, 1);
|
||||
ft_union tmp;
|
||||
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
|
||||
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
|
||||
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
|
||||
tmp.i = __lsx_vpickve2gr_w(res, 0);
|
||||
return tmp.f;
|
||||
}
|
||||
|
||||
inline float hsum(__m256 x) {
|
||||
return hsum_float_8(x);
|
||||
}
|
||||
#endif // __loongarch_asx
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// VECTORIZED MEMORY LOADING
|
||||
|
||||
@ -235,6 +318,47 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
|
||||
}
|
||||
#endif // __AVX512F__
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
template <> inline __m256 load(const float *p) {
|
||||
return (__m256)__lasx_xvld(p, 0);
|
||||
}
|
||||
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
|
||||
__m256i out;
|
||||
__asm__ volatile (
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
|
||||
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"//拼接hi和lo
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
".ifnc %[out], %[hi] \n\t"
|
||||
".irp i," __ALL_REGS "\n\t"
|
||||
" .ifc %[out], " XREGS_PREFIX "\\i \n\t"
|
||||
" .irp j," __ALL_REGS "\n\t"
|
||||
" .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
|
||||
" xvori.b $xr\\i, $xr\\j, 0 \n\t"//复制hi到out
|
||||
" .endif \n\t"
|
||||
" .endr \n\t"
|
||||
" .endif \n\t"
|
||||
".endr \n\t"
|
||||
".endif \n\t"
|
||||
: [out] "=f" (out), [hi] "+f" (inhi)
|
||||
: [lo] "f" (inlo)
|
||||
);
|
||||
return out;
|
||||
}
|
||||
template <> inline __m256 load(const ggml_fp16_t *p) {
|
||||
__m128i vector16 = __lsx_vld((const __m128i *)p, 0);
|
||||
__m128i hi = (__m128i)__lsx_vfcvth_s_h(vector16);
|
||||
__m128i lo = (__m128i)__lsx_vfcvtl_s_h(vector16);
|
||||
return (__m256)lasx_set_q(hi,lo);
|
||||
|
||||
}
|
||||
#endif // __loongarch_asx
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// FLOATING POINT MATRIX MULTIPLICATION
|
||||
|
||||
@ -813,6 +937,165 @@ class tinyBLAS_Q0_AVX {
|
||||
};
|
||||
#endif // __AVX__
|
||||
|
||||
#if defined(__loongarch_asx)
|
||||
template <typename TA, typename TB, typename TC>
|
||||
class tinyBLAS_Q0_LOONGARCH {
|
||||
public:
|
||||
tinyBLAS_Q0_LOONGARCH(int64_t k,
|
||||
const TA *A, int64_t lda,
|
||||
const TB *B, int64_t ldb,
|
||||
TC *C, int64_t ldc,
|
||||
int ith, int nth)
|
||||
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||
}
|
||||
|
||||
void matmul(int64_t m, int64_t n) {
|
||||
mnpack(0, m, 0, n);
|
||||
}
|
||||
|
||||
private:
|
||||
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int64_t mc, nc, mp, np;
|
||||
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 1)) {
|
||||
case 0x44:
|
||||
case 0x43:
|
||||
case 0x42:
|
||||
case 0x41:
|
||||
mc = 4;
|
||||
nc = 1;
|
||||
gemm<4, 1>(m0, m, n0, n);
|
||||
break;
|
||||
case 0x22:
|
||||
mc = 2;
|
||||
nc = 2;
|
||||
gemm<2, 2>(m0, m, n0, n);
|
||||
break;
|
||||
case 0x14:
|
||||
mc = 1;
|
||||
nc = 4;
|
||||
gemm<1, 4>(m0, m, n0, n);
|
||||
break;
|
||||
case 0x31:
|
||||
mc = 3;
|
||||
nc = 1;
|
||||
gemm<3, 1>(m0, m, n0, n);
|
||||
break;
|
||||
case 0x13:
|
||||
mc = 1;
|
||||
nc = 3;
|
||||
gemm<1, 3>(m0, m, n0, n);
|
||||
break;
|
||||
case 0x21:
|
||||
mc = 2;
|
||||
nc = 1;
|
||||
gemm<2, 1>(m0, m, n0, n);
|
||||
break;
|
||||
case 0x12:
|
||||
mc = 1;
|
||||
nc = 2;
|
||||
gemm<1, 2>(m0, m, n0, n);
|
||||
break;
|
||||
case 0x11:
|
||||
mc = 1;
|
||||
nc = 1;
|
||||
gemm<1, 1>(m0, m, n0, n);
|
||||
break;
|
||||
default:
|
||||
return;
|
||||
}
|
||||
mp = m0 + (m - m0) / mc * mc;
|
||||
np = n0 + (n - n0) / nc * nc;
|
||||
mnpack(mp, m, n0, np);
|
||||
mnpack(m0, m, np, n);
|
||||
}
|
||||
|
||||
inline __m256i load(const block_q8_0 *b) {
|
||||
return __lasx_xvld(((const __m256i *)b->qs), 0);
|
||||
}
|
||||
|
||||
static __m256i lasx_insertf128( __m128i x, __m128i y) {
|
||||
return lasx_set_q(x, y);
|
||||
}
|
||||
|
||||
// Unpack 32 4-bit fields into 32 bytes
|
||||
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
|
||||
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
|
||||
__m128i hi = __lsx_vsrli_h(lo, 4);
|
||||
return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
|
||||
}
|
||||
|
||||
inline __m256i load(const block_q4_0 *b) {
|
||||
return __lasx_xvsub_b(bytes_from_nibbles_32(b->qs), __lasx_xvreplgr2vr_b(8));
|
||||
}
|
||||
|
||||
// add int16_t pairwise and return as float vector
|
||||
static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
||||
__m256i v = __lasx_xvpackod_h(x, x);
|
||||
__m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
|
||||
return __lasx_xvffint_s_w(summed_pairs);
|
||||
}
|
||||
|
||||
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
|
||||
__m256i tmp1, tmp2;
|
||||
tmp1 = __lasx_xvmulwev_h_b(a, b);
|
||||
tmp2 = __lasx_xvmulwod_h_b(a, b);
|
||||
return __lasx_xvsadd_h(tmp1, tmp2);
|
||||
}
|
||||
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
||||
const __m256i dot = lasx_maddubs_h(ax, sy);
|
||||
return sum_i16_pairs_float(dot);
|
||||
}
|
||||
|
||||
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
||||
const __m256i ax = __lasx_xvsigncov_b(x, x);
|
||||
const __m256i sy = __lasx_xvsigncov_b(x, y);
|
||||
return mul_sum_us8_pairs_float(ax, sy);
|
||||
}
|
||||
|
||||
|
||||
template <int RM, int RN>
|
||||
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||
int64_t ytiles = (m - m0) / RM;
|
||||
int64_t xtiles = (n - n0) / RN;
|
||||
int64_t tiles = xtiles * ytiles;
|
||||
int64_t duty = (tiles + nth - 1) / nth;
|
||||
int64_t start = duty * ith;
|
||||
int64_t end = start + duty;
|
||||
if (end > tiles)
|
||||
end = tiles;
|
||||
for (int64_t job = start; job < end; ++job) {
|
||||
int64_t ii = m0 + job / xtiles * RM;
|
||||
int64_t jj = n0 + job % xtiles * RN;
|
||||
__m256 Cv[RN][RM] = {};
|
||||
for (int64_t l = 0; l < k; ++l)
|
||||
for (int64_t j = 0; j < RN; ++j)
|
||||
for (int64_t i = 0; i < RM; ++i) {
|
||||
__m256 udTmp = mul_sum_i8_pairs_float(load(B + ldb * (jj + j) + l),
|
||||
load(A + lda * (ii + i) + l));
|
||||
|
||||
Cv[j][i] = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(unhalf(A[lda * (ii + i) + l].d) *
|
||||
unhalf(B[ldb * (jj + j) + l].d)),
|
||||
udTmp,
|
||||
Cv[j][i]);
|
||||
}
|
||||
for (int64_t j = 0; j < RN; ++j)
|
||||
for (int64_t i = 0; i < RM; ++i)
|
||||
C[ldc * (jj + j) + (ii + i)] = hsum_float_8(Cv[j][i]);
|
||||
}
|
||||
}
|
||||
|
||||
const TA *const A;
|
||||
const TB *const B;
|
||||
TC *const C;
|
||||
const int64_t k;
|
||||
const int64_t lda;
|
||||
const int64_t ldb;
|
||||
const int64_t ldc;
|
||||
const int ith;
|
||||
const int nth;
|
||||
};
|
||||
#endif // __loongarch_asx
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
@ -897,6 +1180,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
||||
ith, nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
// #elif defined(__loongarch_asx)
|
||||
// if (k % 8)
|
||||
// return false;
|
||||
// tinyBLAS<8, __m256, __m256, float, float, float> tb{
|
||||
// k, (const float *)A, lda,
|
||||
// (const float *)B, ldb,
|
||||
// (float *)C, ldc,
|
||||
// ith, nth};
|
||||
// tb.matmul(m, n);
|
||||
// return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
@ -953,6 +1246,18 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
||||
ith, nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__loongarch_asx)
|
||||
if (k % 8)
|
||||
return false;
|
||||
if (Btype != GGML_TYPE_F32)
|
||||
return false;
|
||||
tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
|
||||
k, (const ggml_fp16_t *)A, lda,
|
||||
(const float *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
@ -977,6 +1282,14 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
||||
ith, nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__loongarch_asx)
|
||||
tinyBLAS_Q0_LOONGARCH<block_q8_0, block_q8_0, float> tb{
|
||||
k, (const block_q8_0 *)A, lda,
|
||||
(const block_q8_0 *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
@ -1001,6 +1314,14 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
||||
ith, nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#elif defined(__loongarch_asx)
|
||||
tinyBLAS_Q0_LOONGARCH<block_q4_0, block_q8_0, float> tb{
|
||||
k, (const block_q4_0 *)A, lda,
|
||||
(const block_q8_0 *)B, ldb,
|
||||
(float *)C, ldc,
|
||||
ith, nth};
|
||||
tb.matmul(m, n);
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
|
Loading…
Reference in New Issue
Block a user