look forward more

This commit is contained in:
Howard Su 2023-04-08 19:55:29 +08:00
parent 921296c0d5
commit 3b03df5c05

82
ggml.c
View File

@ -9249,16 +9249,10 @@ typedef int ggml_lock_t;
#endif
struct ggml_compute_state_shared {
int n_threads;
};
struct ggml_compute_state {
struct ggml_compute_params params;
struct ggml_tensor * node;
struct ggml_compute_state_shared * shared;
};
static void ggml_graph_compute_thread(void * data) {
@ -9284,9 +9278,6 @@ static void ggml_graph_compute_thread(void * data) {
void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) {
const int n_threads = cgraph->n_threads;
struct ggml_compute_state_shared state_shared = {
/*.n_threads =*/ n_threads,
};
struct ggml_compute_state * workers = n_threads > 1 ? alloca(sizeof(struct ggml_compute_state)*(n_threads - 1)) : NULL;
// create thread pool
@ -9302,7 +9293,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.wdata = cgraph->work ? cgraph->work->data : NULL,
},
.node = NULL,
.shared = &state_shared,
};
}
}
@ -9520,6 +9510,11 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
struct ggml_tensor * node = cgraph->nodes[i];
if (node->n_tasks == 0)
{
// no work need to be done.
continue;
}
// TODO: this could be used to avoid unnecessary computations, but it needs to be improved
//if (node->grad == NULL && node->perf_runs > 0) {
// continue;
@ -9558,46 +9553,45 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
}
else
{
if (i + 1 < cgraph->n_nodes)
int start = i;
int end = i + 1;
while (end < cgraph->n_nodes && next_task < n_threads && (end - start) < n_threads * 2)
{
struct ggml_tensor * next = cgraph->nodes[i + 1];
if (next->src0 != node && next->src1 != node && next->n_tasks == 1)
{
workers[next_task].params = (struct ggml_compute_params) {
.type = GGML_TASK_COMPUTE | GGML_TASK_INIT,
.ith = 0,
.nth = 1,
.wsize = 0,
.wdata = NULL,
};
workers[next_task].node = next;
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
next_task++;
struct ggml_tensor * next = cgraph->nodes[end];
end++;
if (i + 2 < cgraph->n_nodes)
if (next->n_tasks != 1)
continue;
// check src depedency
bool is_dep = false;
for (int k = start; k < end; k++)
{
struct ggml_tensor * node = cgraph->nodes[k];
if (next->src0 == node || next->src1 == node)
{
struct ggml_tensor * prev = cgraph->nodes[i + 1];
struct ggml_tensor * next = cgraph->nodes[i + 2];
if (next->src0 != node && next->src1 != node && next->n_tasks == 1 &&
next->src0 != prev && next->src1 != prev
)
{
workers[next_task].params = (struct ggml_compute_params) {
.type = GGML_TASK_COMPUTE | GGML_TASK_INIT,
.ith = 0,
.nth = 1,
.wsize = 0,
.wdata = NULL,
};
workers[next_task].node = next;
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
next_task++;
}
is_dep = true;
break;
}
}
if (is_dep)
continue;
workers[next_task].params = (struct ggml_compute_params) {
.type = GGML_TASK_COMPUTE | GGML_TASK_INIT,
.ith = 0,
.nth = 1,
.wsize = 0,
.wdata = NULL,
};
workers[next_task].node = next;
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
next->n_tasks = 0; // indicate this node is caculated
next_task++;
//printf("Combine task [%d, %d]\n", start, end);
}
}
params.type = GGML_TASK_COMPUTE;