diff --git a/ggml.c b/ggml.c index c33ad2ca9..cebf43d57 100644 --- a/ggml.c +++ b/ggml.c @@ -9279,23 +9279,22 @@ static void ggml_graph_compute_thread(void * data) { void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { const int n_threads = cgraph->n_threads; - struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL; + const int max_requests = n_threads * 5; + struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*(max_requests)); // create thread pool - if (n_threads > 1) { - ctx->tpool = thpool_init(n_threads); - for (int j = 0; j < n_threads - 1; j++) { - workers[j] = (struct ggml_compute_state) { - .params = { - .type = GGML_TASK_COMPUTE, - .ith = j + 1, - .nth = n_threads, - .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }, - .node = NULL, - }; - } + ctx->tpool = thpool_init(n_threads); + for (int j = 0; j < n_threads - 1; j++) { + workers[j] = (struct ggml_compute_state) { + .params = { + .type = GGML_TASK_COMPUTE, + .ith = j + 1, + .nth = n_threads, + .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, + .wdata = cgraph->work ? cgraph->work->data : NULL, + }, + .node = NULL, + }; } // initialize tasks + work buffer @@ -9505,6 +9504,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) const int64_t perf_start_cycles = ggml_perf_cycles(); const int64_t perf_start_time_us = ggml_perf_time_us(); + const size_t wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0; + const void* wdata = cgraph->work ? cgraph->work->data : NULL; for (int i = 0; i < cgraph->n_nodes; i++) { GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, i, cgraph->n_nodes); @@ -9524,52 +9525,31 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) const int64_t perf_node_start_cycles = ggml_perf_cycles(); const int64_t perf_node_start_time_us = ggml_perf_time_us(); - // INIT - struct ggml_compute_params params = { - .type = GGML_TASK_INIT, - .ith = 0, - .nth = node->n_tasks, - .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }; - - ggml_compute_forward(¶ms, node); - int next_task = 0; // COMPUTE - if (node->n_tasks > 1) { - // launch thread pool - for (int j = 0; j < n_threads - 1; j++) { - workers[j].params = (struct ggml_compute_params) { - .type = GGML_TASK_COMPUTE, - .ith = j + 1, - .nth = node->n_tasks, - .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, - .wdata = cgraph->work ? cgraph->work->data : NULL, - }; - workers[j].node = node; - thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[j]); - } - } - else { int start = i; - int end = i + 1; - while (end < cgraph->n_nodes && next_task < n_threads && (end - start) < n_threads * 2) + int end = i; + while (end < cgraph->n_nodes && (end - start) < n_threads * 2) { struct ggml_tensor * next = cgraph->nodes[end]; end++; - if (next->n_tasks != 1) + // already scheduled + if (next->n_tasks == 0) + continue; + + // if we have slots + if (next_task + next->n_tasks > max_requests) continue; // check src depedency bool is_dep = false; for (int k = start; k < end; k++) { - struct ggml_tensor * node = cgraph->nodes[k]; - if (next->src0 == node || next->src1 == node) + struct ggml_tensor * prev = cgraph->nodes[k]; + if (next->src0 == prev || next->src1 == prev) { is_dep = true; break; @@ -9579,29 +9559,42 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) if (is_dep) continue; - workers[next_task].params = (struct ggml_compute_params) { - .type = GGML_TASK_INIT, - .ith = 0, - .nth = 1, - .wsize = 0, - .wdata = NULL, - }; - workers[next_task].node = next; + if (next->n_tasks > 1) + { + // run INIT in main thread if it is multi thread operator + struct ggml_compute_params params = { + .type = GGML_TASK_INIT, + .ith = 0, + .nth = next->n_tasks, + .wsize = wsize, + .wdata = wdata, + }; - thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]); + ggml_compute_forward(¶ms, next); + } + + for (int j = 0; j < next->n_tasks; j++) { + workers[next_task].params = (struct ggml_compute_params){ + // single thread operator runs INIT in worker thread + .type = next->n_tasks == 1 ? GGML_TASK_INIT : GGML_TASK_COMPUTE, + .ith = j, + .nth = next->n_tasks, + + // TODO: Potential race on wdata + .wsize = wsize, + .wdata = wdata, + }; + workers[next_task].node = next; + + thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]); + next_task++; + } next->n_tasks = 0; // indicate this node is caculated - next_task++; - //printf("Combine task [%d, %d]\n", start, end); } } - params.type = GGML_TASK_COMPUTE; - ggml_compute_forward(¶ms, node); - // wait for thread pool - if (node->n_tasks > 1 || next_task != 0) { - thpool_wait(ctx->tpool); - } + thpool_wait(ctx->tpool); #if 0 // FINALIZE if (node->n_tasks > 1) {