sgemm: add M blocs.

This commit is contained in:
Djip007 2024-12-19 01:21:09 +01:00
parent d732874114
commit ac2b53c564

View File

@ -291,10 +291,13 @@ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
// FLOATING POINT MATRIX MULTIPLICATION // FLOATING POINT MATRIX MULTIPLICATION
template <int M> template <int M>
static int64_t BLOCK_SIZE(size_t m) { static inline int64_t BLOCK_SIZE(size_t m) {
const int64_t NB_BLOC_M = (m + M - 1) / M; const int64_t NB_BLOC_M = (m + M - 1) / M;
int64_t res = (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1; return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
return res; }
static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
} }
template <int KN, typename D, typename V, typename TA, typename TB, typename TC> template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
@ -310,32 +313,37 @@ class tinyBLAS {
bool matmul(int64_t m, int64_t n) { bool matmul(int64_t m, int64_t n) {
if (k % KN != 0) if (k % KN != 0)
return false; return false;
// compute RN/RM for only tile with size RN&RN-1/RM&RM-1 // compute RM for only need tile with size RM&RM-1
#if VECTOR_REGISTERS == 32 #if VECTOR_REGISTERS == 32
if (m % 16 == 0) { if (m % 16 == 0 && (m/16 >= params->nth)) {
const int64_t SIZE_N = BLOCK_SIZE<6>(n); const int64_t SIZE_N = BLOCK_SIZE<6>(n);
mnpack<4, 6, 4>(m, n, SIZE_N); mnpack<4, 6, 4>(m, n, SIZE_N, 12);
return true; return true;
} }
if (m % 8 == 0) { if (m % 8 == 0 ) {
const int64_t SIZE_N = BLOCK_SIZE<6>(n); const int64_t SIZE_N = BLOCK_SIZE<6>(n);
mnpack<4, 6, 2>(m, n, SIZE_N); mnpack<4, 6, 2>(m, n, SIZE_N, 12);
return true; return true;
} }
if (m % 4 == 0) { if (m % 4 == 0) {
const int64_t SIZE_N = BLOCK_SIZE<6>(n); const int64_t SIZE_N = BLOCK_SIZE<6>(n);
mnpack<4, 6, 1>(m, n, SIZE_N); mnpack<4, 6, 1>(m, n, SIZE_N, 12);
return true; return true;
} }
#else // VECTOR_REGISTERS == 16 #else // VECTOR_REGISTERS == 16
if (m % 8 == 0) { if (m % 16 == 0 && (m/16 >= params->nth)) {
const int64_t SIZE_N = BLOCK_SIZE<3>(n); const int64_t SIZE_N = BLOCK_SIZE<3>(n);
mnpack<4, 3, 2>(m, n, SIZE_N); mnpack<4, 3, 4>(m, n, SIZE_N, 24);
return true;
}
if (m % 8 == 0 ) {
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
mnpack<4, 3, 2>(m, n, SIZE_N, 24);
return true; return true;
} }
if (m % 4 == 0) { if (m % 4 == 0) {
const int64_t SIZE_N = BLOCK_SIZE<3>(n); const int64_t SIZE_N = BLOCK_SIZE<3>(n);
mnpack<4, 3, 1>(m, n, SIZE_N); mnpack<4, 3, 1>(m, n, SIZE_N, 24);
return true; return true;
} }
#endif #endif
@ -344,12 +352,12 @@ class tinyBLAS {
private: private:
template <int RM, int RN, int BM> template <int RM, int RN, int BM>
inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N) { inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
if (SIZE_N == RN) { if (SIZE_N == RN) {
return gemm<RM, RN, BM>(m, n); return gemm<RM, RN, BM>(m, n, BN);
} }
if constexpr (RN > 1) { if constexpr (RN > 1) {
return mnpack<RM, RN-1, BM>(m, n, SIZE_N); return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
} else { } else {
GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
GGML_ASSERT(false); // we have miss something. GGML_ASSERT(false); // we have miss something.
@ -391,39 +399,58 @@ class tinyBLAS {
} }
template <int RM, int RN, int BM> template <int RM, int RN, int BM>
NOINLINE void gemm(int64_t m, int64_t n) { NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
GGML_ASSERT(m % (RM * BM) == 0);
// const int64_t ytiles = m / (RM * BM);
const int64_t xtiles = (n + RN -1) / RN;
const int64_t jj_RN = (xtiles - (xtiles * RN - n)) * RN;
static std::atomic<int64_t> current_chunk; static std::atomic<int64_t> current_chunk;
if (params->ith == 0) {
GGML_ASSERT((xtiles * RN - n) >= 0);
GGML_ASSERT((xtiles * RN - n) < RN);
GGML_ASSERT(m % (RM * BM) == 0);
const int64_t ytiles = m / (RM * BM);
const int64_t xtiles = (n + RN -1) / RN;
const int64_t jj_RN = (xtiles - (xtiles * RN - n));
// "round" bloc_size to "nearest" BN
const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
const int64_t nb_job = ytiles * NB_BN;
if (params->ith == 0) {
GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start. // Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
std::atomic_store_explicit(&current_chunk, (int64_t)params->nth, std::memory_order_relaxed); std::atomic_store_explicit(&current_chunk, (int64_t)params->nth, std::memory_order_relaxed);
} }
ggml_barrier(params->threadpool);
int64_t ii = params->ith * RM * BM;
while (ii < m) { ggml_barrier(params->threadpool);
for (int64_t bi = 0; bi < BM * RM; bi+=RM) {
int64_t jj = 0; int64_t job = params->ith;
for (; jj<jj_RN; jj+=RN) { while (job < nb_job) {
const int64_t ii = (job % ytiles) * RM * BM;
const int64_t jb = job / ytiles;
const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
for (int64_t bi = 0; bi < BM * RM; bi += RM) {
int64_t jj = jj0;
for (; jj < jj1; jj += RN) {
gemm_bloc<RM, RN>(ii + bi, jj); gemm_bloc<RM, RN>(ii + bi, jj);
} }
if constexpr (RN > 1) { if constexpr (RN > 1) {
for (; jj<n; jj+=RN-1) { for (; jj < jj2; jj += RN - 1) {
gemm_bloc<RM, RN-1>(ii + bi, jj); gemm_bloc<RM, RN-1>(ii + bi, jj);
} }
} }
GGML_ASSERT(jj == n); GGML_ASSERT(jj == jj2);
} }
ii = std::atomic_fetch_add_explicit(&current_chunk, (int64_t)1, std::memory_order_relaxed) * RM * BM;
// next step.
job = std::atomic_fetch_add_explicit(&current_chunk, (int64_t)1, std::memory_order_relaxed);
} }
ggml_barrier(params->threadpool); ggml_barrier(params->threadpool);
return;
} }
const ggml_compute_params * params; const ggml_compute_params * params;
@ -1650,7 +1677,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
assert(params->nth > 0); assert(params->nth > 0);
assert(params->ith < params->nth); assert(params->ith < params->nth);
// OK avec moins de thread 4 max en zen3 / 16 coeurs?
// only enable sgemm for prompt processing // only enable sgemm for prompt processing
if (n < 2) if (n < 2)
return false; return false;