tinyblas dynamic dispaching

This commit is contained in:
Djip007 2024-12-14 14:10:28 +01:00
parent 3f2bc659e7
commit d732874114
3 changed files with 74 additions and 82 deletions

View File

@ -7419,14 +7419,14 @@ static void ggml_compute_forward_mul_mat(
if (src1_cont) { if (src1_cont) {
for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++) for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), if (!llamafile_sgemm(params,
ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type), nb01/ggml_type_size(src0->type),
(const char *)src1->data + i12*nb12 + i13*nb13, (const char *)src1->data + i12*nb12 + i13*nb13,
nb11/ggml_type_size(src1->type), nb11/ggml_type_size(src1->type),
(char *)dst->data + i12*nb2 + i13*nb3, (char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type), nb1/ggml_type_size(dst->type),
ith, nth,
src0->type, src0->type,
src1->type, src1->type,
dst->type)) dst->type))
@ -7471,14 +7471,14 @@ UseGgmlGemm1:;
for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i13 = 0; i13 < ne13; i13++)
for (int64_t i12 = 0; i12 < ne12; i12++) for (int64_t i12 = 0; i12 < ne12; i12++)
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), if (!llamafile_sgemm(params,
ne01, ne11, ne00/ggml_blck_size(src0->type),
(const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03,
nb01/ggml_type_size(src0->type), nb01/ggml_type_size(src0->type),
(const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size,
row_size/ggml_type_size(vec_dot_type), row_size/ggml_type_size(vec_dot_type),
(char *)dst->data + i12*nb2 + i13*nb3, (char *)dst->data + i12*nb2 + i13*nb3,
nb1/ggml_type_size(dst->type), nb1/ggml_type_size(dst->type),
ith, nth,
src0->type, src0->type,
vec_dot_type, vec_dot_type,
dst->type)) dst->type))

View File

