diff --git a/ggml.c b/ggml.c index 0808d7ff2..945f907cd 100644 --- a/ggml.c +++ b/ggml.c @@ -2725,9 +2725,9 @@ struct ggml_context_container { // enum ggml_task_type { - GGML_TASK_INIT = 0, - GGML_TASK_COMPUTE, - GGML_TASK_FINALIZE, + GGML_TASK_INIT = 1, + GGML_TASK_COMPUTE = 2, + GGML_TASK_FINALIZE = 4, }; struct ggml_compute_params { @@ -9262,9 +9262,20 @@ struct ggml_compute_state { static void ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; + int type = state->params.type; if (state->node) { if (state->params.ith < state->params.nth) { - ggml_compute_forward(&state->params, state->node); + if (type & GGML_TASK_INIT) + { + state->params.type = GGML_TASK_INIT; + ggml_compute_forward(&state->params, state->node); + } + + if (type & GGML_TASK_COMPUTE) + { + state->params.type = GGML_TASK_COMPUTE; + ggml_compute_forward(&state->params, state->node); + } } state->node = NULL; } @@ -9527,6 +9538,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) ggml_compute_forward(¶ms, node); + int next_task = 0; + // COMPUTE if (node->n_tasks > 1) { // launch thread pool @@ -9542,12 +9555,34 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[j]); } } + else + { + int start = i; + int end = i; + if (i + 1 < cgraph->n_nodes) + { + struct ggml_tensor * next = cgraph->nodes[i + 1]; + if (next->src0 != node && next->src1 != node && next->n_tasks == 1) + { + workers[next_task].params = (struct ggml_compute_params) { + .type = GGML_TASK_COMPUTE | GGML_TASK_INIT, + .ith = 0, + .nth = 1, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }; + + thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]); + next_task++; + } + } + } params.type = GGML_TASK_COMPUTE; ggml_compute_forward(¶ms, node); // wait for thread pool - if (node->n_tasks > 1) { + if (node->n_tasks > 1 || next_task != 0) { thpool_wait(ctx->tpool); } #if 0