metal : trying bs = 512 performance (wip)

This commit is contained in:
Georgi Gerganov 2024-02-12 19:21:57 +02:00
parent e8b00e2941
commit 5a668ea000
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 44 additions and 15 deletions

View File

@ -1301,7 +1301,7 @@ static bool ggml_metal_graph_compute(
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (src1t == GGML_TYPE_F32 && ne11 <= 8) {
if (src1t == GGML_TYPE_F32) {
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
@ -1340,12 +1340,12 @@ static bool ggml_metal_graph_compute(
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
const int nsg = 8;
const int nsg = 4;
const int nsg0 = 1;
const int nsh0 = 16;
const int nsg1 = 1;
const int nsh1 = 64;
const int nsg0 = 4;
const int nsh0 = 4;
const int nsg1 = 2;
const int nsh1 = 4;
GGML_ASSERT(ne00 % 4 == 0); // for zeroing shared memory with half4 / float4
//GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16

View File

@ -4784,10 +4784,10 @@ void kernel_mul_mm_impl(
}
}
#define NSG0 1
#define NSH0 16
#define NSG1 1
#define NSH1 64
#define NSG0 4
#define NSH0 4
#define NSG1 2
#define NSH1 4
// each block_q contains 16*nl weights
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
@ -4870,6 +4870,8 @@ void kernel_mul_mm2_impl(
}
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
{
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
@ -4896,10 +4898,10 @@ void kernel_mul_mm2_impl(
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
{
@ -4945,6 +4947,7 @@ void kernel_mul_mm2_impl(
simdgroup_barrier(mem_flags::mem_none);
#if 0
#pragma unroll(NSH0)
for (int k = 0; k < NSH0; ++k) {
for (int j = 0; j < NSG0; ++j) {
@ -4961,9 +4964,22 @@ void kernel_mul_mm2_impl(
}
}
}
}
#else
#pragma unroll(NSH0)
for (int k = 0; k < NSH0; ++k) {
for (int i = 0; i < NSG1; ++i) {
simdgroup_load(m1[i], s1 + (8*i)*(8*NSH1) + 8*NSH0*b0 + 8*k, 8*NSH1, 0, true);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < NSG0; ++j) {
simdgroup_load(m0[j], s0 + (8*j)*(8*NSH0) + 8*k, 8*NSH0);
for (int i = 0; i < NSG1; ++i) {
simdgroup_multiply_accumulate(mr[j][i], m0[j], m1[i], mr[j][i]);
}
}
}
#endif
}
}
// write the mr to shared memory

View File

@ -2075,7 +2075,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
}
}
#else
#elif 0
for (int r0 = 0; r0 < 32; ++r0) {
for (int c0 = 0; c0 < 4096; c0 += 512) {
for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) {
@ -2092,6 +2092,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
}
}
#elif 1
for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1}));
}
}
#endif
for (ggml_type type_a : all_types) {