@ -53,6 +53,8 @@
#include "ggml-cpu-impl.h" #include "ggml-cpu-impl.h"
#include "ggml-quants.h" #include "ggml-quants.h"
#include <atomic>
#ifdef _MSC_VER #ifdef _MSC_VER
#define NOINLINE __declspec(noinline) #define NOINLINE __declspec(noinline)
#else #else
@ -298,12 +300,11 @@ static int64_t BLOCK_SIZE(size_t m) {
template <int KN, typename D, typename V, typename TA, typename TB, typename TC> template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
class tinyBLAS { class tinyBLAS {
public: public:
tinyBLAS(int64_t k, tinyBLAS(const ggml_compute_params * params, int64_t k,
const TA *A, int64_t lda, const TA *A, int64_t lda,
const TB *B, int64_t ldb, const TB *B, int64_t ldb,
TC *C, int64_t ldc, TC *C, int64_t ldc)
int ith, int nth) : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) {
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
} }
bool matmul(int64_t m, int64_t n) { bool matmul(int64_t m, int64_t n) {
@ -311,10 +312,6 @@ class tinyBLAS {
return false; return false;
// compute RN/RM for only tile with size RN&RN-1/RM&RM-1 // compute RN/RM for only tile with size RN&RN-1/RM&RM-1
#if VECTOR_REGISTERS == 32 #if VECTOR_REGISTERS == 32
if (m % 8 == 0 && n < 4) {
mnpack<8, 3, 1>(m, n, n);
return true;
}
if (m % 16 == 0) { if (m % 16 == 0) {
const int64_t SIZE_N = BLOCK_SIZE<6>(n); const int64_t SIZE_N = BLOCK_SIZE<6>(n);
mnpack<4, 6, 4>(m, n, SIZE_N); mnpack<4, 6, 4>(m, n, SIZE_N);
@ -331,10 +328,6 @@ class tinyBLAS {
return true; return true;
} }
#else // VECTOR_REGISTERS == 16 #else // VECTOR_REGISTERS == 16
if (m % 8 == 0 && n == 1) {
gemm<8, 1, 1>(m, n);
return true;
}
if (m % 8 == 0) { if (m % 8 == 0) {
const int64_t SIZE_N = BLOCK_SIZE<3>(n); const int64_t SIZE_N = BLOCK_SIZE<3>(n);
mnpack<4, 3, 2>(m, n, SIZE_N); mnpack<4, 3, 2>(m, n, SIZE_N);
@ -400,30 +393,40 @@ class tinyBLAS {
template <int RM, int RN, int BM> template <int RM, int RN, int BM>
NOINLINE void gemm(int64_t m, int64_t n) { NOINLINE void gemm(int64_t m, int64_t n) {
GGML_ASSERT(m % (RM * BM) == 0); GGML_ASSERT(m % (RM * BM) == 0);
const int64_t ytiles = m / (RM * BM); // const int64_t ytiles = m / (RM * BM);
const int64_t xtiles = (n + RN -1) / RN; const int64_t xtiles = (n + RN -1) / RN;
const int64_t jj_RN = (xtiles - (xtiles * RN - n)); const int64_t jj_RN = (xtiles - (xtiles * RN - n)) * RN;
GGML_ASSERT(jj_RN * RN + (xtiles - jj_RN) * (RN - 1) == n);
const int64_t tiles = xtiles * ytiles; static std::atomic<int64_t> current_chunk;
const int64_t duty = (tiles + nth - 1) / nth; if (params->ith == 0) {
const int64_t start = duty * ith; GGML_ASSERT((xtiles * RN - n) >= 0);
int64_t end = start + duty; GGML_ASSERT((xtiles * RN - n) < RN);
if (end > tiles)
end = tiles; // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
for (int64_t job = start; job < end; ++job) { std::atomic_store_explicit(&current_chunk, (int64_t)params->nth, std::memory_order_relaxed);
const int64_t ii = job / xtiles; }
const int64_t jj = job % xtiles; ggml_barrier(params->threadpool);
for (int64_t bi = 0; bi < BM; ++bi) { int64_t ii = params->ith * RM * BM;
if (jj < jj_RN) {
gemm_bloc<RM, RN>((ii * BM + bi) * RM, jj * RN); while (ii < m) {
} else if constexpr (RN > 1) { for (int64_t bi = 0; bi < BM * RM; bi+=RM) {
gemm_bloc<RM, RN - 1>((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1)); int64_t jj = 0;
for (; jj<jj_RN; jj+=RN) {
gemm_bloc<RM, RN>(ii + bi, jj);
}
if constexpr (RN > 1) {
for (; jj<n; jj+=RN-1) {
gemm_bloc<RM, RN-1>(ii + bi, jj);
} }
} }
GGML_ASSERT(jj == n);
} }
ii = std::atomic_fetch_add_explicit(&current_chunk, (int64_t)1, std::memory_order_relaxed) * RM * BM;
}
ggml_barrier(params->threadpool);
} }
const ggml_compute_params * params;
const TA *const A; const TA *const A;
const TB *const B; const TB *const B;
TC *const C; TC *const C;
@ -431,8 +434,6 @@ class tinyBLAS {
const int64_t lda; const int64_t lda;
const int64_t ldb; const int64_t ldb;
const int64_t ldc; const int64_t ldc;
const int ith;
const int nth;
}; };
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
@ -1636,8 +1637,9 @@ class tinyBLAS_PPC {
* @param Ctype is GGML data type of `C` * @param Ctype is GGML data type of `C`
* @return true if this function was able to service the matmul request * @return true if this function was able to service the matmul request
*/ */
bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) { const void *A, int64_t lda, const void *B, int64_t ldb, void *C,
int64_t ldc, int Atype, int Btype, int Ctype) {
assert(m >= 0); assert(m >= 0);
assert(n >= 0); assert(n >= 0);
@ -1645,9 +1647,10 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
assert(lda >= k); assert(lda >= k);
assert(ldb >= k); assert(ldb >= k);
assert(ldc >= m); assert(ldc >= m);
assert(nth > 0); assert(params->nth > 0);
assert(ith < nth); assert(params->ith < params->nth);
// OK avec moins de thread 4 max en zen3 / 16 coeurs?
// only enable sgemm for prompt processing // only enable sgemm for prompt processing
if (n < 2) if (n < 2)
return false; return false;
@ -1661,27 +1664,24 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
if (Btype != GGML_TYPE_F32) if (Btype != GGML_TYPE_F32)
return false; return false;
#if defined(__AVX512F__) #if defined(__AVX512F__)
tinyBLAS<16, __m512, __m512, float, float, float> tb{ tinyBLAS<16, __m512, __m512, float, float, float> tb{ params,
k, (const float *)A, lda, k, (const float *)A, lda,
(const float *)B, ldb, (const float *)B, ldb,
(float *)C, ldc, (float *)C, ldc};
ith, nth};
return tb.matmul(m, n); return tb.matmul(m, n);
#elif defined(__AVX__) || defined(__AVX2__) #elif defined(__AVX__) || defined(__AVX2__)
tinyBLAS<8, __m256, __m256, float, float, float> tb{ tinyBLAS<8, __m256, __m256, float, float, float> tb{ params,
k, (const float *)A, lda, k, (const float *)A, lda,
(const float *)B, ldb, (const float *)B, ldb,
(float *)C, ldc, (float *)C, ldc};
ith, nth};
return tb.matmul(m, n); return tb.matmul(m, n);
#elif defined(__ARM_NEON) #elif defined(__ARM_NEON)
if (n < 4) if (n < 4)
return false; return false;
tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params,
k, (const float *)A, lda, k, (const float *)A, lda,
(const float *)B, ldb, (const float *)B, ldb,
(float *)C, ldc, (float *)C, ldc};
ith, nth};
return tb.matmul(m, n); return tb.matmul(m, n);
#elif defined(__MMA__) #elif defined(__MMA__)
if (k % 8) if (k % 8)
@ -1690,7 +1690,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
k, (const float *)A, lda, k, (const float *)A, lda,
(const float *)B, ldb, (const float *)B, ldb,
(float *)C, ldc, (float *)C, ldc,
ith, nth}; params->ith, params->nth};
tb.matmul(m, n); tb.matmul(m, n);
return true; return true;
#else #else
@ -1701,29 +1701,26 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
#if defined(__AVX512BF16__) #if defined(__AVX512BF16__)
if (Btype == GGML_TYPE_BF16) { if (Btype == GGML_TYPE_BF16) {
tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ k, tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
(const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb, (const ggml_bf16_t *)B, ldb,
(float *)C, ldc, (float *)C, ldc};
ith, nth};
return tb.matmul(m, n); return tb.matmul(m, n);
} }
#elif defined(__AVX512F__) #elif defined(__AVX512F__)
if (Btype == GGML_TYPE_BF16) { if (Btype == GGML_TYPE_BF16) {
tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ k, tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
(const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb, (const ggml_bf16_t *)B, ldb,
(float *)C, ldc, (float *)C, ldc};
ith, nth};
return tb.matmul(m, n); return tb.matmul(m, n);
} }
#elif defined(__AVX2__) #elif defined(__AVX2__)
if (Btype == GGML_TYPE_BF16) { if (Btype == GGML_TYPE_BF16) {
tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ k, tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k,
(const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)A, lda,
(const ggml_bf16_t *)B, ldb, (const ggml_bf16_t *)B, ldb,
(float *)C, ldc, (float *)C, ldc};
ith, nth};
return tb.matmul(m, n); return tb.matmul(m, n);
} }
#endif #endif
@ -1732,40 +1729,36 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
#if defined(__AVX512F__) #if defined(__AVX512F__)
if (Btype == GGML_TYPE_F16) { if (Btype == GGML_TYPE_F16) {
tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ k, tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
(const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)A, lda,
(const ggml_fp16_t *)B, ldb, (const ggml_fp16_t *)B, ldb,
(float *)C, ldc, (float *)C, ldc};
ith, nth};
return tb.matmul(m, n); return tb.matmul(m, n);
} }
#elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__)
if (Btype == GGML_TYPE_F16) { if (Btype == GGML_TYPE_F16) {
tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ k, tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k,
(const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)A, lda,
(const ggml_fp16_t *)B, ldb, (const ggml_fp16_t *)B, ldb,
(float *)C, ldc, (float *)C, ldc};
ith, nth};
return tb.matmul(m, n); return tb.matmul(m, n);
} }
#elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER)
if (n < 8) if (n < 8)
return false; return false;
if (Btype == GGML_TYPE_F16) { if (Btype == GGML_TYPE_F16) {
tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params,
k, (const ggml_fp16_t *)A, lda, k, (const ggml_fp16_t *)A, lda,
(const ggml_fp16_t *)B, ldb, (const ggml_fp16_t *)B, ldb,
(float *)C, ldc, (float *)C, ldc};
ith, nth};
return tb.matmul(m, n); return tb.matmul(m, n);
} }
#elif defined(__ARM_NEON) && !defined(_MSC_VER) #elif defined(__ARM_NEON) && !defined(_MSC_VER)
if (Btype == GGML_TYPE_F32) { if (Btype == GGML_TYPE_F32) {
tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params,
k, (const ggml_fp16_t *)A, lda, k, (const ggml_fp16_t *)A, lda,
(const float *)B, ldb, (const float *)B, ldb,
(float *)C, ldc, (float *)C, ldc};
ith, nth};
return tb.matmul(m, n); return tb.matmul(m, n);
} }
#endif #endif
@ -1780,7 +1773,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
k, (const block_q8_0 *)A, lda, k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,
ith, nth}; params->ith, params->nth};
tb.matmul(m, n); tb.matmul(m, n);
return true; return true;
#elif defined(__ARM_FEATURE_DOTPROD) #elif defined(__ARM_FEATURE_DOTPROD)
@ -1788,7 +1781,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
k, (const block_q8_0 *)A, lda, k, (const block_q8_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,
ith, nth}; params->ith, params->nth};
tb.matmul(m, n); tb.matmul(m, n);
return true; return true;
#else #else
@ -1804,7 +1797,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
k, (const block_q4_0 *)A, lda, k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,
ith, nth}; params->ith, params->nth};
tb.matmul(m, n); tb.matmul(m, n);
return true; return true;
#elif defined(__ARM_FEATURE_DOTPROD) #elif defined(__ARM_FEATURE_DOTPROD)
@ -1812,7 +1805,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
k, (const block_q4_0 *)A, lda, k, (const block_q4_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,
ith, nth}; params->ith, params->nth};
tb.matmul(m, n); tb.matmul(m, n);
return true; return true;
#else #else
@ -1828,7 +1821,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
k, (const block_q5_0 *)A, lda, k, (const block_q5_0 *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,
ith, nth}; params->ith, params->nth};
tb.matmul(m, n); tb.matmul(m, n);
return true; return true;
#else #else
@ -1844,7 +1837,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
k, (const block_iq4_nl *)A, lda, k, (const block_iq4_nl *)A, lda,
(const block_q8_0 *)B, ldb, (const block_q8_0 *)B, ldb,
(float *)C, ldc, (float *)C, ldc,
ith, nth}; params->ith, params->nth};
tb.matmul(m, n); tb.matmul(m, n);
return true; return true;
#else #else
@ -1856,6 +1849,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
return false; return false;
} }
(void)params;
(void)m; (void)m;
(void)n; (void)n;
(void)k; (void)k;
@ -1865,8 +1859,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
(void)ldb; (void)ldb;
(void)C; (void)C;
(void)ldc; (void)ldc;
(void)ith;
(void)nth;
(void)Atype; (void)Atype;
(void)Btype; (void)Btype;
(void)Ctype; (void)Ctype;

View File

@ -5,8 +5,8 @@
extern "C" { extern "C" {
#endif #endif
bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t, bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t, int64_t, int64_t,
const void *, int64_t, void *, int64_t, int, int, const void *, int64_t, const void *, int64_t, void *, int64_t,
int, int, int); int, int, int);
#ifdef __cplusplus #ifdef __cplusplus