Simplify the logic of scheduling

This commit is contained in:
Howard Su 2023-04-10 22:37:37 +08:00
parent 6d18c6ea3e
commit 94ddd6204c

37
ggml.c
View File

@ -9258,22 +9258,17 @@ struct ggml_compute_state {
static void ggml_graph_compute_thread(void * data) { static void ggml_graph_compute_thread(void * data) {
struct ggml_compute_state * state = (struct ggml_compute_state *) data; struct ggml_compute_state * state = (struct ggml_compute_state *) data;
int type = state->params.type; int type = state->params.type;
if (state->node) { if (type == GGML_TASK_INIT)
if (state->params.ith < state->params.nth) { {
if (type == GGML_TASK_INIT) state->params.type = GGML_TASK_INIT;
{ ggml_compute_forward(&state->params, state->node);
state->params.type = GGML_TASK_INIT; type = GGML_TASK_COMPUTE;
ggml_compute_forward(&state->params, state->node); }
type = GGML_TASK_COMPUTE;
}
if (type == GGML_TASK_COMPUTE) if (type == GGML_TASK_COMPUTE)
{ {
state->params.type = GGML_TASK_COMPUTE; state->params.type = GGML_TASK_COMPUTE;
ggml_compute_forward(&state->params, state->node); ggml_compute_forward(&state->params, state->node);
}
}
state->node = NULL;
} }
} }
@ -9284,18 +9279,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
// create thread pool // create thread pool
ctx->tpool = thpool_init(n_threads); ctx->tpool = thpool_init(n_threads);
for (int j = 0; j < n_threads - 1; j++) {
workers[j] = (struct ggml_compute_state) {
.params = {
.type = GGML_TASK_COMPUTE,
.ith = j + 1,
.nth = n_threads,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
},
.node = NULL,
};
}
// initialize tasks + work buffer // initialize tasks + work buffer
{ {