llama.cpp/sgemm.cpp

1031 lines
31 KiB
C++
Raw Normal View History

// Copyright 2024 Mozilla Foundation
//
// Permission is hereby granted, free of charge, to any person obtaining
// a copy of this software and associated documentation files (the
// "Software"), to deal in the Software without restriction, including
// without limitation the rights to use, copy, modify, merge, publish,
// distribute, sublicense, and/or sell copies of the Software, and to
// permit persons to whom the Software is furnished to do so, subject to
// the following conditions:
//
// The above copyright notice and this permission notice shall be
// included in all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
// EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
// MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
// BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
// ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
// CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.
//
// _ _ ___ _ _ ___
// | |_(_)_ _ _ _| _ ) | /_\ / __|
// | _| | ' \ || | _ \ |__ / _ \\__ \.
// \__|_|_||_\_, |___/____/_/ \_\___/
// |__/
//
// BASIC LINEAR ALGEBRA SUBPROGRAMS
//
//
// This file implements multithreaded CPU matrix multiplication for the
// common contiguous use case C = Aᵀ * B. These kernels are designed to
// have excellent performance[1] for matrices that fit in the CPU cache
// without imposing any overhead such as cache filling or malloc calls.
//
// This implementation does not guarantee any upper bound with rounding
// errors, which grow along with k. Our goal's to maximally exploit the
// hardware for performance, and then use whatever resources remain for
// improving numerical accuracy.
//
// [1] J. Tunney, LLaMA Now Goes Faster on CPUs, Mar. 2024. [Online].
// Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024].
#pragma GCC diagnostic ignored "-Wpedantic"
#pragma GCC diagnostic ignored "-Wignored-attributes"
#include "sgemm.h"
#include "ggml-impl.h"
#include "ggml-quants.h"
#ifdef _MSC_VER
#define NOINLINE __declspec(noinline)
#else
#define NOINLINE __attribute__((__noinline__))
#endif
#if defined(__ARM_NEON) || defined(__AVX512F__)
#define VECTOR_REGISTERS 32
#else
#define VECTOR_REGISTERS 16
#endif
#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
namespace {
inline float unhalf(ggml_fp16_t d) {
return GGML_FP16_TO_FP32(d);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED ARITHMETIC OPERATIONS
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); }
inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); }
inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); }
#endif // __SSE__
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); }
inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); }
inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); }
#endif // __AVX__
#if defined(__AVX512F__)
inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); }
inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); }
inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); }
#endif // __AVX512F__
#if defined(__ARM_NEON)
inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); }
inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); }
inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); }
#endif // __ARM_NEON
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); }
inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED FUSED MULTIPLY ADD
/**
* Computes a * b + c.
*/
template <typename T, typename U>
inline U madd(T a, T b, U c) {
return add(mul(a, b), c);
}
#if defined(__FMA__)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <>
inline __m256 madd(__m256 a, __m256 b, __m256 c) {
return _mm256_fmadd_ps(a, b, c);
}
#endif
#if defined(__AVX512F__)
template <>
inline __m512 madd(__m512 a, __m512 b, __m512 c) {
return _mm512_fmadd_ps(a, b, c);
}
#endif
#endif
#if defined(__ARM_FEATURE_FMA)
template <>
inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) {
return vfmaq_f32(c, b, a);
}
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
template <>
inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) {
return vfmaq_f16(c, b, a);
}
#endif
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED HORIZONTAL SUM
#if defined(__ARM_NEON)
inline float hsum(float32x4_t x) {
return vaddvq_f32(x);
}
#endif // __ARM_NEON
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
inline float hsum(float16x8_t x) {
return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)),
vcvt_f32_f16(vget_high_f16(x))));
}
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline float hsum(__m128 x) {
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
x = _mm_add_ps(x, _mm_movehl_ps(x, x));
x = _mm_add_ss(x, _mm_movehdup_ps(x));
#else
__m128 t;
t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1));
x = _mm_add_ps(x, t);
t = _mm_movehl_ps(t, x);
x = _mm_add_ss(x, t);
#endif
return _mm_cvtss_f32(x);
}
#endif
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
inline float hsum(__m256 x) {
return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1),
_mm256_castps256_ps128(x)));
}
#endif // __AVX__
#if defined(__AVX512F__)
inline float hsum(__m512 x) {
return _mm512_reduce_add_ps(x);
}
#endif // __AVX512F__
////////////////////////////////////////////////////////////////////////////////////////////////////
// VECTORIZED MEMORY LOADING
template <typename T, typename U> T load(const U *);
#if defined(__ARM_NEON)
template <> inline float32x4_t load(const float *p) {
return vld1q_f32(p);
}
#if !defined(_MSC_VER)
template <> inline float16x8_t load(const ggml_fp16_t *p) {
return vld1q_f16((const float16_t *)p);
}
template <> inline float32x4_t load(const ggml_fp16_t *p) {
return vcvt_f32_f16(vld1_f16((const float16_t *)p));
}
#endif // _MSC_VER
#endif // __ARM_NEON
#if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <> inline __m128 load(const float *p) {
return _mm_loadu_ps(p);
}
#endif // __SSE__
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
template <> inline __m256 load(const float *p) {
return _mm256_loadu_ps(p);
}
#endif // __AVX__
#if defined(__F16C__)
template <> inline __m256 load(const ggml_fp16_t *p) {
return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p));
}
#endif // __F16C__
#if defined(__AVX512F__)
template <> inline __m512 load(const float *p) {
return _mm512_loadu_ps(p);
}
template <> inline __m512 load(const ggml_fp16_t *p) {
return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p));
}
#endif // __AVX512F__
////////////////////////////////////////////////////////////////////////////////////////////////////
// FLOATING POINT MATRIX MULTIPLICATION
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
class tinyBLAS {
public:
tinyBLAS(int64_t k,
const TA *A, int64_t lda,
const TB *B, int64_t ldb,
TC *C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
}
private:
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t mc, nc, mp, np;
switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) {
#if VECTOR_REGISTERS == 32
case 0x55:
mc = 5;
nc = 5;
gemm<5, 5>(m0, m, n0, n);
break;
case 0x45:
mc = 4;
nc = 5;
gemm<4, 5>(m0, m, n0, n);
break;
case 0x54:
mc = 5;
nc = 4;
gemm<5, 4>(m0, m, n0, n);
break;
case 0x44:
mc = 4;
nc = 4;
gemm<4, 4>(m0, m, n0, n);
break;
case 0x53:
mc = 5;
nc = 3;
gemm<5, 3>(m0, m, n0, n);
break;
case 0x35:
mc = 3;
nc = 5;
gemm<3, 5>(m0, m, n0, n);
break;
case 0x43:
mc = 4;
nc = 3;
gemm<4, 3>(m0, m, n0, n);
break;
#else
case 0x55:
case 0x54:
case 0x53:
case 0x45:
case 0x44:
case 0x43:
mc = 4;
nc = 3;
gemm<4, 3>(m0, m, n0, n);
break;
case 0x35:
#endif
case 0x34:
mc = 3;
nc = 4;
gemm<3, 4>(m0, m, n0, n);
break;
case 0x52:
mc = 5;
nc = 2;
gemm<5, 2>(m0, m, n0, n);
break;
case 0x33:
mc = 3;
nc = 3;
gemm<3, 3>(m0, m, n0, n);
break;
case 0x25:
mc = 2;
nc = 5;
gemm<2, 5>(m0, m, n0, n);
break;
case 0x42:
mc = 4;
nc = 2;
gemm<4, 2>(m0, m, n0, n);
break;
case 0x24:
mc = 2;
nc = 4;
gemm<2, 4>(m0, m, n0, n);
break;
case 0x32:
mc = 3;
nc = 2;
gemm<3, 2>(m0, m, n0, n);
break;
case 0x23:
mc = 2;
nc = 3;
gemm<2, 3>(m0, m, n0, n);
break;
case 0x51:
mc = 5;
nc = 1;
gemm<5, 1>(m0, m, n0, n);
break;
case 0x41:
mc = 4;
nc = 1;
gemm<4, 1>(m0, m, n0, n);
break;
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
break;
case 0x15:
mc = 1;
nc = 5;
gemm<1, 5>(m0, m, n0, n);
break;
case 0x14:
mc = 1;
nc = 4;
gemm<1, 4>(m0, m, n0, n);
break;
case 0x31:
mc = 3;
nc = 1;
gemm<3, 1>(m0, m, n0, n);
break;
case 0x13:
mc = 1;
nc = 3;
gemm<1, 3>(m0, m, n0, n);
break;
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
break;
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
break;
default:
return;
}
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
template <int RM, int RN>
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
D Cv[RN][RM] = {};
for (int64_t l = 0; l < k; l += KN)
for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i)
Cv[j][i] = madd(load<V>(A + lda * (ii + i) + l),
load<V>(B + ldb * (jj + j) + l),
Cv[j][i]);
for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
}
}
const TA *const A;
const TB *const B;
TC *const C;
const int64_t k;
const int64_t lda;
const int64_t ldb;
const int64_t ldc;
const int ith;
const int nth;
};
//////////////////////////////////////////////////////////////////////////////////////////
// QUANT ZERO MATRIX MULTIPLICATION
#if defined(__ARM_FEATURE_DOTPROD)
template <typename TA>
class tinyBLAS_Q0_ARM {
public:
tinyBLAS_Q0_ARM(int64_t k,
const TA *A, int64_t lda,
const block_q8_0 *B, int64_t ldb,
float *C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
}
private:
NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t mc, nc, mp, np;
switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) {
case 0x33:
mc = 3;
nc = 3;
gemm<3, 3>(m0, m, n0, n);
break;
case 0x32:
mc = 3;
nc = 2;
gemm<3, 2>(m0, m, n0, n);
break;
case 0x23:
mc = 2;
nc = 3;
gemm<2, 3>(m0, m, n0, n);
break;
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
break;
case 0x31:
mc = 3;
nc = 1;
gemm<3, 1>(m0, m, n0, n);
break;
case 0x13:
mc = 1;
nc = 3;
gemm<1, 3>(m0, m, n0, n);
break;
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
break;
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
break;
default:
return;
}
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
template <int RM, int RN>
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
float32x4_t Cv[RN][RM] = {};
for (int64_t l = 0; l < k; ++l)
for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i)
Cv[j][i] = vmlaq_n_f32(Cv[j][i],
vcvtq_f32_s32(vdotq_s32(
vdotq_s32(vdupq_n_s32(0),
load_lo(A + lda * (ii + i) + l),
load_lo(B + ldb * (jj + j) + l)),
load_hi(A + lda * (ii + i) + l),
load_hi(B + ldb * (jj + j) + l))),
unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d));
for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
}
}
inline int8x16_t load_lo(const block_q8_0 *b) {
return vld1q_s8(b->qs);
}
inline int8x16_t load_hi(const block_q8_0 *b) {
return vld1q_s8(b->qs + 16);
}
inline int8x16_t load_lo(const block_q4_0 *b) {
return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs),
vdupq_n_u8(0x0f))),
vdupq_n_s8(0x8));
}
inline int8x16_t load_hi(const block_q4_0 *b) {
return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)),
vdupq_n_s8(0x8));
}
const TA *const A;
const block_q8_0 *const B;
float *const C;
const int64_t k;
const int64_t lda;
const int64_t ldb;
const int64_t ldc;
const int ith;
const int nth;
};
#endif // __ARM_FEATURE_DOTPROD
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
template <typename TA, typename TB, typename TC>
class tinyBLAS_Q0_AVX {
public:
tinyBLAS_Q0_AVX(int64_t k,
const TA *A, int64_t lda,
const TB *B, int64_t ldb,
TC *C, int64_t ldc,
int ith, int nth)
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
}
void matmul(int64_t m, int64_t n, int task) {
if (task == GGML_TASK_TYPE_COMPUTE)
mnpack(0, m, 0, n);
}
private:
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t mc, nc, mp, np;
switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) {
#if VECTOR_REGISTERS == 32
case 0x44:
mc = 4;
nc = 4;
gemm<4, 4>(m0, m, n0, n);
break;
case 0x43:
mc = 4;
nc = 3;
gemm<4, 3>(m0, m, n0, n);
break;
case 0x34:
mc = 3;
nc = 4;
gemm<3, 4>(m0, m, n0, n);
break;
case 0x33:
mc = 3;
nc = 3;
gemm<3, 3>(m0, m, n0, n);
break;
case 0x42:
mc = 4;
nc = 2;
gemm<4, 2>(m0, m, n0, n);
break;
case 0x24:
mc = 2;
nc = 4;
gemm<2, 4>(m0, m, n0, n);
break;
#else
case 0x44:
case 0x43:
case 0x42:
mc = 4;
nc = 2;
gemm<4, 2>(m0, m, n0, n);
break;
case 0x34:
case 0x24:
mc = 2;
nc = 4;
gemm<2, 4>(m0, m, n0, n);
break;
case 0x33:
#endif
case 0x32:
mc = 3;
nc = 2;
gemm<3, 2>(m0, m, n0, n);
break;
case 0x23:
mc = 2;
nc = 3;
gemm<2, 3>(m0, m, n0, n);
break;
case 0x41:
mc = 4;
nc = 1;
gemm<4, 1>(m0, m, n0, n);
break;
case 0x22:
mc = 2;
nc = 2;
gemm<2, 2>(m0, m, n0, n);
break;
case 0x14:
mc = 1;
nc = 4;
gemm<1, 4>(m0, m, n0, n);
break;
case 0x31:
mc = 3;
nc = 1;
gemm<3, 1>(m0, m, n0, n);
break;
case 0x13:
mc = 1;
nc = 3;
gemm<1, 3>(m0, m, n0, n);
break;
case 0x21:
mc = 2;
nc = 1;
gemm<2, 1>(m0, m, n0, n);
break;
case 0x12:
mc = 1;
nc = 2;
gemm<1, 2>(m0, m, n0, n);
break;
case 0x11:
mc = 1;
nc = 1;
gemm<1, 1>(m0, m, n0, n);
break;
default:
return;
}
mp = m0 + (m - m0) / mc * mc;
np = n0 + (n - n0) / nc * nc;
mnpack(mp, m, n0, np);
mnpack(m0, m, np, n);
}
template <int RM, int RN>
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
int64_t ytiles = (m - m0) / RM;
int64_t xtiles = (n - n0) / RN;
int64_t tiles = xtiles * ytiles;
int64_t duty = (tiles + nth - 1) / nth;
int64_t start = duty * ith;
int64_t end = start + duty;
if (end > tiles)
end = tiles;
for (int64_t job = start; job < end; ++job) {
int64_t ii = m0 + job / xtiles * RM;
int64_t jj = n0 + job % xtiles * RN;
__m256 Cv[RN][RM] = {};
for (int64_t l = 0; l < k; ++l)
for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i) {
#if defined(__AVX2__)
__m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
load(A + lda * (ii + i) + l)),
_mm256_sign_epi8(load(B + ldb * (jj + j) + l),
load(A + lda * (ii + i) + l)));
#else
__m128i ali0 = load0(A + lda * (ii + i) + l);
__m128i ali1 = load1(A + lda * (ii + i) + l);
__m128i blj0 = load0(B + ldb * (jj + j) + l);
__m128i blj1 = load1(B + ldb * (jj + j) + l);
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
// updot
const __m128i oneFill = _mm_set1_epi16(1);
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
#endif
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d)),
udTmp,
Cv[j][i]);
}
for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
}
}
inline __m256i load(const block_q8_0 *b) {
return _mm256_loadu_si256((const __m256i *)b->qs);
}
inline __m128i load0(const block_q8_0 *b) {
return _mm_loadu_si128((const __m128i *)b->qs);
}
inline __m128i load1(const block_q8_0 *b) {
return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
}
inline __m256i load(const block_q4_0 *b) {
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
}
inline __m128i load0(const block_q4_0 *b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
}
inline __m128i load1(const block_q4_0 *b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
}
inline __m256 updot(__m256i u, __m256i s) {
__m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
#else
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
#endif
return _mm256_cvtepi32_ps(res);
}
static inline __m256i denibble(const uint8_t *p) {
__m128i x = _mm_loadu_si128((const __m128i *)p);
return _mm256_and_si256(_mm256_set1_epi8(15),
_mm256_insertf128_si256(_mm256_castsi128_si256(x),
_mm_srli_epi16(x, 4), 1));
}
const TA *const A;
const TB *const B;
TC *const C;
const int64_t k;
const int64_t lda;
const int64_t ldb;
const int64_t ldc;
const int ith;
const int nth;
};
#endif // __AVX__
} // namespace
/**
* Performs optimized matrix multiplication on CPU.
*
* This subroutine may compute C = Aᵀ * B with column major ordering.
* Despite its name, this isn't a generalized implementation. Work is
* only performed when a handwritten kernel is written and available.
* Otherwise the caller should fall back to a general matmul routine.
*
* For example, for single-threaded single-precision GEMM you can say
*
* llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc,
* 0, 1, GGML_TASK_TYPE_COMPUTE,
* GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32);
*
* @param m is rows in `A` and `C`
* @param n is cols in `B` and `C`
* @param k is cols in `A` and rows in `B`
* @param A is first input matrix (always transposed)
* @param lda is row stride of `A`
* @param B is second input matrix (never transposed)
* @param ldb is row stride of `B`
* @param C is input/output array of output matrices
* @param ldc is row stride of `C`
* @param ith is thread id (must be less than `nth`)
* @param nth is number of threads (must be greater than zero)
* @param task is GGML task type
* @param Atype is GGML data type of `A`
* @param Btype is GGML data type of `B`
* @param Ctype is GGML data type of `C`
* @return true if this function was able to service the matmul request
*/
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) {
assert(m >= 0);
assert(n >= 0);
assert(k >= 0);
assert(lda >= k);
assert(ldb >= k);
assert(ldc >= m);
assert(nth > 0);
assert(ith < nth);
if (Ctype != GGML_TYPE_F32)
return false;
switch (Atype) {
case GGML_TYPE_F32: {
if (Btype != GGML_TYPE_F32)
return false;
#if defined(__AVX512F__)
if (k % 16)
return false;
tinyBLAS<16, __m512, __m512, float, float, float> tb{
k, (const float *)A, lda,
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__AVX__) || defined(__AVX2__)
if (k % 8)
return false;
tinyBLAS<8, __m256, __m256, float, float, float> tb{
k, (const float *)A, lda,
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON)
if (n < 4)
return false;
if (k % 4)
return false;
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{
k, (const float *)A, lda,
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#else
return false;
#endif
}
case GGML_TYPE_F16: {
#if defined(__AVX512F__)
if (k % 16)
return false;
if (Btype != GGML_TYPE_F32)
return false;
tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{
k, (const ggml_fp16_t *)A, lda,
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
if (k % 8)
return false;
if (Btype != GGML_TYPE_F32)
return false;
tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{
k, (const ggml_fp16_t *)A, lda,
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 8)
return false;
if (k % 8)
return false;
if (Btype != GGML_TYPE_F16)
return false;
tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{
k, (const ggml_fp16_t *)A, lda,
(const ggml_fp16_t *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (k % 4)
return false;
if (Btype != GGML_TYPE_F32)
return false;
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{
k, (const ggml_fp16_t *)A, lda,
(const float *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#else
return false;
#endif
}
case GGML_TYPE_Q8_0: {
if (Btype != GGML_TYPE_Q8_0)
return false;
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<block_q8_0> tb{
k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#else
return false;
#endif
}
case GGML_TYPE_Q4_0: {
if (Btype != GGML_TYPE_Q8_0)
return false;
#if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#elif defined(__ARM_FEATURE_DOTPROD)
tinyBLAS_Q0_ARM<block_q4_0> tb{
k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb,
(float *)C, ldc,
ith, nth};
tb.matmul(m, n, task);
return true;
#else
return false;
#endif
}
default:
return false;
}
(void)m;
(void)n;
(void)k;
(void)A;
(void)lda;
(void)B;
(void)ldb;
(void)C;
(void)ldc;
(void)ith;
(void)nth;
(void)task;
(void)Atype;
(void)Btype;
(void)Ctype;
}