sgemm : AVX Q4_0 and Q8_0 (#6891)

* basic avx implementation

* style

* combine denibble with load

* reduce 256 to 128 (and back!) conversions

* sse load

* Update sgemm.cpp

* oops

oops
This commit is contained in:
Eve 2024-05-08 14:29:23 +00:00 committed by GitHub
parent 911b3900dd
commit 465263d0cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,3 @@
// -*- mode:c++;indent-tabs-mode:nil;c-basic-offset:4;coding:utf-8 -*-
// vi: set et ft=c++ ts=4 sts=4 sw=4 fenc=utf-8 :vi
//
// Copyright 2024 Mozilla Foundation // Copyright 2024 Mozilla Foundation
// //
// Permission is hereby granted, free of charge, to any person obtaining // Permission is hereby granted, free of charge, to any person obtaining
@ -585,11 +582,11 @@ class tinyBLAS_Q0_ARM {
}; };
#endif // __ARM_FEATURE_DOTPROD #endif // __ARM_FEATURE_DOTPROD
#if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
class tinyBLAS_Q0_AVX2 { class tinyBLAS_Q0_AVX {
public: public:
tinyBLAS_Q0_AVX2(int64_t k, tinyBLAS_Q0_AVX(int64_t k,
const TA *A, int64_t lda, const TA *A, int64_t lda,
const TB *B, int64_t ldb, const TB *B, int64_t ldb,
TC *C, int64_t ldc, TC *C, int64_t ldc,
@ -728,14 +725,34 @@ class tinyBLAS_Q0_AVX2 {
__m256 Cv[RN][RM] = {}; __m256 Cv[RN][RM] = {};
for (int64_t l = 0; l < k; ++l) for (int64_t l = 0; l < k; ++l)
for (int64_t j = 0; j < RN; ++j) for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i) for (int64_t i = 0; i < RM; ++i) {
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * #if defined(__AVX2__)
unhalf(B[ldb * (jj + j) + l].d)), __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l),
load(A + lda * (ii + i) + l)), load(A + lda * (ii + i) + l)),
_mm256_sign_epi8(load(B + ldb * (jj + j) + l), _mm256_sign_epi8(load(B + ldb * (jj + j) + l),
load(A + lda * (ii + i) + l))), load(A + lda * (ii + i) + l)));
#else
__m128i ali0 = load0(A + lda * (ii + i) + l);
__m128i ali1 = load1(A + lda * (ii + i) + l);
__m128i blj0 = load0(B + ldb * (jj + j) + l);
__m128i blj1 = load1(B + ldb * (jj + j) + l);
__m128i sepAA0 = _mm_sign_epi8(ali0, ali0);
__m128i sepAA1 = _mm_sign_epi8(ali1, ali1);
__m128i sepBA0 = _mm_sign_epi8(blj0, ali0);
__m128i sepBA1 = _mm_sign_epi8(blj1, ali1);
// updot
const __m128i oneFill = _mm_set1_epi16(1);
__m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0);
__m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1);
__m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0)));
#endif
Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) *
unhalf(B[ldb * (jj + j) + l].d)),
udTmp,
Cv[j][i]); Cv[j][i]);
}
for (int64_t j = 0; j < RN; ++j) for (int64_t j = 0; j < RN; ++j)
for (int64_t i = 0; i < RM; ++i) for (int64_t i = 0; i < RM; ++i)
C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]);
@ -746,10 +763,28 @@ class tinyBLAS_Q0_AVX2 {
return _mm256_loadu_si256((const __m256i *)b->qs); return _mm256_loadu_si256((const __m256i *)b->qs);
} }
inline __m128i load0(const block_q8_0 *b) {
return _mm_loadu_si128((const __m128i *)b->qs);
}
inline __m128i load1(const block_q8_0 *b) {
return _mm_loadu_si128(((const __m128i *)b->qs) + 1);
}
inline __m256i load(const block_q4_0 *b) { inline __m256i load(const block_q4_0 *b) {
return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8));
} }
inline __m128i load0(const block_q4_0 *b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8));
}
inline __m128i load1(const block_q4_0 *b) {
const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs));
return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8));
}
inline __m256 updot(__m256i u, __m256i s) { inline __m256 updot(__m256i u, __m256i s) {
__m256i res; __m256i res;
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) #if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
@ -777,7 +812,7 @@ class tinyBLAS_Q0_AVX2 {
const int ith; const int ith;
const int nth; const int nth;
}; };
#endif // __AVX2__ #endif // __AVX__
} // namespace } // namespace
@ -928,8 +963,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
case GGML_TYPE_Q8_0: { case GGML_TYPE_Q8_0: {
if (Btype != GGML_TYPE_Q8_0) if (Btype != GGML_TYPE_Q8_0)
return false; return false;
#if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX2<block_q8_0, block_q8_0, float> tb{ tinyBLAS_Q0_AVX<block_q8_0, block_q8_0, float> tb{
k, (const block_q8_0 *)A, lda, k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,
@ -952,8 +987,8 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
case GGML_TYPE_Q4_0: { case GGML_TYPE_Q4_0: {
if (Btype != GGML_TYPE_Q8_0) if (Btype != GGML_TYPE_Q8_0)
return false; return false;
#if defined(__AVX2__) || defined(__AVX512F__) #if defined(__AVX2__) || defined(__AVX512F__) || defined(__AVX__)
tinyBLAS_Q0_AVX2<block_q4_0, block_q8_0, float> tb{ tinyBLAS_Q0_AVX<block_q4_0, block_q8_0, float> tb{
k, (const block_q4_0 *)A, lda, k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,