metal : limit kernels to not use more than the allowed threads

This commit is contained in:
Georgi Gerganov 2023-12-13 10:55:17 +02:00
parent ab558ac2b3
commit 109e7aa8ac
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1080,6 +1080,8 @@ void ggml_metal_graph_compute(
int64_t nb = ne00; int64_t nb = ne00;
id<MTLComputePipelineState> pipeline = nil;
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0));
@ -1088,21 +1090,23 @@ void ggml_metal_graph_compute(
nb = ne00 / 4; nb = ne00 / 4;
switch (dst->op) { switch (dst->op) {
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break; case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break; case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break; case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
default: GGML_ASSERT(false); default: GGML_ASSERT(false);
} }
bcast_row = true; bcast_row = true;
} else { } else {
switch (dst->op) { switch (dst->op) {
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break; case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break; case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break; case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
default: GGML_ASSERT(false); default: GGML_ASSERT(false);
} }
} }
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1137,7 +1141,7 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else { } else {
const int nth = MIN(1024, ne0); const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} }