Rework scheduling algorithm.

This commit is contained in:
Howard Su 2023-04-10 22:24:27 +08:00
parent 2035a3cc29
commit 6f2a61eb4f

119
ggml.c
View File

@ -9279,23 +9279,22 @@ static void ggml_graph_compute_thread(void * data) {
void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
const int n_threads = cgraph->n_threads; const int n_threads = cgraph->n_threads;
struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL; const int max_requests = n_threads * 5;
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*(max_requests));
// create thread pool // create thread pool
if (n_threads > 1) { ctx->tpool = thpool_init(n_threads);
ctx->tpool = thpool_init(n_threads); for (int j = 0; j < n_threads - 1; j++) {
for (int j = 0; j < n_threads - 1; j++) { workers[j] = (struct ggml_compute_state) {
workers[j] = (struct ggml_compute_state) { .params = {
.params = { .type = GGML_TASK_COMPUTE,
.type = GGML_TASK_COMPUTE, .ith = j + 1,
.ith = j + 1, .nth = n_threads,
.nth = n_threads, .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, .wdata = cgraph->work ? cgraph->work->data : NULL,
.wdata = cgraph->work ? cgraph->work->data : NULL, },
}, .node = NULL,
.node = NULL, };
};
}
} }
// initialize tasks + work buffer // initialize tasks + work buffer
@ -9505,6 +9504,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
const int64_t perf_start_cycles = ggml_perf_cycles(); const int64_t perf_start_cycles = ggml_perf_cycles();
const int64_t perf_start_time_us = ggml_perf_time_us(); const int64_t perf_start_time_us = ggml_perf_time_us();
const size_t wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0;
const void* wdata = cgraph->work ? cgraph->work->data : NULL;
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes); GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes);
@ -9524,52 +9525,31 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
const int64_t perf_node_start_cycles = ggml_perf_cycles(); const int64_t perf_node_start_cycles = ggml_perf_cycles();
const int64_t perf_node_start_time_us = ggml_perf_time_us(); const int64_t perf_node_start_time_us = ggml_perf_time_us();
// INIT
struct ggml_compute_params params = {
.type = GGML_TASK_INIT,
.ith = 0,
.nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
};
ggml_compute_forward(&params, node);
int next_task = 0; int next_task = 0;
// COMPUTE // COMPUTE
if (node->n_tasks > 1) {
// launch thread pool
for (int j = 0; j < n_threads - 1; j++) {
workers[j].params = (struct ggml_compute_params) {
.type = GGML_TASK_COMPUTE,
.ith = j + 1,
.nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
};
workers[j].node = node;
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[j]);
}
}
else
{ {
int start = i; int start = i;
int end = i + 1; int end = i;
while (end < cgraph->n_nodes && next_task < n_threads && (end - start) < n_threads * 2) while (end < cgraph->n_nodes && (end - start) < n_threads * 2)
{ {
struct ggml_tensor * next = cgraph->nodes[end]; struct ggml_tensor * next = cgraph->nodes[end];
end++; end++;
if (next->n_tasks != 1) // already scheduled
if (next->n_tasks == 0)
continue;
// if we have slots
if (next_task + next->n_tasks > max_requests)
continue; continue;
// check src depedency // check src depedency
bool is_dep = false; bool is_dep = false;
for (int k = start; k < end; k++) for (int k = start; k < end; k++)
{ {
struct ggml_tensor * node = cgraph->nodes[k]; struct ggml_tensor * prev = cgraph->nodes[k];
if (next->src0 == node || next->src1 == node) if (next->src0 == prev || next->src1 == prev)
{ {
is_dep = true; is_dep = true;
break; break;
@ -9579,29 +9559,42 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
if (is_dep) if (is_dep)
continue; continue;
workers[next_task].params = (struct ggml_compute_params) { if (next->n_tasks > 1)
.type = GGML_TASK_INIT, {
.ith = 0, // run INIT in main thread if it is multi thread operator
.nth = 1, struct ggml_compute_params params = {
.wsize = 0, .type = GGML_TASK_INIT,
.wdata = NULL, .ith = 0,
}; .nth = next->n_tasks,
workers[next_task].node = next; .wsize = wsize,
.wdata = wdata,
};
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]); ggml_compute_forward(&params, next);
}
for (int j = 0; j < next->n_tasks; j++) {
workers[next_task].params = (struct ggml_compute_params){
// single thread operator runs INIT in worker thread
.type = next->n_tasks == 1 ? GGML_TASK_INIT : GGML_TASK_COMPUTE,
.ith = j,
.nth = next->n_tasks,
// TODO: Potential race on wdata
.wsize = wsize,
.wdata = wdata,
};
workers[next_task].node = next;
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
next_task++;
}
next->n_tasks = 0; // indicate this node is caculated next->n_tasks = 0; // indicate this node is caculated
next_task++;
//printf("Combine task [%d, %d]\n", start, end);
} }
} }
params.type = GGML_TASK_COMPUTE;
ggml_compute_forward(&params, node);
// wait for thread pool // wait for thread pool
if (node->n_tasks > 1 || next_task != 0) { thpool_wait(ctx->tpool);
thpool_wait(ctx->tpool);
}
#if 0 #if 0
// FINALIZE // FINALIZE
if (node->n_tasks > 1) { if (node->n_tasks > 1) {