mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 19:34:35 +00:00
sgemm: add M blocs.
This commit is contained in:
parent
d732874114
commit
ac2b53c564
@ -291,10 +291,13 @@ static const __m128i iq4nlt = _mm_loadu_si128((const __m128i *) kvalues_iq4nl);
|
|||||||
// FLOATING POINT MATRIX MULTIPLICATION
|
// FLOATING POINT MATRIX MULTIPLICATION
|
||||||
|
|
||||||
template <int M>
|
template <int M>
|
||||||
static int64_t BLOCK_SIZE(size_t m) {
|
static inline int64_t BLOCK_SIZE(size_t m) {
|
||||||
const int64_t NB_BLOC_M = (m + M - 1) / M;
|
const int64_t NB_BLOC_M = (m + M - 1) / M;
|
||||||
int64_t res = (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
|
return (m % NB_BLOC_M == 0) ? m / NB_BLOC_M : (m / NB_BLOC_M) + 1;
|
||||||
return res;
|
}
|
||||||
|
|
||||||
|
static constexpr inline int64_t BLOC_POS(int64_t ib, int64_t ibN, int64_t bloc_size) {
|
||||||
|
return ib < ibN ? ib * bloc_size : ibN * bloc_size + (ib - ibN) * (bloc_size - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
template <int KN, typename D, typename V, typename TA, typename TB, typename TC>
|
||||||
@ -310,32 +313,37 @@ class tinyBLAS {
|
|||||||
bool matmul(int64_t m, int64_t n) {
|
bool matmul(int64_t m, int64_t n) {
|
||||||
if (k % KN != 0)
|
if (k % KN != 0)
|
||||||
return false;
|
return false;
|
||||||
// compute RN/RM for only tile with size RN&RN-1/RM&RM-1
|
// compute RM for only need tile with size RM&RM-1
|
||||||
#if VECTOR_REGISTERS == 32
|
#if VECTOR_REGISTERS == 32
|
||||||
if (m % 16 == 0) {
|
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
||||||
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
||||||
mnpack<4, 6, 4>(m, n, SIZE_N);
|
mnpack<4, 6, 4>(m, n, SIZE_N, 12);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (m % 8 == 0) {
|
if (m % 8 == 0 ) {
|
||||||
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
||||||
mnpack<4, 6, 2>(m, n, SIZE_N);
|
mnpack<4, 6, 2>(m, n, SIZE_N, 12);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (m % 4 == 0) {
|
if (m % 4 == 0) {
|
||||||
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
const int64_t SIZE_N = BLOCK_SIZE<6>(n);
|
||||||
mnpack<4, 6, 1>(m, n, SIZE_N);
|
mnpack<4, 6, 1>(m, n, SIZE_N, 12);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#else // VECTOR_REGISTERS == 16
|
#else // VECTOR_REGISTERS == 16
|
||||||
if (m % 8 == 0) {
|
if (m % 16 == 0 && (m/16 >= params->nth)) {
|
||||||
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
||||||
mnpack<4, 3, 2>(m, n, SIZE_N);
|
mnpack<4, 3, 4>(m, n, SIZE_N, 24);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (m % 8 == 0 ) {
|
||||||
|
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
||||||
|
mnpack<4, 3, 2>(m, n, SIZE_N, 24);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (m % 4 == 0) {
|
if (m % 4 == 0) {
|
||||||
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
const int64_t SIZE_N = BLOCK_SIZE<3>(n);
|
||||||
mnpack<4, 3, 1>(m, n, SIZE_N);
|
mnpack<4, 3, 1>(m, n, SIZE_N, 24);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@ -344,12 +352,12 @@ class tinyBLAS {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
template <int RM, int RN, int BM>
|
template <int RM, int RN, int BM>
|
||||||
inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N) {
|
inline void mnpack(int64_t m, int64_t n, int64_t SIZE_N, int64_t BN) {
|
||||||
if (SIZE_N == RN) {
|
if (SIZE_N == RN) {
|
||||||
return gemm<RM, RN, BM>(m, n);
|
return gemm<RM, RN, BM>(m, n, BN);
|
||||||
}
|
}
|
||||||
if constexpr (RN > 1) {
|
if constexpr (RN > 1) {
|
||||||
return mnpack<RM, RN-1, BM>(m, n, SIZE_N);
|
return mnpack<RM, RN-1, BM>(m, n, SIZE_N, BN);
|
||||||
} else {
|
} else {
|
||||||
GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
|
GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N);
|
||||||
GGML_ASSERT(false); // we have miss something.
|
GGML_ASSERT(false); // we have miss something.
|
||||||
@ -391,39 +399,58 @@ class tinyBLAS {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <int RM, int RN, int BM>
|
template <int RM, int RN, int BM>
|
||||||
NOINLINE void gemm(int64_t m, int64_t n) {
|
NOINLINE void gemm(int64_t m, int64_t n, int64_t BN) {
|
||||||
GGML_ASSERT(m % (RM * BM) == 0);
|
|
||||||
// const int64_t ytiles = m / (RM * BM);
|
|
||||||
const int64_t xtiles = (n + RN -1) / RN;
|
|
||||||
const int64_t jj_RN = (xtiles - (xtiles * RN - n)) * RN;
|
|
||||||
|
|
||||||
static std::atomic<int64_t> current_chunk;
|
static std::atomic<int64_t> current_chunk;
|
||||||
if (params->ith == 0) {
|
|
||||||
GGML_ASSERT((xtiles * RN - n) >= 0);
|
|
||||||
GGML_ASSERT((xtiles * RN - n) < RN);
|
|
||||||
|
|
||||||
|
GGML_ASSERT(m % (RM * BM) == 0);
|
||||||
|
const int64_t ytiles = m / (RM * BM);
|
||||||
|
const int64_t xtiles = (n + RN -1) / RN;
|
||||||
|
const int64_t jj_RN = (xtiles - (xtiles * RN - n));
|
||||||
|
|
||||||
|
// "round" bloc_size to "nearest" BN
|
||||||
|
const int64_t NB_BN = xtiles < BN ? 1 : (xtiles + BN / 2) / BN;
|
||||||
|
const int64_t SIZE_BN = xtiles % NB_BN == 0 ? xtiles / NB_BN : xtiles / NB_BN + 1;
|
||||||
|
const int64_t jj_BN = (NB_BN - (NB_BN * SIZE_BN - xtiles));
|
||||||
|
const int64_t nb_job = ytiles * NB_BN;
|
||||||
|
|
||||||
|
if (params->ith == 0) {
|
||||||
|
GGML_ASSERT( jj_BN * SIZE_BN + (NB_BN - jj_BN) * (SIZE_BN - 1) == xtiles);
|
||||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||||
std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed);
|
std::atomic_store_explicit(¤t_chunk, (int64_t)params->nth, std::memory_order_relaxed);
|
||||||
}
|
}
|
||||||
ggml_barrier(params->threadpool);
|
|
||||||
int64_t ii = params->ith * RM * BM;
|
|
||||||
|
|
||||||
while (ii < m) {
|
ggml_barrier(params->threadpool);
|
||||||
for (int64_t bi = 0; bi < BM * RM; bi+=RM) {
|
|
||||||
int64_t jj = 0;
|
int64_t job = params->ith;
|
||||||
for (; jj<jj_RN; jj+=RN) {
|
while (job < nb_job) {
|
||||||
|
const int64_t ii = (job % ytiles) * RM * BM;
|
||||||
|
const int64_t jb = job / ytiles;
|
||||||
|
const int64_t jr0 = BLOC_POS(jb , jj_BN, SIZE_BN);
|
||||||
|
const int64_t jrN = BLOC_POS(jb+1, jj_BN, SIZE_BN);
|
||||||
|
|
||||||
|
const int64_t jj0 = BLOC_POS(jr0, jj_RN, RN);
|
||||||
|
const int64_t jj2 = BLOC_POS(jrN, jj_RN, RN);
|
||||||
|
const int64_t jj1 = jj2 < jj_RN * RN ? jj2 : jj_RN * RN;
|
||||||
|
|
||||||
|
for (int64_t bi = 0; bi < BM * RM; bi += RM) {
|
||||||
|
int64_t jj = jj0;
|
||||||
|
for (; jj < jj1; jj += RN) {
|
||||||
gemm_bloc<RM, RN>(ii + bi, jj);
|
gemm_bloc<RM, RN>(ii + bi, jj);
|
||||||
}
|
}
|
||||||
if constexpr (RN > 1) {
|
if constexpr (RN > 1) {
|
||||||
for (; jj<n; jj+=RN-1) {
|
for (; jj < jj2; jj += RN - 1) {
|
||||||
gemm_bloc<RM, RN-1>(ii + bi, jj);
|
gemm_bloc<RM, RN-1>(ii + bi, jj);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
GGML_ASSERT(jj == n);
|
GGML_ASSERT(jj == jj2);
|
||||||
}
|
}
|
||||||
ii = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed) * RM * BM;
|
|
||||||
|
// next step.
|
||||||
|
job = std::atomic_fetch_add_explicit(¤t_chunk, (int64_t)1, std::memory_order_relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_barrier(params->threadpool);
|
ggml_barrier(params->threadpool);
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const ggml_compute_params * params;
|
const ggml_compute_params * params;
|
||||||
@ -1650,7 +1677,6 @@ bool llamafile_sgemm(const struct ggml_compute_params * params, int64_t m, int64
|
|||||||
assert(params->nth > 0);
|
assert(params->nth > 0);
|
||||||
assert(params->ith < params->nth);
|
assert(params->ith < params->nth);
|
||||||
|
|
||||||
// OK avec moins de thread 4 max en zen3 / 16 coeurs?
|
|
||||||
// only enable sgemm for prompt processing
|
// only enable sgemm for prompt processing
|
||||||
if (n < 2)
|
if (n < 2)
|
||||||
return false;
|
return false;
|
||||||
|
Loading…
Reference in New Issue
Block a user