diff --git a/ggml-metal.m b/ggml-metal.m index 831d2c93a..47cf991ca 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1301,7 +1301,7 @@ static bool ggml_metal_graph_compute( // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - if (src1t == GGML_TYPE_F32 && ne11 <= 8) { + if (src1t == GGML_TYPE_F32) { id pipeline = nil; switch (src0->type) { @@ -1340,12 +1340,12 @@ static bool ggml_metal_graph_compute( [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; - const int nsg = 8; + const int nsg = 4; - const int nsg0 = 1; - const int nsh0 = 16; - const int nsg1 = 1; - const int nsh1 = 64; + const int nsg0 = 4; + const int nsh0 = 4; + const int nsg1 = 2; + const int nsh1 = 4; GGML_ASSERT(ne00 % 4 == 0); // for zeroing shared memory with half4 / float4 //GGML_ASSERT(ne00 % 16 == 0); // dequantize in chunks of 16 diff --git a/ggml-metal.metal b/ggml-metal.metal index dba9935af..74e74da19 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4784,10 +4784,10 @@ void kernel_mul_mm_impl( } } -#define NSG0 1 -#define NSH0 16 -#define NSG1 1 -#define NSH1 64 +#define NSG0 4 +#define NSH0 4 +#define NSG1 2 +#define NSH1 4 // each block_q contains 16*nl weights template @@ -4870,6 +4870,8 @@ void kernel_mul_mm2_impl( } for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) { + threadgroup_barrier(mem_flags::mem_threadgroup); + // load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory { const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1); @@ -4896,10 +4898,10 @@ void kernel_mul_mm2_impl( } } } - - threadgroup_barrier(mem_flags::mem_threadgroup); } + threadgroup_barrier(mem_flags::mem_threadgroup); + for (int b0 = 0; b0 < NSH1/NSH0; ++b0) { // load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory { @@ -4945,6 +4947,7 @@ void kernel_mul_mm2_impl( simdgroup_barrier(mem_flags::mem_none); +#if 0 #pragma unroll(NSH0) for (int k = 0; k < NSH0; ++k) { for (int j = 0; j < NSG0; ++j) { @@ -4961,9 +4964,22 @@ void kernel_mul_mm2_impl( } } } - } +#else +#pragma unroll(NSH0) + for (int k = 0; k < NSH0; ++k) { + for (int i = 0; i < NSG1; ++i) { + simdgroup_load(m1[i], s1 + (8*i)*(8*NSH1) + 8*NSH0*b0 + 8*k, 8*NSH1, 0, true); + } - threadgroup_barrier(mem_flags::mem_threadgroup); + for (int j = 0; j < NSG0; ++j) { + simdgroup_load(m0[j], s0 + (8*j)*(8*NSH0) + 8*k, 8*NSH0); + for (int i = 0; i < NSG1; ++i) { + simdgroup_multiply_accumulate(mr[j][i], m0[j], m1[i], mr[j][i]); + } + } + } +#endif + } } // write the mr to shared memory diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index ae323a384..a15856eae 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2075,7 +2075,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); } } -#else +#elif 0 for (int r0 = 0; r0 < 32; ++r0) { for (int c0 = 0; c0 < 4096; c0 += 512) { for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) { @@ -2092,6 +2092,19 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } } +#elif 1 + for (ggml_type type_a : {GGML_TYPE_Q4_0, GGML_TYPE_Q8_0, GGML_TYPE_F16}) { + for (ggml_type type_b : {GGML_TYPE_F32}) { + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 4096, 512, 4096, { 1, 1}, {1, 1})); + } + } #endif for (ggml_type type_a : all_types) {