diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h index 6c3226c37..d483cf1ac 100644 --- a/ggml/include/ggml-metal.h +++ b/ggml/include/ggml-metal.h @@ -50,6 +50,8 @@ GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void GGML_API void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb); +GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); + GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); // helper to check if the device supports a specific family diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index b512eb0be..c19274176 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -224,6 +224,10 @@ struct ggml_metal_context { bool support_simdgroup_mm; bool should_capture_next_compute; + + // abort ggml_metal_graph_compute if callback returns true + ggml_abort_callback abort_callback; + void * abort_callback_data; }; // MSL code @@ -878,8 +882,11 @@ static enum ggml_status ggml_metal_graph_compute( id command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; command_buffer_builder[cb_idx] = command_buffer; - // enqueue the command buffers in order to specify their execution order - [command_buffer enqueue]; + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer enqueue]; + } } const id *command_buffers = command_buffer_builder; @@ -2827,7 +2834,9 @@ static enum ggml_status ggml_metal_graph_compute( [encoder endEncoding]; - [command_buffer commit]; + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [command_buffer commit]; + } }); // Wait for completion and check status of each command buffer @@ -2847,6 +2856,23 @@ static enum ggml_status ggml_metal_graph_compute( return GGML_STATUS_FAILED; } + + id next_buffer = (i + 1 < n_cb ? command_buffers[i + 1] : nil); + if (!next_buffer) { + continue; + } + + bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + GGML_METAL_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; } if (should_capture) { @@ -3242,6 +3268,15 @@ void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_BUFFERS); } +void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data) { + GGML_ASSERT(ggml_backend_is_metal(backend)); + + struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context; + + ctx->abort_callback = abort_callback; + ctx->abort_callback_data = user_data; +} + bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) { GGML_ASSERT(ggml_backend_is_metal(backend));