metal : limit kernels to not use more than the allowed threads

This commit is contained in:
Georgi Gerganov 2023-12-13 10:55:17 +02:00
parent ab558ac2b3
commit 109e7aa8ac
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1080,6 +1080,8 @@ void ggml_metal_graph_compute(
int64_t nb = ne00;
id<MTLComputePipelineState> pipeline = nil;
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(src0));
@ -1088,21 +1090,23 @@ void ggml_metal_graph_compute(
nb = ne00 / 4;
switch (dst->op) {
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
default: GGML_ASSERT(false);
}
bcast_row = true;
} else {
switch (dst->op) {
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
default: GGML_ASSERT(false);
}
}
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
@ -1137,7 +1141,7 @@ void ggml_metal_graph_compute(
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
const int nth = MIN(1024, ne0);
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}