diff --git a/sgemm.cpp b/sgemm.cpp index 531e12af3..4e0159804 100644 --- a/sgemm.cpp +++ b/sgemm.cpp @@ -50,7 +50,6 @@ #pragma GCC diagnostic ignored "-Wignored-attributes" #include "sgemm.h" -#include #include "ggml-impl.h" #include "ggml-quants.h" @@ -243,23 +242,23 @@ template <> inline __m512 load(const ggml_fp16_t *p) { template class tinyBLAS { public: - tinyBLAS(int k, - const TA *A, int lda, - const TB *B, int ldb, - TC *C, int ldc, + tinyBLAS(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int m, int n, int task) { + void matmul(int64_t m, int64_t n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - NOINLINE void mnpack(int m0, int m, int n0, int n) { - int mc, nc, mp, np; - switch ((std::min(m - m0, 5) << 4) | std::min(n - n0, 5)) { + NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 5) << 4) | MIN(n - n0, 5)) { #if VECTOR_REGISTERS == 32 case 0x55: mc = 5; @@ -409,27 +408,27 @@ class tinyBLAS { } template - NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; - int tiles = xtiles * ytiles; - int duty = (tiles + nth - 1) / nth; - int start = duty * ith; - int end = start + duty; + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; if (end > tiles) end = tiles; - for (int job = start; job < end; ++job) { - int ii = m0 + job / xtiles * RM; - int jj = n0 + job % xtiles * RN; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; D Cv[RN][RM] = {}; - for (int l = 0; l < k; l += KN) - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t l = 0; l < k; l += KN) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) Cv[j][i] = madd(load(A + lda * (ii + i) + l), load(B + ldb * (jj + j) + l), Cv[j][i]); - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } @@ -437,10 +436,10 @@ class tinyBLAS { const TA *const A; const TB *const B; TC *const C; - const int k; - const int lda; - const int ldb; - const int ldc; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; const int ith; const int nth; }; @@ -452,23 +451,23 @@ class tinyBLAS { template class tinyBLAS_Q0_ARM { public: - tinyBLAS_Q0_ARM(int k, - const TA *A, int lda, - const block_q8_0 *B, int ldb, - float *C, int ldc, + tinyBLAS_Q0_ARM(int64_t k, + const TA *A, int64_t lda, + const block_q8_0 *B, int64_t ldb, + float *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int m, int n, int task) { + void matmul(int64_t m, int64_t n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - NOINLINE void mnpack(int m0, int m, int n0, int n) { - int mc, nc, mp, np; - switch ((std::min(m - m0, 3) << 4) | std::min(n - n0, 3)) { + NOINLINE void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 3) << 4) | MIN(n - n0, 3ll)) { case 0x33: mc = 3; nc = 3; @@ -524,22 +523,22 @@ class tinyBLAS_Q0_ARM { } template - NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; - int tiles = xtiles * ytiles; - int duty = (tiles + nth - 1) / nth; - int start = duty * ith; - int end = start + duty; + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; if (end > tiles) end = tiles; - for (int job = start; job < end; ++job) { - int ii = m0 + job / xtiles * RM; - int jj = n0 + job % xtiles * RN; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; float32x4_t Cv[RN][RM] = {}; - for (int l = 0; l < k; ++l) - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) Cv[j][i] = vmlaq_n_f32(Cv[j][i], vcvtq_f32_s32(vdotq_s32( vdotq_s32(vdupq_n_s32(0), @@ -549,8 +548,8 @@ class tinyBLAS_Q0_ARM { load_hi(B + ldb * (jj + j) + l))), unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)); - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } @@ -577,10 +576,10 @@ class tinyBLAS_Q0_ARM { const TA *const A; const block_q8_0 *const B; float *const C; - const int k; - const int lda; - const int ldb; - const int ldc; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; const int ith; const int nth; }; @@ -590,23 +589,23 @@ class tinyBLAS_Q0_ARM { template class tinyBLAS_Q0_AVX2 { public: - tinyBLAS_Q0_AVX2(int k, - const TA *A, int lda, - const TB *B, int ldb, - TC *C, int ldc, + tinyBLAS_Q0_AVX2(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int m, int n, int task) { + void matmul(int64_t m, int64_t n, int task) { if (task == GGML_TASK_TYPE_COMPUTE) mnpack(0, m, 0, n); } private: - void mnpack(int m0, int m, int n0, int n) { - int mc, nc, mp, np; - switch ((std::min(m - m0, 4) << 4) | std::min(n - n0, 4)) { + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { #if VECTOR_REGISTERS == 32 case 0x44: mc = 4; @@ -714,22 +713,22 @@ class tinyBLAS_Q0_AVX2 { } template - NOINLINE void gemm(int m0, int m, int n0, int n) { - int ytiles = (m - m0) / RM; - int xtiles = (n - n0) / RN; - int tiles = xtiles * ytiles; - int duty = (tiles + nth - 1) / nth; - int start = duty * ith; - int end = start + duty; + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; if (end > tiles) end = tiles; - for (int job = start; job < end; ++job) { - int ii = m0 + job / xtiles * RM; - int jj = n0 + job % xtiles * RN; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; __m256 Cv[RN][RM] = {}; - for (int l = 0; l < k; ++l) - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) Cv[j][i] = madd(_mm256_set1_ps(unhalf(A[lda * (ii + i) + l].d) * unhalf(B[ldb * (jj + j) + l].d)), updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), @@ -737,8 +736,8 @@ class tinyBLAS_Q0_AVX2 { _mm256_sign_epi8(load(B + ldb * (jj + j) + l), load(A + lda * (ii + i) + l))), Cv[j][i]); - for (int j = 0; j < RN; ++j) - for (int i = 0; i < RM; ++i) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); } } @@ -771,10 +770,10 @@ class tinyBLAS_Q0_AVX2 { const TA *const A; const TB *const B; TC *const C; - const int k; - const int lda; - const int ldb; - const int ldc; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; const int ith; const int nth; }; @@ -813,8 +812,8 @@ class tinyBLAS_Q0_AVX2 { * @param Ctype is GGML data type of `C` * @return true if this function was able to service the matmul request */ -bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, int ldb, void *C, - int ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { +bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda, const void *B, int64_t ldb, void *C, + int64_t ldc, int ith, int nth, int task, int Atype, int Btype, int Ctype) { assert(m >= 0); assert(n >= 0); @@ -824,9 +823,6 @@ bool llamafile_sgemm(int m, int n, int k, const void *A, int lda, const void *B, assert(ldc >= m); assert(nth > 0); assert(ith < nth); - assert(1ll * lda * m <= 0x7fffffff); - assert(1ll * ldb * n <= 0x7fffffff); - assert(1ll * ldc * n <= 0x7fffffff); if (Ctype != GGML_TYPE_F32) return false; diff --git a/sgemm.h b/sgemm.h index da23b209c..f29747d0a 100644 --- a/sgemm.h +++ b/sgemm.h @@ -1,11 +1,13 @@ #pragma once +#include #include #ifdef __cplusplus extern "C" { #endif -bool llamafile_sgemm(int, int, int, const void *, int, const void *, int, - void *, int, int, int, int, int, int, int); +bool llamafile_sgemm(int64_t, int64_t, int64_t, const void *, int64_t, + const void *, int64_t, void *, int64_t, int, int, + int, int, int, int); #ifdef __cplusplus }