mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-14 23:09:53 +00:00
Add support for loongarch backend in sgemm.cpp
This commit is contained in:
parent
de280085e7
commit
53ad4bd89f
@ -58,7 +58,7 @@
|
|||||||
#define NOINLINE __attribute__((__noinline__))
|
#define NOINLINE __attribute__((__noinline__))
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__ARM_NEON) || defined(__AVX512F__)
|
#if defined(__ARM_NEON) || defined(__AVX512F__) || defined(__loongarch_asx)
|
||||||
#define VECTOR_REGISTERS 32
|
#define VECTOR_REGISTERS 32
|
||||||
#else
|
#else
|
||||||
#define VECTOR_REGISTERS 16
|
#define VECTOR_REGISTERS 16
|
||||||
@ -66,6 +66,15 @@
|
|||||||
|
|
||||||
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
|
||||||
|
|
||||||
|
#ifdef __clang__
|
||||||
|
#define VREGS_PREFIX "$vr"
|
||||||
|
#define XREGS_PREFIX "$xr"
|
||||||
|
#else // GCC
|
||||||
|
#define VREGS_PREFIX "$f"
|
||||||
|
#define XREGS_PREFIX "$f"
|
||||||
|
#endif
|
||||||
|
#define __ALL_REGS "0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
inline float unhalf(ggml_fp16_t d) {
|
inline float unhalf(ggml_fp16_t d) {
|
||||||
@ -105,6 +114,12 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
|||||||
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
||||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
|
||||||
|
#if defined(__loongarch_asx)
|
||||||
|
inline __m256 add(__m256 x, __m256 y) { return __lasx_xvfadd_s(x, y); }
|
||||||
|
inline __m256 sub(__m256 x, __m256 y) { return __lasx_xvfsub_s(x, y); }
|
||||||
|
inline __m256 mul(__m256 x, __m256 y) { return __lasx_xvfmul_s(x, y); }
|
||||||
|
#endif // __loongarch_asx
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// VECTORIZED FUSED MULTIPLY ADD
|
// VECTORIZED FUSED MULTIPLY ADD
|
||||||
|
|
||||||
@ -144,6 +159,12 @@ inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
|
|||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(__loongarch_asx)
|
||||||
|
template <>
|
||||||
|
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
|
||||||
|
return __lasx_xvfmadd_s(a, b, c);
|
||||||
|
}
|
||||||
|
#endif //__loongarch_asx
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// VECTORIZED HORIZONTAL SUM
|
// VECTORIZED HORIZONTAL SUM
|
||||||
|
|
||||||
@ -189,6 +210,68 @@ inline float hsum(__m512 x) {
|
|||||||
}
|
}
|
||||||
#endif // __AVX512F__
|
#endif // __AVX512F__
|
||||||
|
|
||||||
|
#if defined(__loongarch_asx)
|
||||||
|
static inline __m128i lasx_extracti128_lo(__m256i in) {
|
||||||
|
__m128i out;
|
||||||
|
__asm__ volatile (
|
||||||
|
".ifnc %[out], %[in] \n\t"
|
||||||
|
".irp i," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
||||||
|
" .irp j," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
||||||
|
" vori.b $vr\\i, $vr\\j, 0 \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
" .endr \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
".endr \n\t"
|
||||||
|
".endif \n\t"
|
||||||
|
: [out] "=f" (out) : [in] "f" (in)
|
||||||
|
);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __m128i lasx_extracti128_hi(__m256i in) {
|
||||||
|
__m128i out;
|
||||||
|
__asm__ volatile (
|
||||||
|
".irp i," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[out], " VREGS_PREFIX "\\i \n\t"
|
||||||
|
" .irp j," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[in], " XREGS_PREFIX "\\j \n\t"
|
||||||
|
" xvpermi.q $xr\\i, $xr\\j, 0x11 \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
" .endr \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
".endr \n\t"
|
||||||
|
: [out] "=f" (out) : [in] "f" (in)
|
||||||
|
);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m128 lasx_extractf128( __m256 a, int pos) {
|
||||||
|
__m128 ret;
|
||||||
|
if( pos == 0)
|
||||||
|
{
|
||||||
|
ret = (__m128)lasx_extracti128_lo((__m256i)a);
|
||||||
|
} else {
|
||||||
|
ret = (__m128)lasx_extracti128_hi((__m256i)a);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
static inline float hsum_float_8(const __m256 x) {
|
||||||
|
__m128 res = lasx_extractf128(x, 1);
|
||||||
|
ft_union tmp;
|
||||||
|
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
|
||||||
|
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
|
||||||
|
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
|
||||||
|
tmp.i = __lsx_vpickve2gr_w(res, 0);
|
||||||
|
return tmp.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline float hsum(__m256 x) {
|
||||||
|
return hsum_float_8(x);
|
||||||
|
}
|
||||||
|
#endif // __loongarch_asx
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// VECTORIZED MEMORY LOADING
|
// VECTORIZED MEMORY LOADING
|
||||||
|
|
||||||
@ -235,6 +318,47 @@ template <> inline __m512 load(const ggml_fp16_t *p) {
|
|||||||
}
|
}
|
||||||
#endif // __AVX512F__
|
#endif // __AVX512F__
|
||||||
|
|
||||||
|
#if defined(__loongarch_asx)
|
||||||
|
template <> inline __m256 load(const float *p) {
|
||||||
|
return (__m256)__lasx_xvld(p, 0);
|
||||||
|
}
|
||||||
|
static inline __m256i lasx_set_q(__m128i inhi, __m128i inlo) {
|
||||||
|
__m256i out;
|
||||||
|
__asm__ volatile (
|
||||||
|
".irp i," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[hi], " VREGS_PREFIX "\\i \n\t"
|
||||||
|
" .irp j," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[lo], " VREGS_PREFIX "\\j \n\t"
|
||||||
|
" xvpermi.q $xr\\i, $xr\\j, 0x20 \n\t"//拼接hi和lo
|
||||||
|
" .endif \n\t"
|
||||||
|
" .endr \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
".endr \n\t"
|
||||||
|
".ifnc %[out], %[hi] \n\t"
|
||||||
|
".irp i," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[out], " XREGS_PREFIX "\\i \n\t"
|
||||||
|
" .irp j," __ALL_REGS "\n\t"
|
||||||
|
" .ifc %[hi], " VREGS_PREFIX "\\j \n\t"
|
||||||
|
" xvori.b $xr\\i, $xr\\j, 0 \n\t"//复制hi到out
|
||||||
|
" .endif \n\t"
|
||||||
|
" .endr \n\t"
|
||||||
|
" .endif \n\t"
|
||||||
|
".endr \n\t"
|
||||||
|
".endif \n\t"
|
||||||
|
: [out] "=f" (out), [hi] "+f" (inhi)
|
||||||
|
: [lo] "f" (inlo)
|
||||||
|
);
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
template <> inline __m256 load(const ggml_fp16_t *p) {
|
||||||
|
__m128i vector16 = __lsx_vld((const __m128i *)p, 0);
|
||||||
|
__m128i hi = (__m128i)__lsx_vfcvth_s_h(vector16);
|
||||||
|
__m128i lo = (__m128i)__lsx_vfcvtl_s_h(vector16);
|
||||||
|
return (__m256)lasx_set_q(hi,lo);
|
||||||
|
|
||||||
|
}
|
||||||
|
#endif // __loongarch_asx
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// FLOATING POINT MATRIX MULTIPLICATION
|
// FLOATING POINT MATRIX MULTIPLICATION
|
||||||
|
|
||||||
@ -813,6 +937,165 @@ class tinyBLAS_Q0_AVX {
|
|||||||
};
|
};
|
||||||
#endif // __AVX__
|
#endif // __AVX__
|
||||||
|
|
||||||
|
#if defined(__loongarch_asx)
|
||||||
|
template <typename TA, typename TB, typename TC>
|
||||||
|
class tinyBLAS_Q0_LOONGARCH {
|
||||||
|
public:
|
||||||
|
tinyBLAS_Q0_LOONGARCH(int64_t k,
|
||||||
|
const TA *A, int64_t lda,
|
||||||
|
const TB *B, int64_t ldb,
|
||||||
|
TC *C, int64_t ldc,
|
||||||
|
int ith, int nth)
|
||||||
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void matmul(int64_t m, int64_t n) {
|
||||||
|
mnpack(0, m, 0, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
|
int64_t mc, nc, mp, np;
|
||||||
|
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 1)) {
|
||||||
|
case 0x44:
|
||||||
|
case 0x43:
|
||||||
|
case 0x42:
|
||||||
|
case 0x41:
|
||||||
|
mc = 4;
|
||||||
|
nc = 1;
|
||||||
|
gemm<4, 1>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x22:
|
||||||
|
mc = 2;
|
||||||
|
nc = 2;
|
||||||
|
gemm<2, 2>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x14:
|
||||||
|
mc = 1;
|
||||||
|
nc = 4;
|
||||||
|
gemm<1, 4>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x31:
|
||||||
|
mc = 3;
|
||||||
|
nc = 1;
|
||||||
|
gemm<3, 1>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x13:
|
||||||
|
mc = 1;
|
||||||
|
nc = 3;
|
||||||
|
gemm<1, 3>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x21:
|
||||||
|
mc = 2;
|
||||||
|
nc = 1;
|
||||||
|
gemm<2, 1>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x12:
|
||||||
|
mc = 1;
|
||||||
|
nc = 2;
|
||||||
|
gemm<1, 2>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
case 0x11:
|
||||||
|
mc = 1;
|
||||||
|
nc = 1;
|
||||||
|
gemm<1, 1>(m0, m, n0, n);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
mp = m0 + (m - m0) / mc * mc;
|
||||||
|
np = n0 + (n - n0) / nc * nc;
|
||||||
|
mnpack(mp, m, n0, np);
|
||||||
|
mnpack(m0, m, np, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __m256i load(const block_q8_0 *b) {
|
||||||
|
return __lasx_xvld(((const __m256i *)b->qs), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_insertf128( __m128i x, __m128i y) {
|
||||||
|
return lasx_set_q(x, y);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unpack 32 4-bit fields into 32 bytes
|
||||||
|
// The output vector contains 32 bytes, each one in [ 0 .. 15 ] interval
|
||||||
|
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
|
||||||
|
const __m128i lo = __lsx_vld((const __m128i *)rsi, 0);
|
||||||
|
__m128i hi = __lsx_vsrli_h(lo, 4);
|
||||||
|
return __lasx_xvandi_b(lasx_insertf128(hi, lo), 0xf);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline __m256i load(const block_q4_0 *b) {
|
||||||
|
return __lasx_xvsub_b(bytes_from_nibbles_32(b->qs), __lasx_xvreplgr2vr_b(8));
|
||||||
|
}
|
||||||
|
|
||||||
|
// add int16_t pairwise and return as float vector
|
||||||
|
static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
||||||
|
__m256i v = __lasx_xvpackod_h(x, x);
|
||||||
|
__m256i summed_pairs = __lasx_xvaddwev_w_h(x, v);
|
||||||
|
return __lasx_xvffint_s_w(summed_pairs);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
|
||||||
|
__m256i tmp1, tmp2;
|
||||||
|
tmp1 = __lasx_xvmulwev_h_b(a, b);
|
||||||
|
tmp2 = __lasx_xvmulwod_h_b(a, b);
|
||||||
|
return __lasx_xvsadd_h(tmp1, tmp2);
|
||||||
|
}
|
||||||
|
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
||||||
|
const __m256i dot = lasx_maddubs_h(ax, sy);
|
||||||
|
return sum_i16_pairs_float(dot);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline __m256 mul_sum_i8_pairs_float(const __m256i x, const __m256i y) {
|
||||||
|
const __m256i ax = __lasx_xvsigncov_b(x, x);
|
||||||
|
const __m256i sy = __lasx_xvsigncov_b(x, y);
|
||||||
|
return mul_sum_us8_pairs_float(ax, sy);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <int RM, int RN>
|
||||||
|
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
|
int64_t ytiles = (m - m0) / RM;
|
||||||
|
int64_t xtiles = (n - n0) / RN;
|
||||||
|
int64_t tiles = xtiles * ytiles;
|
||||||
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
|
int64_t start = duty * ith;
|
||||||
|
int64_t end = start + duty;
|
||||||
|
if (end > tiles)
|
||||||
|
end = tiles;
|
||||||
|
for (int64_t job = start; job < end; ++job) {
|
||||||
|
int64_t ii = m0 + job / xtiles * RM;
|
||||||
|
int64_t jj = n0 + job % xtiles * RN;
|
||||||
|
__m256 Cv[RN][RM] = {};
|
||||||
|
for (int64_t l = 0; l < k; ++l)
|
||||||
|
for (int64_t j = 0; j < RN; ++j)
|
||||||
|
for (int64_t i = 0; i < RM; ++i) {
|
||||||
|
__m256 udTmp = mul_sum_i8_pairs_float(load(B + ldb * (jj + j) + l),
|
||||||
|
load(A + lda * (ii + i) + l));
|
||||||
|
|
||||||
|
Cv[j][i] = __lasx_xvfmadd_s(__lasx_xvreplfr2vr_s(unhalf(A[lda * (ii + i) + l].d) *
|
||||||
|
unhalf(B[ldb * (jj + j) + l].d)),
|
||||||
|
udTmp,
|
||||||
|
Cv[j][i]);
|
||||||
|
}
|
||||||
|
for (int64_t j = 0; j < RN; ++j)
|
||||||
|
for (int64_t i = 0; i < RM; ++i)
|
||||||
|
C[ldc * (jj + j) + (ii + i)] = hsum_float_8(Cv[j][i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const TA *const A;
|
||||||
|
const TB *const B;
|
||||||
|
TC *const C;
|
||||||
|
const int64_t k;
|
||||||
|
const int64_t lda;
|
||||||
|
const int64_t ldb;
|
||||||
|
const int64_t ldc;
|
||||||
|
const int ith;
|
||||||
|
const int nth;
|
||||||
|
};
|
||||||
|
#endif // __loongarch_asx
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -897,6 +1180,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||||||
ith, nth};
|
ith, nth};
|
||||||
tb.matmul(m, n);
|
tb.matmul(m, n);
|
||||||
return true;
|
return true;
|
||||||
|
// #elif defined(__loongarch_asx)
|
||||||
|
// if (k % 8)
|
||||||
|
// return false;
|
||||||
|
// tinyBLAS<8, __m256, __m256, float, float, float> tb{
|
||||||
|
// k, (const float *)A, lda,
|
||||||
|
// (const float *)B, ldb,
|
||||||
|
// (float *)C, ldc,
|
||||||
|
// ith, nth};
|
||||||
|
// tb.matmul(m, n);
|
||||||
|
// return true;
|
||||||
#else
|
#else
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
@ -953,6 +1246,18 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||||||
ith, nth};
|
ith, nth};
|
||||||
tb.matmul(m, n);
|
tb.matmul(m, n);
|
||||||
return true;
|
return true;
|
||||||
|
#elif defined(__loongarch_asx)
|
||||||
|
if (k % 8)
|
||||||
|
return false;
|
||||||
|
if (Btype != GGML_TYPE_F32)
|
||||||
|
return false;
|
||||||
|
tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
|
||||||
|
k, (const ggml_fp16_t *)A, lda,
|
||||||
|
(const float *)B, ldb,
|
||||||
|
(float *)C, ldc,
|
||||||
|
ith, nth};
|
||||||
|
tb.matmul(m, n);
|
||||||
|
return true;
|
||||||
#else
|
#else
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
@ -977,6 +1282,14 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||||||
ith, nth};
|
ith, nth};
|
||||||
tb.matmul(m, n);
|
tb.matmul(m, n);
|
||||||
return true;
|
return true;
|
||||||
|
#elif defined(__loongarch_asx)
|
||||||
|
tinyBLAS_Q0_LOONGARCH<block_q8_0, block_q8_0, float> tb{
|
||||||
|
k, (const block_q8_0 *)A, lda,
|
||||||
|
(const block_q8_0 *)B, ldb,
|
||||||
|
(float *)C, ldc,
|
||||||
|
ith, nth};
|
||||||
|
tb.matmul(m, n);
|
||||||
|
return true;
|
||||||
#else
|
#else
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
@ -1001,6 +1314,14 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||||||
ith, nth};
|
ith, nth};
|
||||||
tb.matmul(m, n);
|
tb.matmul(m, n);
|
||||||
return true;
|
return true;
|
||||||
|
#elif defined(__loongarch_asx)
|
||||||
|
tinyBLAS_Q0_LOONGARCH<block_q4_0, block_q8_0, float> tb{
|
||||||
|
k, (const block_q4_0 *)A, lda,
|
||||||
|
(const block_q8_0 *)B, ldb,
|
||||||
|
(float *)C, ldc,
|
||||||
|
ith, nth};
|
||||||
|
tb.matmul(m, n);
|
||||||
|
return true;
|
||||||
#else
|
#else
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
Reference in New Issue
Block a user