look forward more

This commit is contained in:
Howard Su 2023-04-08 19:55:29 +08:00
parent 921296c0d5
commit 3b03df5c05

68
ggml.c
View File

@ -9249,16 +9249,10 @@ typedef int ggml_lock_t;
#endif #endif
struct ggml_compute_state_shared {
int n_threads;
};
struct ggml_compute_state { struct ggml_compute_state {
struct ggml_compute_params params; struct ggml_compute_params params;
struct ggml_tensor * node; struct ggml_tensor * node;
struct ggml_compute_state_shared * shared;
}; };
static void ggml_graph_compute_thread(void * data) { static void ggml_graph_compute_thread(void * data) {
@ -9284,9 +9278,6 @@ static void ggml_graph_compute_thread(void * data) {
void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
const int n_threads = cgraph->n_threads; const int n_threads = cgraph->n_threads;
struct ggml_compute_state_shared state_shared = {
/*.n_threads =*/ n_threads,
};
struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL; struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;
// create thread pool // create thread pool
@ -9302,7 +9293,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.wdata = cgraph->work ? cgraph->work->data : NULL, .wdata = cgraph->work ? cgraph->work->data : NULL,
}, },
.node = NULL, .node = NULL,
.shared = &state_shared,
}; };
} }
} }
@ -9520,6 +9510,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
struct ggml_tensor * node = cgraph->nodes[i]; struct ggml_tensor * node = cgraph->nodes[i];
if (node->n_tasks == 0)
{
// no work need to be done.
continue;
}
// TODO: this could be used to avoid unnecessary computations, but it needs to be improved // TODO: this could be used to avoid unnecessary computations, but it needs to be improved
//if (node->grad == NULL && node->perf_runs > 0) { //if (node->grad == NULL && node->perf_runs > 0) {
// continue; // continue;
@ -9558,30 +9553,31 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} }
else else
{ {
if (i + 1 < cgraph->n_nodes) int start = i;
int end = i + 1;
while (end < cgraph->n_nodes && next_task < n_threads && (end - start) < n_threads * 2)
{ {
struct ggml_tensor * next = cgraph->nodes[i + 1]; struct ggml_tensor * next = cgraph->nodes[end];
if (next->src0 != node && next->src1 != node && next->n_tasks == 1) end++;
{
workers[next_task].params = (struct ggml_compute_params) { if (next->n_tasks != 1)
.type = GGML_TASK_COMPUTE | GGML_TASK_INIT, continue;
.ith = 0,
.nth = 1, // check src depedency
.wsize = 0, bool is_dep = false;
.wdata = NULL, for (int k = start; k < end; k++)
};
workers[next_task].node = next;
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
next_task++;
if (i + 2 < cgraph->n_nodes)
{
struct ggml_tensor * prev = cgraph->nodes[i + 1];
struct ggml_tensor * next = cgraph->nodes[i + 2];
if (next->src0 != node && next->src1 != node && next->n_tasks == 1 &&
next->src0 != prev && next->src1 != prev
)
{ {
struct ggml_tensor * node = cgraph->nodes[k];
if (next->src0 == node || next->src1 == node)
{
is_dep = true;
break;
}
}
if (is_dep)
continue;
workers[next_task].params = (struct ggml_compute_params) { workers[next_task].params = (struct ggml_compute_params) {
.type = GGML_TASK_COMPUTE | GGML_TASK_INIT, .type = GGML_TASK_COMPUTE | GGML_TASK_INIT,
.ith = 0, .ith = 0,
@ -9590,15 +9586,13 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.wdata = NULL, .wdata = NULL,
}; };
workers[next_task].node = next; workers[next_task].node = next;
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]); thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
next->n_tasks = 0; // indicate this node is caculated
next_task++; next_task++;
//printf("Combine task [%d, %d]\n", start, end);
} }
} }
}
}
}
params.type = GGML_TASK_COMPUTE; params.type = GGML_TASK_COMPUTE;
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);