mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 20:14:29 +00:00
ggml : optimize llamafile cpu matrix multiplication for ppc64le (#10156)
This change upstreams llamafile's cpu matrix multiplication kernels for ppc64le using MMA builtins for FP32 datatype. This change results in a consistent 90% improvement in input processing time, and 20% to 80% improvement in output processing time, across various batch sizes. The patch is tested with Meta-Lllama-3-8B, Mistral-7B, Llama-2-7B-chat-hf models on a IBM POWER10 machine. Signed-off-by: Amrita H S <amritahs@linux.vnet.ibm.com>
This commit is contained in:
parent
8fc393f246
commit
e89213492d
@ -1265,8 +1265,13 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
|
|||||||
endif()
|
endif()
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
|
||||||
message(STATUS "PowerPC detected")
|
message(STATUS "PowerPC detected")
|
||||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
execute_process(COMMAND bash -c "grep POWER10 /proc/cpuinfo | head -n 1"
|
||||||
list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
|
OUTPUT_VARIABLE POWER10_M)
|
||||||
|
string(FIND ${POWER10_M} "POWER10" substring_index)
|
||||||
|
if(${substring_index} GREATER_EQUAL 0)
|
||||||
|
list(APPEND ARCH_FLAGS -mcpu=power10)
|
||||||
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64le")
|
||||||
|
list(APPEND ARCH_FLAGS -mcpu=powerpc64le)
|
||||||
else()
|
else()
|
||||||
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
|
list(APPEND ARCH_FLAGS -mcpu=native -mtune=native)
|
||||||
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
|
#TODO: Add targets for Power8/Power9 (Altivec/VSX) and Power10(MMA) and query for big endian systems (ppc64/le/be)
|
||||||
|
@ -106,6 +106,10 @@ inline float16x8_t sub(float16x8_t x, float16x8_t y) { return vsubq_f16(x, y); }
|
|||||||
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
inline float16x8_t mul(float16x8_t x, float16x8_t y) { return vmulq_f16(x, y); }
|
||||||
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
|
||||||
|
|
||||||
|
#if defined(__MMA__)
|
||||||
|
typedef vector unsigned char vec_t;
|
||||||
|
typedef __vector_quad acc_t;
|
||||||
|
#endif
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// VECTORIZED FUSED MULTIPLY ADD
|
// VECTORIZED FUSED MULTIPLY ADD
|
||||||
|
|
||||||
@ -1026,6 +1030,600 @@ class tinyBLAS_Q0_AVX {
|
|||||||
};
|
};
|
||||||
#endif // __AVX__
|
#endif // __AVX__
|
||||||
|
|
||||||
|
//PPC Implementation
|
||||||
|
#if defined(__MMA__)
|
||||||
|
|
||||||
|
#define SAVE_ACC(ACC, ii, jj) \
|
||||||
|
__builtin_mma_disassemble_acc(vec_C, ACC); \
|
||||||
|
for (int I = 0; I < 4; I++) { \
|
||||||
|
for (int J = 0; J < 4; J++) { \
|
||||||
|
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J); \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
template <typename TA, typename TB, typename TC>
|
||||||
|
class tinyBLAS_PPC {
|
||||||
|
public:
|
||||||
|
tinyBLAS_PPC(int64_t k,
|
||||||
|
const TA *A, int64_t lda,
|
||||||
|
const TB *B, int64_t ldb,
|
||||||
|
TC *C, int64_t ldc,
|
||||||
|
int ith, int nth)
|
||||||
|
: A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void matmul(int64_t m, int64_t n) {
|
||||||
|
mnpack(0, m, 0, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
|
||||||
|
void (tinyBLAS_PPC::*kernel)(int64_t, int64_t);
|
||||||
|
|
||||||
|
void READ_BLOCK(const float* a, int64_t lda, int rows, int cols, float* vec) {
|
||||||
|
int64_t i, j;
|
||||||
|
float *aoffset = NULL, *boffset = NULL;
|
||||||
|
float *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL;
|
||||||
|
float *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL;
|
||||||
|
|
||||||
|
aoffset = const_cast<float*>(a);
|
||||||
|
boffset = vec;
|
||||||
|
j = (rows >> 3);
|
||||||
|
if (j > 0) {
|
||||||
|
do {
|
||||||
|
aoffset1 = aoffset;
|
||||||
|
aoffset2 = aoffset1 + lda;
|
||||||
|
aoffset3 = aoffset2 + lda;
|
||||||
|
aoffset4 = aoffset3 + lda;
|
||||||
|
aoffset5 = aoffset4 + lda;
|
||||||
|
aoffset6 = aoffset5 + lda;
|
||||||
|
aoffset7 = aoffset6 + lda;
|
||||||
|
aoffset8 = aoffset7 + lda;
|
||||||
|
aoffset += 8 * lda;
|
||||||
|
i = (cols >> 3);
|
||||||
|
if (i > 0) {
|
||||||
|
__vector_pair C1, C2, C3, C4, C5, C6, C7, C8;
|
||||||
|
vector float c1[2], c2[2], c3[2], c4[2], c5[2], c6[2], c7[2], c8[2];
|
||||||
|
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
||||||
|
do {
|
||||||
|
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
||||||
|
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
||||||
|
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
|
||||||
|
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
|
||||||
|
C5 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset5);
|
||||||
|
C6 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset6);
|
||||||
|
C7 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset7);
|
||||||
|
C8 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset8);
|
||||||
|
__builtin_vsx_disassemble_pair(c1, &C1);
|
||||||
|
__builtin_vsx_disassemble_pair(c2, &C2);
|
||||||
|
__builtin_vsx_disassemble_pair(c3, &C3);
|
||||||
|
__builtin_vsx_disassemble_pair(c4, &C4);
|
||||||
|
__builtin_vsx_disassemble_pair(c5, &C5);
|
||||||
|
__builtin_vsx_disassemble_pair(c6, &C6);
|
||||||
|
__builtin_vsx_disassemble_pair(c7, &C7);
|
||||||
|
__builtin_vsx_disassemble_pair(c8, &C8);
|
||||||
|
|
||||||
|
t1 = vec_mergeh(c1[0], c2[0]);
|
||||||
|
t2 = vec_mergeh(c3[0], c4[0]);
|
||||||
|
t3 = vec_mergeh(c5[0], c6[0]);
|
||||||
|
t4 = vec_mergeh(c7[0], c8[0]);
|
||||||
|
t5 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t6 = vec_xxpermdi(t3, t4, 0);
|
||||||
|
t7 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
t8 = vec_xxpermdi(t3, t4, 3);
|
||||||
|
vec_xst(t5, 0, boffset);
|
||||||
|
vec_xst(t6, 0, boffset+4);
|
||||||
|
vec_xst(t7, 0, boffset+8);
|
||||||
|
vec_xst(t8, 0, boffset+12);
|
||||||
|
|
||||||
|
t1 = vec_mergel(c1[0], c2[0]);
|
||||||
|
t2 = vec_mergel(c3[0], c4[0]);
|
||||||
|
t3 = vec_mergel(c5[0], c6[0]);
|
||||||
|
t4 = vec_mergel(c7[0], c8[0]);
|
||||||
|
t5 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t6 = vec_xxpermdi(t3, t4, 0);
|
||||||
|
t7 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
t8 = vec_xxpermdi(t3, t4, 3);
|
||||||
|
vec_xst(t5, 0, boffset+16);
|
||||||
|
vec_xst(t6, 0, boffset+20);
|
||||||
|
vec_xst(t7, 0, boffset+24);
|
||||||
|
vec_xst(t8, 0, boffset+28);
|
||||||
|
|
||||||
|
t1 = vec_mergeh(c1[1], c2[1]);
|
||||||
|
t2 = vec_mergeh(c3[1], c4[1]);
|
||||||
|
t3 = vec_mergeh(c5[1], c6[1]);
|
||||||
|
t4 = vec_mergeh(c7[1], c8[1]);
|
||||||
|
t5 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t6 = vec_xxpermdi(t3, t4, 0);
|
||||||
|
t7 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
t8 = vec_xxpermdi(t3, t4, 3);
|
||||||
|
vec_xst(t5, 0, boffset+32);
|
||||||
|
vec_xst(t6, 0, boffset+36);
|
||||||
|
vec_xst(t7, 0, boffset+40);
|
||||||
|
vec_xst(t8, 0, boffset+44);
|
||||||
|
|
||||||
|
t1 = vec_mergel(c1[1], c2[1]);
|
||||||
|
t2 = vec_mergel(c3[1], c4[1]);
|
||||||
|
t3 = vec_mergel(c5[1], c6[1]);
|
||||||
|
t4 = vec_mergel(c7[1], c8[1]);
|
||||||
|
t5 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t6 = vec_xxpermdi(t3, t4, 0);
|
||||||
|
t7 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
t8 = vec_xxpermdi(t3, t4, 3);
|
||||||
|
vec_xst(t5, 0, boffset+48);
|
||||||
|
vec_xst(t6, 0, boffset+52);
|
||||||
|
vec_xst(t7, 0, boffset+56);
|
||||||
|
vec_xst(t8, 0, boffset+60);
|
||||||
|
|
||||||
|
aoffset1 += 8*lda;
|
||||||
|
aoffset2 += 8*lda;
|
||||||
|
aoffset3 += 8*lda;
|
||||||
|
aoffset4 += 8*lda;
|
||||||
|
boffset += 64;
|
||||||
|
i--;
|
||||||
|
} while(i > 0);
|
||||||
|
}
|
||||||
|
if (cols & 4) {
|
||||||
|
vector float c1, c2, c3, c4, c5, c6, c7, c8;
|
||||||
|
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
||||||
|
c1 = vec_xl(0, aoffset1);
|
||||||
|
c2 = vec_xl(0, aoffset2);
|
||||||
|
c3 = vec_xl(0, aoffset3);
|
||||||
|
c4 = vec_xl(0, aoffset4);
|
||||||
|
c5 = vec_xl(0, aoffset5);
|
||||||
|
c6 = vec_xl(0, aoffset6);
|
||||||
|
c7 = vec_xl(0, aoffset7);
|
||||||
|
c8 = vec_xl(0, aoffset8);
|
||||||
|
|
||||||
|
t1 = vec_mergeh(c1, c2);
|
||||||
|
t2 = vec_mergeh(c3, c4);
|
||||||
|
t3 = vec_mergeh(c5, c6);
|
||||||
|
t4 = vec_mergeh(c7, c8);
|
||||||
|
t5 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t6 = vec_xxpermdi(t3, t4, 0);
|
||||||
|
t7 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
t8 = vec_xxpermdi(t3, t4, 3);
|
||||||
|
vec_xst(t5, 0, boffset);
|
||||||
|
vec_xst(t6, 0, boffset+4);
|
||||||
|
vec_xst(t7, 0, boffset+8);
|
||||||
|
vec_xst(t8, 0, boffset+12);
|
||||||
|
|
||||||
|
t1 = vec_mergel(c1, c2);
|
||||||
|
t2 = vec_mergel(c3, c4);
|
||||||
|
t3 = vec_mergel(c5, c6);
|
||||||
|
t4 = vec_mergel(c7, c8);
|
||||||
|
t5 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t6 = vec_xxpermdi(t3, t4, 0);
|
||||||
|
t7 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
t8 = vec_xxpermdi(t3, t4, 3);
|
||||||
|
vec_xst(t5, 0, boffset+16);
|
||||||
|
vec_xst(t6, 0, boffset+20);
|
||||||
|
vec_xst(t7, 0, boffset+24);
|
||||||
|
vec_xst(t8, 0, boffset+28);
|
||||||
|
}
|
||||||
|
j--;
|
||||||
|
} while(j > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (rows & 4) {
|
||||||
|
aoffset1 = aoffset;
|
||||||
|
aoffset2 = aoffset1 + lda;
|
||||||
|
aoffset3 = aoffset2 + lda;
|
||||||
|
aoffset4 = aoffset3 + lda;
|
||||||
|
aoffset += 4 * lda;
|
||||||
|
i = (cols >> 3);
|
||||||
|
if (i > 0) {
|
||||||
|
__vector_pair C1, C2, C3, C4;
|
||||||
|
vector float c1[2], c2[2], c3[2], c4[2];
|
||||||
|
vector float t1, t2, t3, t4, t5, t6, t7, t8;
|
||||||
|
do {
|
||||||
|
C1 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset1);
|
||||||
|
C2 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset2);
|
||||||
|
C3 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset3);
|
||||||
|
C4 = __builtin_vsx_lxvp(0, (__vector_pair*)aoffset4);
|
||||||
|
__builtin_vsx_disassemble_pair(c1, &C1);
|
||||||
|
__builtin_vsx_disassemble_pair(c2, &C2);
|
||||||
|
__builtin_vsx_disassemble_pair(c3, &C3);
|
||||||
|
__builtin_vsx_disassemble_pair(c4, &C4);
|
||||||
|
|
||||||
|
t1 = vec_mergeh(c1[0], c2[0]);
|
||||||
|
t2 = vec_mergeh(c3[0], c4[0]);
|
||||||
|
t3 = vec_mergel(c1[0], c2[0]);
|
||||||
|
t4 = vec_mergel(c3[0], c4[0]);
|
||||||
|
t5 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t6 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
t7 = vec_xxpermdi(t3, t4, 0);
|
||||||
|
t8 = vec_xxpermdi(t3, t4, 3);
|
||||||
|
vec_xst(t5, 0, boffset);
|
||||||
|
vec_xst(t6, 0, boffset+4);
|
||||||
|
vec_xst(t7, 0, boffset+8);
|
||||||
|
vec_xst(t8, 0, boffset+12);
|
||||||
|
|
||||||
|
t1 = vec_mergeh(c1[1], c2[1]);
|
||||||
|
t2 = vec_mergeh(c3[1], c4[1]);
|
||||||
|
t3 = vec_mergel(c1[1], c2[1]);
|
||||||
|
t4 = vec_mergel(c3[1], c4[1]);
|
||||||
|
t5 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t6 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
t7 = vec_xxpermdi(t3, t4, 0);
|
||||||
|
t8 = vec_xxpermdi(t3, t4, 3);
|
||||||
|
vec_xst(t5, 0, boffset+16);
|
||||||
|
vec_xst(t6, 0, boffset+20);
|
||||||
|
vec_xst(t7, 0, boffset+24);
|
||||||
|
vec_xst(t8, 0, boffset+28);
|
||||||
|
|
||||||
|
aoffset1 += 8*lda;
|
||||||
|
aoffset2 += 8*lda;
|
||||||
|
aoffset3 += 8*lda;
|
||||||
|
aoffset4 += 8*lda;
|
||||||
|
boffset += 32;
|
||||||
|
i--;
|
||||||
|
} while(i > 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cols & 4) {
|
||||||
|
vector float c1, c2, c3, c4;
|
||||||
|
vector float t1, t2, t3, t4;
|
||||||
|
c1 = vec_xl(0, aoffset1);
|
||||||
|
c2 = vec_xl(0, aoffset2);
|
||||||
|
c3 = vec_xl(0, aoffset3);
|
||||||
|
c4 = vec_xl(0, aoffset4);
|
||||||
|
|
||||||
|
t1 = vec_mergeh(c1, c2);
|
||||||
|
t2 = vec_mergeh(c3, c4);
|
||||||
|
t3 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t4 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
vec_xst(t3, 0, boffset);
|
||||||
|
vec_xst(t4, 0, boffset+4);
|
||||||
|
|
||||||
|
t1 = vec_mergel(c1, c2);
|
||||||
|
t2 = vec_mergel(c3, c4);
|
||||||
|
t3 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t4 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
vec_xst(t3, 0, boffset+8);
|
||||||
|
vec_xst(t4, 0, boffset+12);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (rows & 3) {
|
||||||
|
aoffset1 = aoffset;
|
||||||
|
aoffset2 = aoffset1 + lda;
|
||||||
|
aoffset3 = aoffset2 + lda;
|
||||||
|
if (cols & 4) {
|
||||||
|
vector float c1, c2, c3, c4 = {0};
|
||||||
|
vector float t1, t2, t3, t4;
|
||||||
|
c1 = vec_xl(0, aoffset1);
|
||||||
|
c2 = vec_xl(0, aoffset2);
|
||||||
|
c3 = vec_xl(0, aoffset3);
|
||||||
|
|
||||||
|
t1 = vec_mergeh(c1, c2);
|
||||||
|
t2 = vec_mergeh(c3, c4);
|
||||||
|
t3 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t4 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
vec_xst(t3, 0, boffset);
|
||||||
|
vec_xst(t4, 0, boffset+4);
|
||||||
|
|
||||||
|
t1 = vec_mergel(c1, c2);
|
||||||
|
t2 = vec_mergel(c3, c4);
|
||||||
|
t3 = vec_xxpermdi(t1, t2, 0);
|
||||||
|
t4 = vec_xxpermdi(t1, t2, 3);
|
||||||
|
vec_xst(t3, 0, boffset+8);
|
||||||
|
vec_xst(t4, 0, boffset+12);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void KERNEL_4x4(int64_t ii, int64_t jj) {
|
||||||
|
vec_t vec_A[4], vec_B[4], vec_C[4];
|
||||||
|
acc_t acc_0;
|
||||||
|
__builtin_mma_xxsetaccz(&acc_0);
|
||||||
|
for (int l = 0; l < k; l+=4) {
|
||||||
|
READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
|
||||||
|
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
|
||||||
|
}
|
||||||
|
SAVE_ACC(&acc_0, ii, jj);
|
||||||
|
}
|
||||||
|
|
||||||
|
void KERNEL_4x8(int64_t ii, int64_t jj) {
|
||||||
|
vec_t vec_A[4], vec_B[8], vec_C[4];
|
||||||
|
acc_t acc_0, acc_1;
|
||||||
|
__builtin_mma_xxsetaccz(&acc_0);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_1);
|
||||||
|
for (int64_t l = 0; l < k; l+=4) {
|
||||||
|
READ_BLOCK(A+(ii*lda)+l, lda, 4, 4, (float*)vec_A);
|
||||||
|
READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 4, (float*)vec_B);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], (vec_t)vec_B[0]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_1, vec_A[0], (vec_t)vec_B[1]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], (vec_t)vec_B[2]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_1, vec_A[1], (vec_t)vec_B[3]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], (vec_t)vec_B[4]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_1, vec_A[2], (vec_t)vec_B[5]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], (vec_t)vec_B[6]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_1, vec_A[3], (vec_t)vec_B[7]);
|
||||||
|
}
|
||||||
|
SAVE_ACC(&acc_0, ii, jj);
|
||||||
|
SAVE_ACC(&acc_1, ii, jj+4);
|
||||||
|
}
|
||||||
|
|
||||||
|
void KERNEL_8x4(int64_t ii, int64_t jj) {
|
||||||
|
vec_t vec_A[8], vec_B[4], vec_C[4];
|
||||||
|
acc_t acc_0, acc_1;
|
||||||
|
__builtin_mma_xxsetaccz(&acc_0);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_1);
|
||||||
|
for (int64_t l = 0; l < k; l+=4) {
|
||||||
|
READ_BLOCK(A+(ii*lda)+l, lda, 8, 4, (float*)vec_A);
|
||||||
|
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[0], vec_B[0]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[1], vec_B[0]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[2], vec_B[1]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[3], vec_B[1]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[4], vec_B[2]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[5], vec_B[2]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[6], vec_B[3]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[7], vec_B[3]);
|
||||||
|
}
|
||||||
|
SAVE_ACC(&acc_0, ii, jj);
|
||||||
|
SAVE_ACC(&acc_1, ii+4, jj);
|
||||||
|
}
|
||||||
|
|
||||||
|
void KERNEL_8x8(int64_t ii, int64_t jj) {
|
||||||
|
vec_t vec_A[16], vec_B[16], vec_C[4];
|
||||||
|
acc_t acc_0, acc_1, acc_2, acc_3;
|
||||||
|
__builtin_mma_xxsetaccz(&acc_0);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_1);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_2);
|
||||||
|
__builtin_mma_xxsetaccz(&acc_3);
|
||||||
|
for (int l = 0; l < k; l+=8) {
|
||||||
|
READ_BLOCK(A+(ii*lda)+l, lda, 8, 8, (float*)vec_A);
|
||||||
|
READ_BLOCK(B+(jj*ldb)+l, ldb, 8, 8, (float*)vec_B);
|
||||||
|
for(int x = 0; x < 16; x+=2) {
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, (vec_t)vec_A[x], vec_B[x]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_1, (vec_t)vec_A[x], vec_B[x+1]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_2, (vec_t)vec_A[x+1], vec_B[x]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_3, (vec_t)vec_A[x+1], vec_B[x+1]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
SAVE_ACC(&acc_0, ii, jj);
|
||||||
|
SAVE_ACC(&acc_1, ii, jj+4);
|
||||||
|
SAVE_ACC(&acc_2, ii+4, jj);
|
||||||
|
SAVE_ACC(&acc_3, ii+4, jj+4);
|
||||||
|
}
|
||||||
|
|
||||||
|
void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
|
int64_t mc, nc, mp, np;
|
||||||
|
int m_rem = MIN(m - m0, 16);
|
||||||
|
int n_rem = MIN(n - n0, 16);
|
||||||
|
if (m_rem >= 16 && n_rem >= 8) {
|
||||||
|
mc = 8;
|
||||||
|
nc = 8;
|
||||||
|
gemm<8,8>(m0, m, n0, n);
|
||||||
|
} else if(m_rem >= 8 && n_rem >= 16) {
|
||||||
|
mc = 8;
|
||||||
|
nc = 8;
|
||||||
|
gemm<8,8>(m0, m, n0, n);
|
||||||
|
} else if (m_rem >= 8 && n_rem >= 8) {
|
||||||
|
mc = 8;
|
||||||
|
nc = 8;
|
||||||
|
gemm<8,8>(m0, m, n0, n);
|
||||||
|
} else if (m_rem >= 4 && n_rem >= 8) {
|
||||||
|
mc = 4;
|
||||||
|
nc = 8;
|
||||||
|
gemm<4,8>(m0, m, n0, n);
|
||||||
|
} else if (m_rem >= 8 && n_rem >= 4) {
|
||||||
|
mc = 8;
|
||||||
|
nc = 4;
|
||||||
|
gemm<8,4>(m0, m, n0, n);
|
||||||
|
} else if (m_rem >= 4 && n_rem >= 4) {
|
||||||
|
mc = 4;
|
||||||
|
nc = 4;
|
||||||
|
gemm<4,4>(m0, m, n0, n);
|
||||||
|
} else if ((m_rem < 4) && (n_rem > 4)) {
|
||||||
|
nc = 4;
|
||||||
|
switch(m_rem) {
|
||||||
|
case 1:
|
||||||
|
mc = 1;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
mc = 2;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
mc = 3;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else if ((m_rem > 4) && (n_rem < 4)) {
|
||||||
|
mc = 4;
|
||||||
|
switch(n_rem) {
|
||||||
|
case 1:
|
||||||
|
nc = 1;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 2:
|
||||||
|
nc = 2;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 3:
|
||||||
|
nc = 3;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch((m_rem << 4) | n_rem) {
|
||||||
|
case 0x43:
|
||||||
|
mc = 4;
|
||||||
|
nc = 3;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x42:
|
||||||
|
mc = 4;
|
||||||
|
nc = 2;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x41:
|
||||||
|
mc = 4;
|
||||||
|
nc = 1;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x34:
|
||||||
|
mc = 3;
|
||||||
|
nc = 4;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x33:
|
||||||
|
mc = 3;
|
||||||
|
nc = 3;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x32:
|
||||||
|
mc = 3;
|
||||||
|
nc = 2;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x31:
|
||||||
|
mc = 3;
|
||||||
|
nc = 1;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x24:
|
||||||
|
mc = 2;
|
||||||
|
nc = 4;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x23:
|
||||||
|
mc = 2;
|
||||||
|
nc = 3;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x22:
|
||||||
|
mc = 2;
|
||||||
|
nc = 2;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x21:
|
||||||
|
mc = 2;
|
||||||
|
nc = 1;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x14:
|
||||||
|
mc = 1;
|
||||||
|
nc = 4;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x13:
|
||||||
|
mc = 1;
|
||||||
|
nc = 3;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x12:
|
||||||
|
mc = 1;
|
||||||
|
nc = 2;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
case 0x11:
|
||||||
|
mc = 1;
|
||||||
|
nc = 1;
|
||||||
|
gemm_small(m0, m, n0, n, mc, nc);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
mp = m0 + (m - m0) / mc * mc;
|
||||||
|
np = n0 + (n - n0) / nc * nc;
|
||||||
|
mnpack(mp, m, n0, np);
|
||||||
|
mnpack(m0, m, np, n);
|
||||||
|
}
|
||||||
|
|
||||||
|
void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) {
|
||||||
|
int64_t ytiles = (m - m0) / RM;
|
||||||
|
int64_t xtiles = (n - n0) / RN;
|
||||||
|
int64_t tiles = xtiles * ytiles;
|
||||||
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
|
int64_t start = duty * ith;
|
||||||
|
int64_t end = start + duty;
|
||||||
|
if (end > tiles)
|
||||||
|
end = tiles;
|
||||||
|
for (int64_t job = start; job < end; ++job) {
|
||||||
|
int64_t ii = m0 + job / xtiles * RM;
|
||||||
|
int64_t jj = n0 + job % xtiles * RN;
|
||||||
|
vec_t vec_C[4];
|
||||||
|
acc_t acc_0;
|
||||||
|
__builtin_mma_xxsetaccz(&acc_0);
|
||||||
|
vec_t vec_A[4], vec_B[4];
|
||||||
|
for (int l=0; l<k; l+=4) {
|
||||||
|
if (RN >= 4 && RM == 1) {
|
||||||
|
float* a = const_cast<float*>(A+(ii)*lda+l);
|
||||||
|
READ_BLOCK(B+(jj*ldb)+l, ldb, 4, 4, (float*)vec_B);
|
||||||
|
vec_A[0] = (vec_t)vec_xl(0,a);
|
||||||
|
vec_A[1] = (vec_t)vec_splats(*((float*)&vec_A+1));
|
||||||
|
vec_A[2] = (vec_t)vec_splats(*((float*)&vec_A+2));
|
||||||
|
vec_A[3] = (vec_t)vec_splats(*((float*)&vec_A+3));
|
||||||
|
} else {
|
||||||
|
READ_BLOCK(A+(ii*lda)+l, lda, RM, 4, (float*)vec_A);
|
||||||
|
READ_BLOCK(B+(jj*ldb)+l, ldb, RN, 4, (float*)vec_B);
|
||||||
|
}
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[0], vec_B[0]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[1], vec_B[1]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[2], vec_B[2]);
|
||||||
|
__builtin_mma_xvf32gerpp(&acc_0, vec_A[3], vec_B[3]);
|
||||||
|
}
|
||||||
|
__builtin_mma_disassemble_acc(vec_C, &acc_0);
|
||||||
|
for (int I = 0; I < RM; I++) {
|
||||||
|
for (int J = 0; J < RN; J++) {
|
||||||
|
*((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&vec_C[I]+J);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <int RM, int RN>
|
||||||
|
NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) {
|
||||||
|
int64_t ytiles = (m - m0) / RM;
|
||||||
|
int64_t xtiles = (n - n0) / RN;
|
||||||
|
int64_t tiles = xtiles * ytiles;
|
||||||
|
int64_t duty = (tiles + nth - 1) / nth;
|
||||||
|
int64_t start = duty * ith;
|
||||||
|
int64_t end = start + duty;
|
||||||
|
if (RM == 4 && RN == 4) {
|
||||||
|
kernel = &tinyBLAS_PPC::KERNEL_4x4;
|
||||||
|
} else if (RM == 4 && RN == 8) {
|
||||||
|
kernel = &tinyBLAS_PPC::KERNEL_4x8;
|
||||||
|
} else if (RM == 8 && RN == 4) {
|
||||||
|
kernel = &tinyBLAS_PPC::KERNEL_8x4;
|
||||||
|
} else if (RM == 8 && RN == 8) {
|
||||||
|
kernel = &tinyBLAS_PPC::KERNEL_8x8;
|
||||||
|
}
|
||||||
|
if (end > tiles)
|
||||||
|
end = tiles;
|
||||||
|
for (int64_t job = start; job < end; ++job) {
|
||||||
|
int64_t ii = m0 + job / xtiles * RM;
|
||||||
|
int64_t jj = n0 + job % xtiles * RN;
|
||||||
|
(this->*kernel)(ii, jj);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const TA *const A;
|
||||||
|
const TB *const B;
|
||||||
|
TC *C;
|
||||||
|
TA *At;
|
||||||
|
TB *Bt;
|
||||||
|
const int64_t k;
|
||||||
|
const int64_t lda;
|
||||||
|
const int64_t ldb;
|
||||||
|
const int64_t ldc;
|
||||||
|
const int ith;
|
||||||
|
const int nth;
|
||||||
|
};
|
||||||
|
#endif
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@ -1114,6 +1712,16 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda
|
|||||||
ith, nth};
|
ith, nth};
|
||||||
tb.matmul(m, n);
|
tb.matmul(m, n);
|
||||||
return true;
|
return true;
|
||||||
|
#elif defined(__MMA__)
|
||||||
|
if (k % 8)
|
||||||
|
return false;
|
||||||
|
tinyBLAS_PPC<float, float, float> tb{
|
||||||
|
k, (const float *)A, lda,
|
||||||
|
(const float *)B, ldb,
|
||||||
|
(float *)C, ldc,
|
||||||
|
ith, nth};
|
||||||
|
tb.matmul(m, n);
|
||||||
|
return true;
|
||||||
#else
|
#else
|
||||||
return false;
|
return false;
|
||||||
#endif
|
#endif
|
||||||
|
Loading…
Reference in New Issue
Block a user