Run second operator when possible

This commit is contained in:
Howard Su 2023-04-07 23:51:46 +08:00
parent c640d2a4bd
commit 43dde039b0

45
ggml.c
View File

@ -2725,9 +2725,9 @@ struct ggml_context_container {
//
enum ggml_task_type {
GGML_TASK_INIT = 0,
GGML_TASK_COMPUTE,
GGML_TASK_FINALIZE,
GGML_TASK_INIT = 1,
GGML_TASK_COMPUTE = 2,
GGML_TASK_FINALIZE = 4,
};
struct ggml_compute_params {
@ -9262,9 +9262,20 @@ struct ggml_compute_state {
static void ggml_graph_compute_thread(void * data) {
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
int type = state->params.type;
if (state->node) {
if (state->params.ith < state->params.nth) {
ggml_compute_forward(&state->params, state->node);
if (type & GGML_TASK_INIT)
{
state->params.type = GGML_TASK_INIT;
ggml_compute_forward(&state->params, state->node);
}
if (type & GGML_TASK_COMPUTE)
{
state->params.type = GGML_TASK_COMPUTE;
ggml_compute_forward(&state->params, state->node);
}
}
state->node = NULL;
}
@ -9527,6 +9538,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
ggml_compute_forward(&params, node);
int next_task = 0;
// COMPUTE
if (node->n_tasks > 1) {
// launch thread pool
@ -9542,12 +9555,34 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[j]);
}
}
else
{
int start = i;
int end = i;
if (i + 1 < cgraph->n_nodes)
{
struct ggml_tensor * next = cgraph->nodes[i + 1];
if (next->src0 != node && next->src1 != node && next->n_tasks == 1)
{
workers[next_task].params = (struct ggml_compute_params) {
.type = GGML_TASK_COMPUTE | GGML_TASK_INIT,
.ith = 0,
.nth = 1,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
};
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
next_task++;
}
}
}
params.type = GGML_TASK_COMPUTE;
ggml_compute_forward(&params, node);
// wait for thread pool
if (node->n_tasks > 1) {
if (node->n_tasks > 1 || next_task != 0) {
thpool_wait(ctx->tpool);
}
#if 0