diff --git a/ggml.c b/ggml.c index d5a190f34..c33ad2ca9 100644 --- a/ggml.c +++ b/ggml.c @@ -2725,9 +2725,9 @@ struct ggml_context_container { // enum ggml_task_type { - GGML_TASK_INIT = 1, - GGML_TASK_COMPUTE = 2, - GGML_TASK_FINALIZE = 4, + GGML_TASK_INIT = 0, + GGML_TASK_COMPUTE, + GGML_TASK_FINALIZE, }; struct ggml_compute_params { @@ -9260,13 +9260,14 @@ static void ggml_graph_compute_thread(void * data) { int type = state->params.type; if (state->node) { if (state->params.ith < state->params.nth) { - if (type & GGML_TASK_INIT) + if (type == GGML_TASK_INIT) { state->params.type = GGML_TASK_INIT; ggml_compute_forward(&state->params, state->node); + type = GGML_TASK_COMPUTE; } - if (type & GGML_TASK_COMPUTE) + if (type == GGML_TASK_COMPUTE) { state->params.type = GGML_TASK_COMPUTE; ggml_compute_forward(&state->params, state->node); @@ -9579,7 +9580,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) continue; workers[next_task].params = (struct ggml_compute_params) { - .type = GGML_TASK_COMPUTE | GGML_TASK_INIT, + .type = GGML_TASK_INIT, .ith = 0, .nth = 1, .wsize = 0,