ggml : add abort_callback for cpu backend (ggml/725)

* a way to use abort_callback with the cpu backend

* whisper update
This commit is contained in:
Michael Podvitskiy 2024-02-09 10:42:27 +01:00 committed by Georgi Gerganov
parent 4b7b38bef5
commit 4633d93af0
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 33 additions and 9 deletions

View File

@ -653,6 +653,9 @@ struct ggml_backend_cpu_context {
int n_threads; int n_threads;
void * work_data; void * work_data;
size_t work_size; size_t work_size;
ggml_abort_callback abort_callback;
void * abort_callback_data;
}; };
GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) { GGML_CALL static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
@ -691,6 +694,9 @@ GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(gg
cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
} }
cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
return cpu_plan; return cpu_plan;
} }
@ -721,9 +727,11 @@ GGML_CALL static bool ggml_backend_cpu_graph_compute(ggml_backend_t backend, str
cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size); cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
cpu_ctx->work_size = cplan.work_size; cpu_ctx->work_size = cplan.work_size;
} }
cplan.work_data = cpu_ctx->work_data; cplan.work_data = cpu_ctx->work_data;
cplan.abort_callback = cpu_ctx->abort_callback;
cplan.abort_callback_data = cpu_ctx->abort_callback_data;
ggml_graph_compute(cgraph, &cplan); ggml_graph_compute(cgraph, &cplan);
return true; return true;
} }
@ -759,9 +767,11 @@ static struct ggml_backend_i cpu_backend_i = {
ggml_backend_t ggml_backend_cpu_init(void) { ggml_backend_t ggml_backend_cpu_init(void) {
struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
ctx->n_threads = GGML_DEFAULT_N_THREADS; ctx->n_threads = GGML_DEFAULT_N_THREADS;
ctx->work_data = NULL; ctx->work_data = NULL;
ctx->work_size = 0; ctx->work_size = 0;
ctx->abort_callback = NULL;
ctx->abort_callback_data = NULL;
ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
@ -783,6 +793,14 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
ctx->n_threads = n_threads; ctx->n_threads = n_threads;
} }
void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data) {
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
ctx->abort_callback = abort_callback;
ctx->abort_callback_data = abort_callback_data;
}
GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) { GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size) {
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size); return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
} }

View File

@ -83,8 +83,9 @@ extern "C" {
GGML_API ggml_backend_t ggml_backend_cpu_init(void); GGML_API ggml_backend_t ggml_backend_cpu_init(void);
GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend); GGML_API GGML_CALL bool ggml_backend_is_cpu (ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads); GGML_API void ggml_backend_cpu_set_n_threads (ggml_backend_t backend_cpu, int n_threads);
GGML_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
// Create a backend buffer from an existing pointer // Create a backend buffer from an existing pointer
GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size); GGML_API GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, size_t size);

2
ggml.c
View File

@ -16649,7 +16649,7 @@ struct ggml_compute_state_shared {
atomic_int node_n; // active graph node atomic_int node_n; // active graph node
atomic_int node_task; // active graph node task phase atomic_int node_task; // active graph node task phase
bool (*abort_callback)(void * data); // abort ggml_graph_compute when true ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
void * abort_callback_data; void * abort_callback_data;
}; };

9
ggml.h
View File

@ -567,6 +567,11 @@ extern "C" {
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
// Abort callback
// If not NULL, called before ggml computation
// If it returns true, the computation is aborted
typedef bool (*ggml_abort_callback)(void * data);
// the compute plan that needs to be prepared for ggml_graph_compute() // the compute plan that needs to be prepared for ggml_graph_compute()
// since https://github.com/ggerganov/ggml/issues/287 // since https://github.com/ggerganov/ggml/issues/287
struct ggml_cplan { struct ggml_cplan {
@ -576,8 +581,8 @@ extern "C" {
int n_threads; int n_threads;
// abort ggml_graph_compute when true // abort ggml_graph_compute when true
bool (*abort_callback)(void * data); ggml_abort_callback abort_callback;
void * abort_callback_data; void * abort_callback_data;
}; };
enum ggml_cgraph_eval_order { enum ggml_cgraph_eval_order {