metal : support permuted matrix multiplicaions (#10033)

* metal : support permuted matrix multiplicaions

ggml-ci

* cont : use nb01 directly for row steps

ggml-ci

* cont : add comments [no ci]

* metal : minor refactor

* metal : minor
This commit is contained in:
Georgi Gerganov 2024-10-25 22:26:15 +03:00 committed by GitHub
parent ff252ea48e
commit 668750357e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 423 additions and 230 deletions

View File

@ -1015,19 +1015,21 @@ static void ggml_metal_encode_node(
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil; id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil; id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
//GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op)); #if 0
//if (src0) { GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
// GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, if (src0) {
// ggml_is_contiguous(src0), src0->name); GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
//} ggml_is_contiguous(src0), src0->name);
//if (src1) { }
// GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, if (src1) {
// ggml_is_contiguous(src1), src1->name); GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
//} ggml_is_contiguous(src1), src1->name);
//if (dst) { }
// GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, if (dst) {
// dst->name); GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
//} dst->name);
}
#endif
id<MTLDevice> device = ctx_dev->mtl_device; id<MTLDevice> device = ctx_dev->mtl_device;
@ -1810,14 +1812,16 @@ static void ggml_metal_encode_node(
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:15];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:16];
[encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else { } else {
@ -1986,16 +1990,18 @@ static void ggml_metal_encode_node(
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11]; [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:19];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:20];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
@ -2048,6 +2054,9 @@ static void ggml_metal_encode_node(
GGML_ASSERT(src1t == GGML_TYPE_F32); GGML_ASSERT(src1t == GGML_TYPE_F32);
GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);
// find the break-even point where the matrix-matrix kernel becomes more efficient compared // find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel // to the matrix-vector kernel
// ne20 = n_used_experts // ne20 = n_used_experts

File diff suppressed because it is too large Load Diff