mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
metal : limit kernels to not use more than the allowed threads
This commit is contained in:
parent
ab558ac2b3
commit
109e7aa8ac
18
ggml-metal.m
18
ggml-metal.m
@ -1080,6 +1080,8 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
int64_t nb = ne00;
|
int64_t nb = ne00;
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
@ -1088,21 +1090,23 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
nb = ne00 / 4;
|
nb = ne00 / 4;
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
|
case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
|
||||||
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
|
case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
|
||||||
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
|
case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
|
|
||||||
bcast_row = true;
|
bcast_row = true;
|
||||||
} else {
|
} else {
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
|
case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
|
||||||
case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
|
case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
|
||||||
case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
|
case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
|
||||||
default: GGML_ASSERT(false);
|
default: GGML_ASSERT(false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
@ -1137,7 +1141,7 @@ void ggml_metal_graph_compute(
|
|||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} else {
|
} else {
|
||||||
const int nth = MIN(1024, ne0);
|
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user