From d73287411475b11f8aa88a9ac89a69f233e45a0f Mon Sep 17 00:00:00 2001 From: Djip007 <3705339+Djip007@users.noreply.github.com> Date: Sat, 14 Dec 2024 14:10:28 +0100 Subject: [PATCH] tinyblas dynamic dispaching --- ggml/src/ggml-cpu/ggml-cpu.c | 8 +- ggml/src/ggml-cpu/llamafile/sgemm.cpp | 144 ++++++++++++-------------- ggml/src/ggml-cpu/llamafile/sgemm.h | 4 +- 3 files changed, 74 insertions(+), 82 deletions(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 67e67a089..4f9de1180 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -7419,14 +7419,14 @@ static void ggml_compute_forward_mul_mat( if (src1_cont) { for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + if (!llamafile_sgemm(params, + ne01, ne11, ne00/ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), - ith, nth, src0->type, src1->type, dst->type)) @@ -7471,14 +7471,14 @@ UseGgmlGemm1:; for (int64_t i13 = 0; i13 < ne13; i13++) for (int64_t i12 = 0; i12 < ne12; i12++) - if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type), + if (!llamafile_sgemm(params, + ne01, ne11, ne00/ggml_blck_size(src0->type), (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type), (char *)dst->data + i12*nb2 + i13*nb3, nb1/ggml_type_size(dst->type), - ith, nth, src0->type, vec_dot_type, dst->type)) diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 3a9afa8bd..9b6a2eda0 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -53,6 +53,8 @@ #include "ggml-cpu-impl.h" #include "ggml-quants.h" +#include + #ifdef _MSC_VER #define NOINLINE __declspec(noinline) #else @@ -298,12 +300,11 @@ static int64_t BLOCK_SIZE(size_t m) { template class tinyBLAS { public: - tinyBLAS(int64_t k, + tinyBLAS(const ggml_compute_params * params, int64_t k, const TA *A, int64_t lda, const TB *B, int64_t ldb, - TC *C, int64_t ldc, - int ith, int nth) - : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + TC *C, int64_t ldc) + : params(params), A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc) { } bool matmul(int64_t m, int64_t n) { @@ -311,10 +312,6 @@ class tinyBLAS { return false; // compute RN/RM for only tile with size RN&RN-1/RM&RM-1 #if VECTOR_REGISTERS == 32 - if (m % 8 == 0 && n < 4) { - mnpack<8, 3, 1>(m, n, n); - return true; - } if (m % 16 == 0) { const int64_t SIZE_N = BLOCK_SIZE<6>(n); mnpack<4, 6, 4>(m, n, SIZE_N); @@ -331,10 +328,6 @@ class tinyBLAS { return true; } #else // VECTOR_REGISTERS == 16 - if (m % 8 == 0 && n == 1) { - gemm<8, 1, 1>(m, n); - return true; - } if (m % 8 == 0) { const int64_t SIZE_N = BLOCK_SIZE<3>(n); mnpack<4, 3, 2>(m, n, SIZE_N); @@ -400,30 +393,40 @@ class tinyBLAS { template NOINLINE void gemm(int64_t m, int64_t n) { GGML_ASSERT(m % (RM * BM) == 0); - const int64_t ytiles = m / (RM * BM); + // const int64_t ytiles = m / (RM * BM); const int64_t xtiles = (n + RN -1) / RN; - const int64_t jj_RN = (xtiles - (xtiles * RN - n)); - GGML_ASSERT(jj_RN * RN + (xtiles - jj_RN) * (RN - 1) == n); + const int64_t jj_RN = (xtiles - (xtiles * RN - n)) * RN; - const int64_t tiles = xtiles * ytiles; - const int64_t duty = (tiles + nth - 1) / nth; - const int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) - end = tiles; - for (int64_t job = start; job < end; ++job) { - const int64_t ii = job / xtiles; - const int64_t jj = job % xtiles; - for (int64_t bi = 0; bi < BM; ++bi) { - if (jj < jj_RN) { - gemm_bloc((ii * BM + bi) * RM, jj * RN); - } else if constexpr (RN > 1) { - gemm_bloc((ii * BM + bi) * RM, jj_RN * RN + (jj - jj_RN) * (RN - 1)); - } - } + static std::atomic current_chunk; + if (params->ith == 0) { + GGML_ASSERT((xtiles * RN - n) >= 0); + GGML_ASSERT((xtiles * RN - n) < RN); + + // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. + std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed); } + ggml_barrier(params->threadpool); + int64_t ii = params->ith * RM * BM; + + while (ii < m) { + for (int64_t bi = 0; bi < BM * RM; bi+=RM) { + int64_t jj = 0; + for (; jj(ii + bi, jj); + } + if constexpr (RN > 1) { + for (; jj(ii + bi, jj); + } + } + GGML_ASSERT(jj == n); + } + ii = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed) * RM * BM; + } + ggml_barrier(params->threadpool); } + const ggml_compute_params * params; const TA *const A; const TB *const B; TC *const C; @@ -431,8 +434,6 @@ class tinyBLAS { const int64_t lda; const int64_t ldb; const int64_t ldc; - const int ith; - const int nth; }; ////////////////////////////////////////////////////////////////////////////////////////// @@ -1636,8 +1637,9 @@ class tinyBLAS_PPC { * @param Ctype is GGML data type of `C` * @return true if this function was able to service the matmul request */ -bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, - int64_t ldc, int ith, int nth, int Atype, int Btype, int Ctype) { +bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64_t n, int64_t k, + const void *A, int64_t lda, const void *B, int64_t ldb, void *C, + int64_t ldc, int Atype, int Btype, int Ctype) { assert(m >= 0); assert(n >= 0); @@ -1645,9 +1647,10 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda assert(lda >= k); assert(ldb >= k); assert(ldc >= m); - assert(nth > 0); - assert(ith < nth); + assert(params->nth > 0); + assert(params->ith < params->nth); + // OK avec moins de thread 4 max en zen3 / 16 coeurs? // only enable sgemm for prompt processing if (n < 2) return false; @@ -1661,27 +1664,24 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda if (Btype != GGML_TYPE_F32) return false; #if defined(__AVX512F__) - tinyBLAS<16, __m512, __m512, float, float, float> tb{ + tinyBLAS<16, __m512, __m512, float, float, float> tb{ params, k, (const float *)A, lda, (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); #elif defined(__AVX__) || defined(__AVX2__) - tinyBLAS<8, __m256, __m256, float, float, float> tb{ + tinyBLAS<8, __m256, __m256, float, float, float> tb{ params, k, (const float *)A, lda, (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); #elif defined(__ARM_NEON) if (n < 4) return false; - tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ + tinyBLAS<4, float32x4_t, float32x4_t, float, float, float> tb{ params, k, (const float *)A, lda, (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); #elif defined(__MMA__) if (k % 8) @@ -1690,7 +1690,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const float *)A, lda, (const float *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #else @@ -1701,29 +1701,26 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda case GGML_TYPE_BF16: { #if defined(__AVX512BF16__) if (Btype == GGML_TYPE_BF16) { - tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ k, + tinyBLAS<32, __m512, __m512bh, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k, (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #elif defined(__AVX512F__) if (Btype == GGML_TYPE_BF16) { - tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ k, + tinyBLAS<16, __m512, __m512, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k, (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #elif defined(__AVX2__) if (Btype == GGML_TYPE_BF16) { - tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ k, + tinyBLAS<8, __m256, __m256, ggml_bf16_t, ggml_bf16_t, float> tb{ params, k, (const ggml_bf16_t *)A, lda, (const ggml_bf16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #endif @@ -1732,40 +1729,36 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda case GGML_TYPE_F16: { #if defined(__AVX512F__) if (Btype == GGML_TYPE_F16) { - tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ k, + tinyBLAS<16, __m512, __m512, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #elif (defined(__AVX__) || defined(__AVX2__)) && defined(__F16C__) if (Btype == GGML_TYPE_F16) { - tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ k, + tinyBLAS<8, __m256, __m256, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #elif defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && !defined(_MSC_VER) if (n < 8) return false; if (Btype == GGML_TYPE_F16) { - tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ + tinyBLAS<8, float16x8_t, float16x8_t, ggml_fp16_t, ggml_fp16_t, float> tb{ params, k, (const ggml_fp16_t *)A, lda, (const ggml_fp16_t *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #elif defined(__ARM_NEON) && !defined(_MSC_VER) if (Btype == GGML_TYPE_F32) { - tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ + tinyBLAS<4, float32x4_t, float32x4_t, ggml_fp16_t, float, float> tb{ params, k, (const ggml_fp16_t *)A, lda, (const float *)B, ldb, - (float *)C, ldc, - ith, nth}; + (float *)C, ldc}; return tb.matmul(m, n); } #endif @@ -1780,7 +1773,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #elif defined(__ARM_FEATURE_DOTPROD) @@ -1788,7 +1781,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_q8_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #else @@ -1804,7 +1797,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #elif defined(__ARM_FEATURE_DOTPROD) @@ -1812,7 +1805,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_q4_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #else @@ -1828,7 +1821,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_q5_0 *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #else @@ -1844,7 +1837,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda k, (const block_iq4_nl *)A, lda, (const block_q8_0 *)B, ldb, (float *)C, ldc, - ith, nth}; + params->ith, params->nth}; tb.matmul(m, n); return true; #else @@ -1856,6 +1849,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda return false; } + (void)params; (void)m; (void)n; (void)k; @@ -1865,8 +1859,6 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda (void)ldb; (void)C; (void)ldc; - (void)ith; - (void)nth; (void)Atype; (void)Btype; (void)Ctype; diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.h b/ggml/src/ggml-cpu/llamafile/sgemm.h index caf6dd556..3d2909515 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.h +++ b/ggml/src/ggml-cpu/llamafile/sgemm.h @@ -5,8 +5,8 @@ extern "C" { #endif -bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t, - const void *, int64_t, void *, int64_t, int, int, +bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t, int64_t, int64_t, + const void *, int64_t, const void *, int64_t, void *, int64_t, int, int, int); #ifdef __cplusplus