mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
metal : add abort callback (ggml/905)
This commit is contained in:
parent
ebd541a570
commit
85fca8deb6
@ -50,6 +50,8 @@ GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void
|
|||||||
|
|
||||||
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb);
|
||||||
|
|
||||||
|
GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
|
||||||
|
|
||||||
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
|
||||||
|
|
||||||
// helper to check if the device supports a specific family
|
// helper to check if the device supports a specific family
|
||||||
|
@ -224,6 +224,10 @@ struct ggml_metal_context {
|
|||||||
bool support_simdgroup_mm;
|
bool support_simdgroup_mm;
|
||||||
|
|
||||||
bool should_capture_next_compute;
|
bool should_capture_next_compute;
|
||||||
|
|
||||||
|
// abort ggml_metal_graph_compute if callback returns true
|
||||||
|
ggml_abort_callback abort_callback;
|
||||||
|
void * abort_callback_data;
|
||||||
};
|
};
|
||||||
|
|
||||||
// MSL code
|
// MSL code
|
||||||
@ -878,8 +882,11 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences];
|
||||||
command_buffer_builder[cb_idx] = command_buffer;
|
command_buffer_builder[cb_idx] = command_buffer;
|
||||||
|
|
||||||
// enqueue the command buffers in order to specify their execution order
|
// always enqueue the first two command buffers
|
||||||
[command_buffer enqueue];
|
// enqueue all of the command buffers if we don't need to abort
|
||||||
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||||
|
[command_buffer enqueue];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
const id<MTLCommandBuffer> *command_buffers = command_buffer_builder;
|
||||||
@ -2827,7 +2834,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
|
|
||||||
[encoder endEncoding];
|
[encoder endEncoding];
|
||||||
|
|
||||||
[command_buffer commit];
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||||
|
[command_buffer commit];
|
||||||
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Wait for completion and check status of each command buffer
|
// Wait for completion and check status of each command buffer
|
||||||
@ -2847,6 +2856,23 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||||||
|
|
||||||
return GGML_STATUS_FAILED;
|
return GGML_STATUS_FAILED;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil);
|
||||||
|
if (!next_buffer) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
|
||||||
|
if (next_queued) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
|
||||||
|
GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i);
|
||||||
|
return GGML_STATUS_ABORTED;
|
||||||
|
}
|
||||||
|
|
||||||
|
[next_buffer commit];
|
||||||
}
|
}
|
||||||
|
|
||||||
if (should_capture) {
|
if (should_capture) {
|
||||||
@ -3242,6 +3268,15 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
|||||||
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) {
|
||||||
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||||
|
|
||||||
|
struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context;
|
||||||
|
|
||||||
|
ctx->abort_callback = abort_callback;
|
||||||
|
ctx->abort_callback_data = user_data;
|
||||||
|
}
|
||||||
|
|
||||||
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
|
||||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user