mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 03:31:46 +00:00
metal : reduce command encoding overhead (#9698)
* metal : reduce command encoding overhead ggml-ci * metal : add comments
This commit is contained in:
parent
a90484c6d9
commit
cad341d889
@ -204,13 +204,6 @@ static ggml_status compute_piter(
|
|||||||
ggml_backend_cpu_set_n_threads(model.backend, params.n_threads);
|
ggml_backend_cpu_set_n_threads(model.backend, params.n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: enable GPU support when support for GGML_OP_SQRT is added
|
|
||||||
//#ifdef GGML_USE_METAL
|
|
||||||
// if (ggml_backend_is_metal(model.backend)) {
|
|
||||||
// ggml_backend_metal_set_n_cb(model.backend, params.n_threads);
|
|
||||||
// }
|
|
||||||
//#endif
|
|
||||||
|
|
||||||
ggml_status res = ggml_backend_graph_compute(model.backend, gf);
|
ggml_status res = ggml_backend_graph_compute(model.backend, gf);
|
||||||
if (res == GGML_STATUS_SUCCESS) {
|
if (res == GGML_STATUS_SUCCESS) {
|
||||||
auto extract_i = [](std::string prefix, std::string str) -> int {
|
auto extract_i = [](std::string prefix, std::string str) -> int {
|
||||||
|
@ -2444,12 +2444,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
|
|||||||
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
ggml_backend_cpu_set_n_threads(ctx->backend, n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef GGML_USE_METAL
|
|
||||||
if (ggml_backend_is_metal(ctx->backend)) {
|
|
||||||
ggml_backend_metal_set_n_cb(ctx->backend, n_threads);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
ggml_backend_graph_compute(ctx->backend, gf);
|
ggml_backend_graph_compute(ctx->backend, gf);
|
||||||
|
|
||||||
// the last node is the embedding tensor
|
// the last node is the embedding tensor
|
||||||
|
@ -25,9 +25,6 @@
|
|||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
|
|
||||||
// max memory buffers that can be mapped to the device
|
|
||||||
#define GGML_METAL_MAX_BUFFERS 64
|
|
||||||
|
|
||||||
struct ggml_tensor;
|
struct ggml_tensor;
|
||||||
struct ggml_cgraph;
|
struct ggml_cgraph;
|
||||||
|
|
||||||
@ -48,8 +45,6 @@ GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
|
|||||||
|
|
||||||
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
|
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
|
||||||
|
|
||||||
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
|
||||||
|
|
||||||
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
|
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
|
||||||
|
|
||||||
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
||||||
|
@ -12,6 +12,12 @@
|
|||||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
|
||||||
|
// max memory buffers that can be mapped to the device
|
||||||
|
#define GGML_METAL_MAX_BUFFERS 64
|
||||||
|
|
||||||
|
// max number of MTLCommandBuffer used to submit a graph for processing
|
||||||
|
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
|
||||||
|
|
||||||
#ifdef GGML_METAL_NDEBUG
|
#ifdef GGML_METAL_NDEBUG
|
||||||
#define GGML_METAL_LOG(...)
|
#define GGML_METAL_LOG(...)
|
||||||
#define GGML_METAL_LOG_INFO(...)
|
#define GGML_METAL_LOG_INFO(...)
|
||||||
@ -221,11 +227,11 @@ enum ggml_metal_kernel_type {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_backend_metal_context {
|
struct ggml_backend_metal_context {
|
||||||
int n_cb;
|
|
||||||
|
|
||||||
id<MTLDevice> device;
|
id<MTLDevice> device;
|
||||||
id<MTLCommandQueue> queue;
|
id<MTLCommandQueue> queue;
|
||||||
|
|
||||||
|
MTLComputePassDescriptor * edesc;
|
||||||
|
|
||||||
dispatch_queue_t d_queue;
|
dispatch_queue_t d_queue;
|
||||||
|
|
||||||
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
|
struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
|
||||||
@ -233,7 +239,27 @@ struct ggml_backend_metal_context {
|
|||||||
bool support_simdgroup_reduction;
|
bool support_simdgroup_reduction;
|
||||||
bool support_simdgroup_mm;
|
bool support_simdgroup_mm;
|
||||||
|
|
||||||
bool should_capture_next_compute;
|
// capture state
|
||||||
|
bool capture_next_compute;
|
||||||
|
bool capture_started;
|
||||||
|
|
||||||
|
id<MTLCaptureScope> capture_scope;
|
||||||
|
|
||||||
|
// command buffer state
|
||||||
|
int n_cb; // number of extra threads used to submit the command buffers
|
||||||
|
int n_nodes_0; // number of nodes submitted by the main thread
|
||||||
|
int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
|
||||||
|
int n_nodes_per_cb;
|
||||||
|
|
||||||
|
struct ggml_cgraph * gf;
|
||||||
|
|
||||||
|
// the callback given to the thread pool
|
||||||
|
// TODO: ideally, this should be created once, utilizing the command buffer state above
|
||||||
|
// for some reason, doing it like this leads to a crash
|
||||||
|
void (^encode_async)(size_t ith);
|
||||||
|
|
||||||
|
// n_cb command buffers + 1 used by the main thread
|
||||||
|
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
||||||
|
|
||||||
// abort ggml_metal_graph_compute if callback returns true
|
// abort ggml_metal_graph_compute if callback returns true
|
||||||
ggml_abort_callback abort_callback;
|
ggml_abort_callback abort_callback;
|
||||||
@ -303,7 +329,7 @@ static void * ggml_metal_host_malloc(size_t n) {
|
|||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
||||||
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
|
||||||
|
|
||||||
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
||||||
@ -322,8 +348,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
// Configure context
|
// Configure context
|
||||||
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
|
||||||
ctx->device = device;
|
ctx->device = device;
|
||||||
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
|
||||||
ctx->queue = [ctx->device newCommandQueue];
|
ctx->queue = [ctx->device newCommandQueue];
|
||||||
|
ctx->edesc = MTLComputePassDescriptor.computePassDescriptor;
|
||||||
|
ctx->edesc.dispatchType = MTLDispatchTypeSerial;
|
||||||
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
||||||
|
|
||||||
id<MTLLibrary> metal_library;
|
id<MTLLibrary> metal_library;
|
||||||
@ -455,7 +482,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
|
GGML_METAL_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx->support_simdgroup_mm ? "true" : "false");
|
||||||
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
|
||||||
|
|
||||||
ctx->should_capture_next_compute = false;
|
ctx->capture_next_compute = false;
|
||||||
|
ctx->capture_started = false;
|
||||||
|
ctx->capture_scope = nil;
|
||||||
|
|
||||||
|
ctx->gf = nil;
|
||||||
|
ctx->encode_async = nil;
|
||||||
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||||
|
ctx->command_buffers[i] = nil;
|
||||||
|
}
|
||||||
|
|
||||||
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
||||||
if (@available(macOS 10.12, iOS 16.0, *)) {
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
||||||
@ -686,6 +721,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(int n_cb) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
[metal_library release];
|
[metal_library release];
|
||||||
|
|
||||||
return ctx;
|
return ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -874,78 +910,23 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static enum ggml_status ggml_metal_graph_compute(
|
static void ggml_metal_encode_node(
|
||||||
struct ggml_backend_metal_context * ctx,
|
struct ggml_backend_metal_context * ctx,
|
||||||
struct ggml_cgraph * gf) {
|
int idx,
|
||||||
|
id<MTLComputeCommandEncoder> encoder) {
|
||||||
|
struct ggml_cgraph * gf = ctx->gf;
|
||||||
|
|
||||||
@autoreleasepool {
|
struct ggml_tensor * node = ggml_graph_node(gf, idx);
|
||||||
MTLComputePassDescriptor * edesc = MTLComputePassDescriptor.computePassDescriptor;
|
|
||||||
edesc.dispatchType = MTLDispatchTypeSerial;
|
|
||||||
|
|
||||||
// create multiple command buffers and enqueue them
|
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
|
||||||
// then, we encode the graph into the command buffers in parallel
|
|
||||||
|
|
||||||
const int n_nodes = gf->n_nodes;
|
struct ggml_tensor * src0 = node->src[0];
|
||||||
const int n_cb = ctx->n_cb;
|
struct ggml_tensor * src1 = node->src[1];
|
||||||
const int n_nodes_per_cb = (n_nodes + n_cb - 1) / n_cb;
|
struct ggml_tensor * src2 = node->src[2];
|
||||||
|
struct ggml_tensor * dst = node;
|
||||||
const bool should_capture = ctx->should_capture_next_compute;
|
|
||||||
if (should_capture) {
|
|
||||||
ctx->should_capture_next_compute = false;
|
|
||||||
|
|
||||||
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
|
||||||
descriptor.captureObject = ctx->queue;
|
|
||||||
|
|
||||||
NSError * error = nil;
|
|
||||||
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
|
||||||
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
|
|
||||||
GGML_ABORT("capture failed");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
id<MTLCommandBuffer> command_buffer_builder[n_cb];
|
|
||||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
|
||||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
|
||||||
command_buffer_builder[cb_idx] = command_buffer;
|
|
||||||
|
|
||||||
// always enqueue the first two command buffers
|
|
||||||
// enqueue all of the command buffers if we don't need to abort
|
|
||||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
|
||||||
[command_buffer enqueue];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
|
||||||
|
|
||||||
dispatch_apply(n_cb, ctx->d_queue, ^(size_t iter) {
|
|
||||||
const int cb_idx = iter;
|
|
||||||
|
|
||||||
size_t offs_src0 = 0;
|
|
||||||
size_t offs_src1 = 0;
|
|
||||||
size_t offs_src2 = 0;
|
|
||||||
size_t offs_dst = 0;
|
|
||||||
|
|
||||||
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
|
|
||||||
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
|
|
||||||
|
|
||||||
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
|
|
||||||
const int node_end = MIN((cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb, n_nodes);
|
|
||||||
|
|
||||||
for (int i = node_start; i < node_end; ++i) {
|
|
||||||
if (i == -1) {
|
|
||||||
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
//GGML_METAL_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
|
|
||||||
|
|
||||||
struct ggml_tensor * src0 = gf->nodes[i]->src[0];
|
|
||||||
struct ggml_tensor * src1 = gf->nodes[i]->src[1];
|
|
||||||
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
|
||||||
struct ggml_tensor * dst = gf->nodes[i];
|
|
||||||
|
|
||||||
if (ggml_is_empty(dst)) {
|
if (ggml_is_empty(dst)) {
|
||||||
continue;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (dst->op) {
|
switch (dst->op) {
|
||||||
@ -956,7 +937,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
case GGML_OP_PERMUTE:
|
case GGML_OP_PERMUTE:
|
||||||
{
|
{
|
||||||
// noop -> next node
|
// noop -> next node
|
||||||
} continue;
|
} return;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
} break;
|
} break;
|
||||||
@ -967,10 +948,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
GGML_ABORT("unsupported op");
|
GGML_ABORT("unsupported op");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (should_capture) {
|
|
||||||
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(dst) encoding:NSUTF8StringEncoding]];
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
const int64_t ne00 = src0 ? src0->ne[0] : 0;
|
||||||
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
const int64_t ne01 = src0 ? src0->ne[1] : 0;
|
||||||
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
const int64_t ne02 = src0 ? src0->ne[2] : 0;
|
||||||
@ -1015,6 +992,11 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
||||||
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
||||||
|
|
||||||
|
size_t offs_src0 = 0;
|
||||||
|
size_t offs_src1 = 0;
|
||||||
|
size_t offs_src2 = 0;
|
||||||
|
size_t offs_dst = 0;
|
||||||
|
|
||||||
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
||||||
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
||||||
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
||||||
@ -1039,7 +1021,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
{
|
{
|
||||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CONCAT].pipeline;
|
||||||
|
|
||||||
const int32_t dim = ((int32_t *) dst->op_params)[0];
|
const int32_t dim = ((const int32_t *) dst->op_params)[0];
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
@ -1203,12 +1185,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
GGML_ASSERT(ggml_is_contiguous(src1));
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
|
||||||
const size_t pnb1 = ((int32_t *) dst->op_params)[0];
|
const size_t pnb1 = ((const int32_t *) dst->op_params)[0];
|
||||||
const size_t pnb2 = ((int32_t *) dst->op_params)[1];
|
const size_t pnb2 = ((const int32_t *) dst->op_params)[1];
|
||||||
const size_t pnb3 = ((int32_t *) dst->op_params)[2];
|
const size_t pnb3 = ((const int32_t *) dst->op_params)[2];
|
||||||
const size_t offs = ((int32_t *) dst->op_params)[3];
|
const size_t offs = ((const int32_t *) dst->op_params)[3];
|
||||||
|
|
||||||
const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
|
const bool inplace = (bool) ((const int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
if (!inplace) {
|
if (!inplace) {
|
||||||
// run a separete kernel to cpy src->dst
|
// run a separete kernel to cpy src->dst
|
||||||
@ -1309,8 +1291,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
|
|
||||||
float min;
|
float min;
|
||||||
float max;
|
float max;
|
||||||
memcpy(&min, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
memcpy(&min, ((const int32_t *) dst->op_params) + 0, sizeof(float));
|
||||||
memcpy(&max, ((int32_t *) dst->op_params) + 1, sizeof(float));
|
memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float));
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
@ -1323,7 +1305,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
switch (ggml_get_unary_op(gf->nodes[i])) {
|
switch (ggml_get_unary_op(node)) {
|
||||||
// we are not taking into account the strides, so for now require contiguous tensors
|
// we are not taking into account the strides, so for now require contiguous tensors
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
@ -1422,7 +1404,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
GGML_METAL_LOG_WARN("%s: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
@ -1551,8 +1533,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
|
|
||||||
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||||
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||||
|
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
const int64_t nrows_y = src0->ne[1];
|
const int64_t nrows_y = src0->ne[1];
|
||||||
@ -1585,7 +1567,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
const int n_past = ((const int32_t *)(dst->op_params))[0];
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
@ -1644,9 +1626,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
{
|
{
|
||||||
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
struct ggml_tensor * src3 = node->src[3];
|
||||||
struct ggml_tensor * src4 = gf->nodes[i]->src[4];
|
struct ggml_tensor * src4 = node->src[4];
|
||||||
struct ggml_tensor * src5 = gf->nodes[i]->src[5];
|
struct ggml_tensor * src5 = node->src[5];
|
||||||
|
|
||||||
GGML_ASSERT(src3);
|
GGML_ASSERT(src3);
|
||||||
GGML_ASSERT(src4);
|
GGML_ASSERT(src4);
|
||||||
@ -2425,7 +2407,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
float eps;
|
float eps;
|
||||||
memcpy(&eps, dst->op_params + 1, sizeof(float));
|
memcpy(&eps, dst->op_params + 1, sizeof(float));
|
||||||
|
|
||||||
const int32_t n_groups = ((int32_t *) dst->op_params)[0];
|
const int32_t n_groups = ((const int32_t *) dst->op_params)[0];
|
||||||
|
|
||||||
int nth = 32; // SIMD width
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
@ -2479,11 +2461,11 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
|
|
||||||
const int nth = MIN(1024, ne00);
|
const int nth = MIN(1024, ne00);
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
const int n_past = ((const int32_t *) dst->op_params)[0];
|
||||||
const int n_dims = ((int32_t *) dst->op_params)[1];
|
const int n_dims = ((const int32_t *) dst->op_params)[1];
|
||||||
const int mode = ((int32_t *) dst->op_params)[2];
|
const int mode = ((const int32_t *) dst->op_params)[2];
|
||||||
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
// skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
|
||||||
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
const int n_ctx_orig = ((const int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
float freq_base;
|
float freq_base;
|
||||||
float freq_scale;
|
float freq_scale;
|
||||||
@ -2492,12 +2474,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
float beta_fast;
|
float beta_fast;
|
||||||
float beta_slow;
|
float beta_slow;
|
||||||
|
|
||||||
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float));
|
||||||
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float));
|
||||||
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
|
memcpy(&ext_factor, (const int32_t *) dst->op_params + 7, sizeof(float));
|
||||||
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
|
memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float));
|
||||||
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
|
||||||
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
|
||||||
|
|
||||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||||
|
|
||||||
@ -2686,8 +2668,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
float start;
|
float start;
|
||||||
float step;
|
float step;
|
||||||
|
|
||||||
memcpy(&start, ((int32_t *) dst->op_params) + 0, sizeof(float));
|
memcpy(&start, ((const int32_t *) dst->op_params) + 0, sizeof(float));
|
||||||
memcpy(&step, ((int32_t *) dst->op_params) + 2, sizeof(float));
|
memcpy(&step, ((const int32_t *) dst->op_params) + 2, sizeof(float));
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline;
|
||||||
|
|
||||||
@ -2786,7 +2768,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
|
|
||||||
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
||||||
|
|
||||||
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
struct ggml_tensor * src3 = node->src[3];
|
||||||
|
|
||||||
size_t offs_src3 = 0;
|
size_t offs_src3 = 0;
|
||||||
|
|
||||||
@ -2811,9 +2793,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
float scale;
|
float scale;
|
||||||
float max_bias;
|
float max_bias;
|
||||||
float logit_softcap;
|
float logit_softcap;
|
||||||
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
|
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||||
memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||||
memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
|
memcpy(&logit_softcap, ((const int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
|
||||||
|
|
||||||
if (logit_softcap != 0.0f) {
|
if (logit_softcap != 0.0f) {
|
||||||
scale /= logit_softcap;
|
scale /= logit_softcap;
|
||||||
@ -3014,10 +2996,87 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
|
GGML_METAL_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
struct ggml_backend_metal_context * ctx,
|
||||||
|
struct ggml_cgraph * gf) {
|
||||||
|
// number of nodes encoded by the main thread (empirically determined)
|
||||||
|
const int n_main = 128;
|
||||||
|
|
||||||
|
// number of threads in addition to the main thread
|
||||||
|
const int n_cb = ctx->n_cb;
|
||||||
|
|
||||||
|
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
|
||||||
|
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
|
||||||
|
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
|
||||||
|
// each thread creates it's own command buffer and enqueues the ops in parallel
|
||||||
|
//
|
||||||
|
// tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
|
||||||
|
|
||||||
|
@autoreleasepool {
|
||||||
|
ctx->gf = gf;
|
||||||
|
|
||||||
|
ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
|
||||||
|
ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
|
||||||
|
|
||||||
|
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
|
||||||
|
|
||||||
|
const bool should_capture = ctx->capture_next_compute;
|
||||||
|
if (should_capture) {
|
||||||
|
ctx->capture_next_compute = false;
|
||||||
|
|
||||||
|
if (!ctx->capture_started) {
|
||||||
|
// create capture scope
|
||||||
|
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
|
||||||
|
|
||||||
|
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
||||||
|
descriptor.captureObject = ctx->capture_scope;
|
||||||
|
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
|
||||||
|
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
|
||||||
|
|
||||||
|
NSError * error = nil;
|
||||||
|
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
||||||
|
GGML_METAL_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
|
||||||
|
GGML_ABORT("capture failed");
|
||||||
|
} else {
|
||||||
|
[ctx->capture_scope beginScope];
|
||||||
|
ctx->capture_started = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
|
||||||
|
ctx->encode_async = ^(size_t iter) {
|
||||||
|
const int cb_idx = iter;
|
||||||
|
const int n_cb_l = ctx->n_cb;
|
||||||
|
|
||||||
|
const int n_nodes_0 = ctx->n_nodes_0;
|
||||||
|
const int n_nodes_1 = ctx->n_nodes_1;
|
||||||
|
|
||||||
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
||||||
|
|
||||||
|
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx];
|
||||||
|
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: ctx->edesc];
|
||||||
|
|
||||||
|
int node_start = 0;
|
||||||
|
int node_end = n_nodes_0;
|
||||||
|
|
||||||
|
if (cb_idx < n_cb_l) {
|
||||||
|
node_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
|
||||||
|
node_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int idx = node_start; idx < node_end; ++idx) {
|
||||||
|
if (should_capture) {
|
||||||
|
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_encode_node(ctx, idx, encoder);
|
||||||
|
|
||||||
if (should_capture) {
|
if (should_capture) {
|
||||||
[encoder popDebugGroup];
|
[encoder popDebugGroup];
|
||||||
@ -3029,13 +3088,52 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||||
[command_buffer commit];
|
[command_buffer commit];
|
||||||
}
|
}
|
||||||
});
|
};
|
||||||
|
|
||||||
// Wait for completion and check status of each command buffer
|
// the main thread commits the first few commands immediately
|
||||||
|
// command_buffer[n_cb]
|
||||||
|
{
|
||||||
|
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
||||||
|
ctx->command_buffers[n_cb] = command_buffer;
|
||||||
|
|
||||||
|
[command_buffer enqueue];
|
||||||
|
ctx->encode_async(n_cb);
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepare the rest of the command buffers asynchronously
|
||||||
|
// command_buffer[0.. n_cb)
|
||||||
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||||
|
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
||||||
|
ctx->command_buffers[cb_idx] = command_buffer;
|
||||||
|
|
||||||
|
// always enqueue the first two command buffers
|
||||||
|
// enqueue all of the command buffers if we don't need to abort
|
||||||
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||||
|
[command_buffer enqueue];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
|
||||||
|
|
||||||
|
// wait for completion and check status of each command buffer
|
||||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||||
|
{
|
||||||
|
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb];
|
||||||
|
[command_buffer waitUntilCompleted];
|
||||||
|
|
||||||
|
MTLCommandBufferStatus status = [command_buffer status];
|
||||||
|
if (status != MTLCommandBufferStatusCompleted) {
|
||||||
|
GGML_METAL_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
||||||
|
if (status == MTLCommandBufferStatusError) {
|
||||||
|
GGML_METAL_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]);
|
||||||
|
}
|
||||||
|
|
||||||
|
return GGML_STATUS_FAILED;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_cb; ++i) {
|
for (int i = 0; i < n_cb; ++i) {
|
||||||
id<MTLCommandBuffer> command_buffer = command_buffers[i];
|
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i];
|
||||||
[command_buffer waitUntilCompleted];
|
[command_buffer waitUntilCompleted];
|
||||||
|
|
||||||
MTLCommandBufferStatus status = [command_buffer status];
|
MTLCommandBufferStatus status = [command_buffer status];
|
||||||
@ -3048,12 +3146,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
return GGML_STATUS_FAILED;
|
return GGML_STATUS_FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
|
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil);
|
||||||
if (!next_buffer) {
|
if (!next_buffer) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
|
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
|
||||||
if (next_queued) {
|
if (next_queued) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -3066,11 +3164,12 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
[next_buffer commit];
|
[next_buffer commit];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (should_capture) {
|
if (!should_capture && ctx->capture_started) {
|
||||||
|
[ctx->capture_scope endScope];
|
||||||
[[MTLCaptureManager sharedCaptureManager] stopCapture];
|
[[MTLCaptureManager sharedCaptureManager] stopCapture];
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return GGML_STATUS_SUCCESS;
|
return GGML_STATUS_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3405,6 +3504,25 @@ GGML_CALL static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, g
|
|||||||
UNUSED(backend);
|
UNUSED(backend);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||||
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||||
|
|
||||||
|
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
||||||
|
|
||||||
|
if (ctx->n_cb != n_cb) {
|
||||||
|
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
|
||||||
|
|
||||||
|
if (ctx->n_cb > 2) {
|
||||||
|
GGML_METAL_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why?
|
||||||
|
//ctx->encode_async = ^(size_t iter) {
|
||||||
|
// ...
|
||||||
|
//};
|
||||||
|
}
|
||||||
|
|
||||||
static struct ggml_backend_i ggml_backend_metal_i = {
|
static struct ggml_backend_i ggml_backend_metal_i = {
|
||||||
/* .get_name = */ ggml_backend_metal_name,
|
/* .get_name = */ ggml_backend_metal_name,
|
||||||
/* .free = */ ggml_backend_metal_free,
|
/* .free = */ ggml_backend_metal_free,
|
||||||
@ -3439,35 +3557,29 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_t ggml_backend_metal_init(void) {
|
ggml_backend_t ggml_backend_metal_init(void) {
|
||||||
struct ggml_backend_metal_context * ctx = ggml_metal_init(GGML_DEFAULT_N_THREADS);
|
struct ggml_backend_metal_context * ctx = ggml_metal_init();
|
||||||
if (ctx == NULL) {
|
if (ctx == NULL) {
|
||||||
GGML_METAL_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
GGML_METAL_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_t metal_backend = malloc(sizeof(struct ggml_backend));
|
ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
|
||||||
|
|
||||||
*metal_backend = (struct ggml_backend) {
|
*backend = (struct ggml_backend) {
|
||||||
/* .guid = */ ggml_backend_metal_guid(),
|
/* .guid = */ ggml_backend_metal_guid(),
|
||||||
/* .interface = */ ggml_backend_metal_i,
|
/* .interface = */ ggml_backend_metal_i,
|
||||||
/* .context = */ ctx,
|
/* .context = */ ctx,
|
||||||
};
|
};
|
||||||
|
|
||||||
return metal_backend;
|
ggml_backend_metal_set_n_cb(backend, 1);
|
||||||
|
|
||||||
|
return backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
bool ggml_backend_is_metal(ggml_backend_t backend) {
|
||||||
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_metal_guid());
|
||||||
}
|
}
|
||||||
|
|
||||||
void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|
||||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
|
||||||
|
|
||||||
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
|
||||||
|
|
||||||
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
|
||||||
}
|
|
||||||
|
|
||||||
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
|
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
|
||||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||||
|
|
||||||
@ -3489,7 +3601,7 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
|
|||||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||||
|
|
||||||
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
|
||||||
ctx->should_capture_next_compute = true;
|
ctx->capture_next_compute = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
|
||||||
|
@ -17025,12 +17025,6 @@ static void llama_graph_compute(
|
|||||||
ggml_cgraph * gf,
|
ggml_cgraph * gf,
|
||||||
int n_threads,
|
int n_threads,
|
||||||
ggml_threadpool * threadpool) {
|
ggml_threadpool * threadpool) {
|
||||||
#ifdef GGML_USE_METAL
|
|
||||||
if (ggml_backend_is_metal(lctx.backend_metal)) {
|
|
||||||
ggml_backend_metal_set_n_cb(lctx.backend_metal, n_threads);
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
if (lctx.backend_cpu != nullptr) {
|
if (lctx.backend_cpu != nullptr) {
|
||||||
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
|
ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
|
||||||
ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
|
ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
|
||||||
|
Loading…
Reference in New Issue
Block a user