// Copyright 2024 Mozilla Foundation // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the // "Software"), to deal in the Software without restriction, including // without limitation the rights to use, copy, modify, merge, publish, // distribute, sublicense, and/or sell copies of the Software, and to // permit persons to whom the Software is furnished to do so, subject to // the following conditions: // // The above copyright notice and this permission notice shall be // included in all copies or substantial portions of the Software. // // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, // EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF // MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND // NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS // BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN // ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN // CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. // // _ _ ___ _ _ ___ // | |_(_)_ _ _ _| _ ) | /_\ / __| // | _| | ' \ || | _ \ |__ / _ \\__ \. // \__|_|_||_\_, |___/____/_/ \_\___/ // |__/ // // BASIC LINEAR ALGEBRA SUBPROGRAMS // // // This file implements multithreaded CPU matrix multiplication for the // common contiguous use case C = Aᵀ * B. These kernels are designed to // have excellent performance[1] for matrices that fit in the CPU cache // without imposing any overhead such as cache filling or malloc calls. // // This implementation does not guarantee any upper bound with rounding // errors, which grow along with k. Our goal's to maximally exploit the // hardware for performance, and then use whatever resources remain for // improving numerical accuracy. // // [1] J. Tunney, ‘LLaMA Now Goes Faster on CPUs’, Mar. 2024. [Online]. // Available: https://justine.lol/matmul/. [Accessed: 29-Mar-2024]. #if defined(__GNUC__) #pragma GCC diagnostic ignored "-Wpedantic" #pragma GCC diagnostic ignored "-Wignored-attributes" #endif #include "sgemm.h" #include "ggml-impl.h" #include "ggml-quants.h" #ifdef _MSC_VER #define NOINLINE __declspec(noinline) #else #define NOINLINE __attribute__((__noinline__)) #endif #if defined(__ARM_NEON) || defined(__AVX512F__) #define VECTOR_REGISTERS 32 #else #define VECTOR_REGISTERS 16 #endif #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1) namespace { inline float unhalf(ggml_fp16_t d) { return GGML_FP16_TO_FP32(d); } //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED ARITHMETIC OPERATIONS #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline __m128 add(__m128 x, __m128 y) { return _mm_add_ps(x, y); } inline __m128 sub(__m128 x, __m128 y) { return _mm_sub_ps(x, y); } inline __m128 mul(__m128 x, __m128 y) { return _mm_mul_ps(x, y); } #endif // __SSE__ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline __m256 add(__m256 x, __m256 y) { return _mm256_add_ps(x, y); } inline __m256 sub(__m256 x, __m256 y) { return _mm256_sub_ps(x, y); } inline __m256 mul(__m256 x, __m256 y) { return _mm256_mul_ps(x, y); } #endif // __AVX__ #if defined(__AVX512F__) inline __m512 add(__m512 x, __m512 y) { return _mm512_add_ps(x, y); } inline __m512 sub(__m512 x, __m512 y) { return _mm512_sub_ps(x, y); } inline __m512 mul(__m512 x, __m512 y) { return _mm512_mul_ps(x, y); } #endif // __AVX512F__ #if defined(__ARM_NEON) inline float32x4_t add(float32x4_t x, float32x4_t y) { return vaddq_f32(x, y); } inline float32x4_t sub(float32x4_t x, float32x4_t y) { return vsubq_f32(x, y); } inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vmulq_f32(x, y); } #endif // __ARM_NEON #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) inline float16x8_t add(float16x8_t x, float16x8_t y) { return vaddq_f16(x, y); } inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); } inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD /** * Computes a * b + c. */ template inline U madd(T a, T b, U c) { return add(mul(a, b), c); } #if defined(__FMA__) #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> inline __m256 madd(__m256 a, __m256 b, __m256 c) { return _mm256_fmadd_ps(a, b, c); } #endif #if defined(__AVX512F__) template <> inline __m512 madd(__m512 a, __m512 b, __m512 c) { return _mm512_fmadd_ps(a, b, c); } #endif #endif #if defined(__ARM_FEATURE_FMA) template <> inline float32x4_t madd(float32x4_t a, float32x4_t b, float32x4_t c) { return vfmaq_f32(c, b, a); } #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) template <> inline float16x8_t madd(float16x8_t a, float16x8_t b, float16x8_t c) { return vfmaq_f16(c, b, a); } #endif #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED HORIZONTAL SUM #if defined(__ARM_NEON) inline float hsum(float32x4_t x) { return vaddvq_f32(x); } #endif // __ARM_NEON #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) inline float hsum(float16x8_t x) { return vaddvq_f32(vaddq_f32(vcvt_f32_f16(vget_low_f16(x)), vcvt_f32_f16(vget_high_f16(x)))); } #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline float hsum(__m128 x) { #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) x = _mm_add_ps(x, _mm_movehl_ps(x, x)); x = _mm_add_ss(x, _mm_movehdup_ps(x)); #else __m128 t; t = _mm_shuffle_ps(x, x, _MM_SHUFFLE(2, 3, 0, 1)); x = _mm_add_ps(x, t); t = _mm_movehl_ps(t, x); x = _mm_add_ss(x, t); #endif return _mm_cvtss_f32(x); } #endif #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) inline float hsum(__m256 x) { return hsum(_mm_add_ps(_mm256_extractf128_ps(x, 1), _mm256_castps256_ps128(x))); } #endif // __AVX__ #if defined(__AVX512F__) inline float hsum(__m512 x) { return _mm512_reduce_add_ps(x); } #endif // __AVX512F__ //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED MEMORY LOADING template T load(const U *); #if defined(__ARM_NEON) template <> inline float32x4_t load(const float *p) { return vld1q_f32(p); } #if !defined(_MSC_VER) template <> inline float16x8_t load(const ggml_fp16_t *p) { return vld1q_f16((const float16_t *)p); } template <> inline float32x4_t load(const ggml_fp16_t *p) { return vcvt_f32_f16(vld1_f16((const float16_t *)p)); } #endif // _MSC_VER #endif // __ARM_NEON #if defined(__SSE__) || defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> inline __m128 load(const float *p) { return _mm_loadu_ps(p); } #endif // __SSE__ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) template <> inline __m256 load(const float *p) { return _mm256_loadu_ps(p); } #endif // __AVX__ #if defined(__F16C__) template <> inline __m256 load(const ggml_fp16_t *p) { return _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)p)); } #endif // __F16C__ #if defined(__AVX512F__) template <> inline __m512 load(const float *p) { return _mm512_loadu_ps(p); } template <> inline __m512 load(const ggml_fp16_t *p) { return _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)p)); } #endif // __AVX512F__ //////////////////////////////////////////////////////////////////////////////////////////////////// // FLOATING POINT MATRIX MULTIPLICATION template class tinyBLAS { public: tinyBLAS(int64_t k, const TA *A, int64_t lda, const TB *B, int64_t ldb, TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(int64_t m, int64_t n) { mnpack(0, m, 0, n); } private: NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t mc, nc, mp, np; switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { #if VECTOR_REGISTERS == 32 case 0x55: mc = 5; nc = 5; gemm<5, 5>(m0, m, n0, n); break; case 0x45: mc = 4; nc = 5; gemm<4, 5>(m0, m, n0, n); break; case 0x54: mc = 5; nc = 4; gemm<5, 4>(m0, m, n0, n); break; case 0x44: mc = 4; nc = 4; gemm<4, 4>(m0, m, n0, n); break; case 0x53: mc = 5; nc = 3; gemm<5, 3>(m0, m, n0, n); break; case 0x35: mc = 3; nc = 5; gemm<3, 5>(m0, m, n0, n); break; case 0x43: mc = 4; nc = 3; gemm<4, 3>(m0, m, n0, n); break; #else case 0x55: case 0x54: case 0x53: case 0x45: case 0x44: case 0x43: mc = 4; nc = 3; gemm<4, 3>(m0, m, n0, n); break; case 0x35: #endif case 0x34: mc = 3; nc = 4; gemm<3, 4>(m0, m, n0, n); break; case 0x52: mc = 5; nc = 2; gemm<5, 2>(m0, m, n0, n); break; case 0x33: mc = 3; nc = 3; gemm<3, 3>(m0, m, n0, n); break; case 0x25: mc = 2; nc = 5; gemm<2, 5>(m0, m, n0, n); break; case 0x42: mc = 4; nc = 2; gemm<4, 2>(m0, m, n0, n); break; case 0x24: mc = 2; nc = 4; gemm<2, 4>(m0, m, n0, n); break; case 0x32: mc = 3; nc = 2; gemm<3, 2>(m0, m, n0, n); break; case 0x23: mc = 2; nc = 3; gemm<2, 3>(m0, m, n0, n); break; case 0x51: mc = 5; nc = 1; gemm<5, 1>(m0, m, n0, n); break; case 0x41: mc = 4; nc = 1; gemm<4, 1>(m0, m, n0, n); break; case 0x22: mc = 2; nc = 2; gemm<2, 2>(m0, m, n0, n); break; case 0x15: mc = 1; nc = 5; gemm<1, 5>(m0, m, n0, n); break; case 0x14: mc = 1; nc = 4; gemm<1, 4>(m0, m, n0, n); break; case 0x31: mc = 3; nc = 1; gemm<3, 1>(m0, m, n0, n); break; case 0x13: mc = 1; nc = 3; gemm<1, 3>(m0, m, n0, n); break; case 0x21: mc = 2; nc = 1; gemm<2, 1>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; gemm<1, 2>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1>(m0, m, n0, n); break; default: return; } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } template NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; int64_t duty = (tiles + nth - 1) / nth; int64_t start = duty * ith; int64_t end = start + duty; if (end > tiles) end = tiles; for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; D Cv[RN][RM] = {}; for (int64_t l = 0; l < k; l += KN) for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) Cv[j][i] = madd(load(A + lda * (ii + i) + l), load(B + ldb * (jj + j) + l), Cv[j][i]); for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } const TA *const A; const TB *const B; TC *const C; const int64_t k; const int64_t lda; const int64_t ldb; const int64_t ldc; const int ith; const int nth; }; ////////////////////////////////////////////////////////////////////////////////////////// // QUANT ZERO MATRIX MULTIPLICATION #if defined(__ARM_FEATURE_DOTPROD) template class tinyBLAS_Q0_ARM { public: tinyBLAS_Q0_ARM(int64_t k, const TA *A, int64_t lda, const block_q8_0 *B, int64_t ldb, float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(int64_t m, int64_t n) { mnpack(0, m, 0, n); } private: NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t mc, nc, mp, np; switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { case 0x33: mc = 3; nc = 3; gemm<3, 3>(m0, m, n0, n); break; case 0x32: mc = 3; nc = 2; gemm<3, 2>(m0, m, n0, n); break; case 0x23: mc = 2; nc = 3; gemm<2, 3>(m0, m, n0, n); break; case 0x22: mc = 2; nc = 2; gemm<2, 2>(m0, m, n0, n); break; case 0x31: mc = 3; nc = 1; gemm<3, 1>(m0, m, n0, n); break; case 0x13: mc = 1; nc = 3; gemm<1, 3>(m0, m, n0, n); break; case 0x21: mc = 2; nc = 1; gemm<2, 1>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; gemm<1, 2>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1>(m0, m, n0, n); break; default: return; } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } template NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; int64_t duty = (tiles + nth - 1) / nth; int64_t start = duty * ith; int64_t end = start + duty; if (end > tiles) end = tiles; for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; float32x4_t Cv[RN][RM] = {}; for (int64_t l = 0; l < k; ++l) for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) Cv[j][i] = vmlaq_n_f32(Cv[j][i], vcvtq_f32_s32(vdotq_s32( vdotq_s32(vdupq_n_s32(0), load_lo(A + lda * (ii + i) + l), load_lo(B + ldb * (jj + j) + l)), load_hi(A + lda * (ii + i) + l), load_hi(B + ldb * (jj + j) + l))), unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)); for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } inline int8x16_t load_lo(const block_q8_0 *b) { return vld1q_s8(b->qs); } inline int8x16_t load_hi(const block_q8_0 *b) { return vld1q_s8(b->qs + 16); } inline int8x16_t load_lo(const block_q4_0 *b) { return vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vld1q_u8(b->qs), vdupq_n_u8(0x0f))), vdupq_n_s8(0x8)); } inline int8x16_t load_hi(const block_q4_0 *b) { return vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(vld1q_u8(b->qs), 4)), vdupq_n_s8(0x8)); } const TA *const A; const block_q8_0 *const B; float *const C; const int64_t k; const int64_t lda; const int64_t ldb; const int64_t ldc; const int ith; const int nth; }; #endif // __ARM_FEATURE_DOTPROD #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) template class tinyBLAS_Q0_AVX { public: tinyBLAS_Q0_AVX(int64_t k, const TA *A, int64_t lda, const TB *B, int64_t ldb, TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } void matmul(int64_t m, int64_t n) { mnpack(0, m, 0, n); } private: void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t mc, nc, mp, np; switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { #if VECTOR_REGISTERS == 32 case 0x44: mc = 4; nc = 4; gemm<4, 4>(m0, m, n0, n); break; case 0x43: mc = 4; nc = 3; gemm<4, 3>(m0, m, n0, n); break; case 0x34: mc = 3; nc = 4; gemm<3, 4>(m0, m, n0, n); break; case 0x33: mc = 3; nc = 3; gemm<3, 3>(m0, m, n0, n); break; case 0x42: mc = 4; nc = 2; gemm<4, 2>(m0, m, n0, n); break; case 0x24: mc = 2; nc = 4; gemm<2, 4>(m0, m, n0, n); break; #else case 0x44: case 0x43: case 0x42: mc = 4; nc = 2; gemm<4, 2>(m0, m, n0, n); break; case 0x34: case 0x24: mc = 2; nc = 4; gemm<2, 4>(m0, m, n0, n); break; case 0x33: #endif case 0x32: mc = 3; nc = 2; gemm<3, 2>(m0, m, n0, n); break; case 0x23: mc = 2; nc = 3; gemm<2, 3>(m0, m, n0, n); break; case 0x41: mc = 4; nc = 1; gemm<4, 1>(m0, m, n0, n); break; case 0x22: mc = 2; nc = 2; gemm<2, 2>(m0, m, n0, n); break; case 0x14: mc = 1; nc = 4; gemm<1, 4>(m0, m, n0, n); break; case 0x31: mc = 3; nc = 1; gemm<3, 1>(m0, m, n0, n); break; case 0x13: mc = 1; nc = 3; gemm<1, 3>(m0, m, n0, n); break; case 0x21: mc = 2; nc = 1; gemm<2, 1>(m0, m, n0, n); break; case 0x12: mc = 1; nc = 2; gemm<1, 2>(m0, m, n0, n); break; case 0x11: mc = 1; nc = 1; gemm<1, 1>(m0, m, n0, n); break; default: return; } mp = m0 + (m - m0) / mc * mc; np = n0 + (n - n0) / nc * nc; mnpack(mp, m, n0, np); mnpack(m0, m, np, n); } template NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; int64_t duty = (tiles + nth - 1) / nth; int64_t start = duty * ith; int64_t end = start + duty; if (end > tiles) end = tiles; for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; __m256 Cv[RN][RM] = {}; for (int64_t l = 0; l < k; ++l) for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) { #if defined(__AVX2__) __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), load(A + lda * (ii + i) + l)), _mm256_sign_epi8(load(B + ldb * (jj + j) + l), load(A + lda * (ii + i) + l))); #else __m128i ali0 = load0(A + lda * (ii + i) + l); __m128i ali1 = load1(A + lda * (ii + i) + l); __m128i blj0 = load0(B + ldb * (jj + j) + l); __m128i blj1 = load1(B + ldb * (jj + j) + l); __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); // updot const __m128i oneFill = _mm_set1_epi16(1); __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); #endif Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)), udTmp, Cv[j][i]); } for (int64_t j = 0; j < RN; ++j) for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } inline __m256i load(const block_q8_0 *b) { return _mm256_loadu_si256((const __m256i *)b->qs); } inline __m128i load0(const block_q8_0 *b) { return _mm_loadu_si128((const __m128i *)b->qs); } inline __m128i load1(const block_q8_0 *b) { return _mm_loadu_si128(((const __m128i *)b->qs) + 1); } inline __m256i load(const block_q4_0 *b) { return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); } inline __m128i load0(const block_q4_0 *b) { const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); } inline __m128i load1(const block_q4_0 *b) { const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); } inline __m256 updot(__m256i u, __m256i s) { __m256i res; #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); #else res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); #endif return _mm256_cvtepi32_ps(res); } static inline __m256i denibble(const uint8_t *p) { __m128i x = _mm_loadu_si128((const __m128i *)p); return _mm256_and_si256(_mm256_set1_epi8(15), _mm256_insertf128_si256(_mm256_castsi128_si256(x), _mm_srli_epi16(x, 4), 1)); } const TA *const A; const TB *const B; TC *const C; const int64_t k; const int64_t lda; const int64_t ldb; const int64_t ldc; const int ith; const int nth; }; #endif // __AVX__ } // namespace /** * Performs optimized matrix multiplication on CPU. * * This subroutine may compute C = Aᵀ * B with column major ordering. * Despite its name, this isn't a generalized implementation. Work is * only performed when a handwritten kernel is written and available. * Otherwise the caller should fall back to a general matmul routine. * * For example, for single-threaded single-precision GEMM you can say * * llamafile_sgemm(m, n, k, A, lda, B, ldb, C, ldc, * 0, 1, * GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32); * * @param m is rows in `A` and `C` * @param n is cols in `B` and `C` * @param k is cols in `A` and rows in `B` * @param A is first input matrix (always transposed) * @param lda is row stride of `A` * @param B is second input matrix (never transposed) * @param ldb is row stride of `B` * @param C is input/output array of output matrices * @param ldc is row stride of `C` * @param ith is thread id (must be less than `nth`) * @param nth is number of threads (must be greater than zero) * @param Atype is GGML data type of `A` * @param Btype is GGML data type of `B` * @param Ctype is GGML data type of `C` * @return true if this function was able to service the matmul request */ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) { assert(m >= 0); assert(n >= 0); assert(k >= 0); assert(lda >= k); assert(ldb >= k); assert(ldc >= m); assert(nth > 0); assert(ith < nth); if (Ctype != GGML_TYPE_F32) return false; switch (Atype) { case GGML_TYPE_F32: { if (Btype != GGML_TYPE_F32) return false; #if defined(__AVX512F__) if (k % 16) return false; tinyBLAS<16, __m512, __m512, float, float, float> tb{ k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #elif defined(__AVX__) || defined(__AVX2__) if (k % 8) return false; tinyBLAS<8, __m256, __m256, float, float, float> tb{ k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #elif defined(__ARM_NEON) if (n < 4) return false; if (k % 4) return false; tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #else return false; #endif } case GGML_TYPE_F16: { #if defined(__AVX512F__) if (k % 16) return false; if (Btype != GGML_TYPE_F32) return false; tinyBLAS<16, __m512, __m512, ggml_fp16_t, float, float> tb{ k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) if (k % 8) return false; if (Btype != GGML_TYPE_F32) return false; tinyBLAS<8, __m256, __m256, ggml_fp16_t, float, float> tb{ k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) if (n < 8) return false; if (k % 8) return false; if (Btype != GGML_TYPE_F16) return false; tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #elif defined(__ARM_NEON) && !defined(_MSC_VER) if (k % 4) return false; if (Btype != GGML_TYPE_F32) return false; tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #else return false; #endif } case GGML_TYPE_Q8_0: { if (Btype != GGML_TYPE_Q8_0) return false; #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) tinyBLAS_Q0_AVX tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #elif defined(__ARM_FEATURE_DOTPROD) tinyBLAS_Q0_ARM tb{ k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #else return false; #endif } case GGML_TYPE_Q4_0: { if (Btype != GGML_TYPE_Q8_0) return false; #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__) tinyBLAS_Q0_AVX tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #elif defined(__ARM_FEATURE_DOTPROD) tinyBLAS_Q0_ARM tb{ k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; tb.matmul(m, n); return true; #else return false; #endif } default: return false; } (void)m; (void)n; (void)k; (void)A; (void)lda; (void)B; (void)ldb; (void)C; (void)ldc; (void)ith; (void)nth; (void)Atype; (void)Btype; (void)Ctype; }