mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-06 00:34:35 +00:00
Run second operator when possible
This commit is contained in:
parent
c640d2a4bd
commit
43dde039b0
45
ggml.c
45
ggml.c
@ -2725,9 +2725,9 @@ struct ggml_context_container {
|
|||||||
//
|
//
|
||||||
|
|
||||||
enum ggml_task_type {
|
enum ggml_task_type {
|
||||||
GGML_TASK_INIT = 0,
|
GGML_TASK_INIT = 1,
|
||||||
GGML_TASK_COMPUTE,
|
GGML_TASK_COMPUTE = 2,
|
||||||
GGML_TASK_FINALIZE,
|
GGML_TASK_FINALIZE = 4,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_compute_params {
|
struct ggml_compute_params {
|
||||||
@ -9262,9 +9262,20 @@ struct ggml_compute_state {
|
|||||||
|
|
||||||
static void ggml_graph_compute_thread(void * data) {
|
static void ggml_graph_compute_thread(void * data) {
|
||||||
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
struct ggml_compute_state * state = (struct ggml_compute_state *) data;
|
||||||
|
int type = state->params.type;
|
||||||
if (state->node) {
|
if (state->node) {
|
||||||
if (state->params.ith < state->params.nth) {
|
if (state->params.ith < state->params.nth) {
|
||||||
ggml_compute_forward(&state->params, state->node);
|
if (type & GGML_TASK_INIT)
|
||||||
|
{
|
||||||
|
state->params.type = GGML_TASK_INIT;
|
||||||
|
ggml_compute_forward(&state->params, state->node);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (type & GGML_TASK_COMPUTE)
|
||||||
|
{
|
||||||
|
state->params.type = GGML_TASK_COMPUTE;
|
||||||
|
ggml_compute_forward(&state->params, state->node);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
state->node = NULL;
|
state->node = NULL;
|
||||||
}
|
}
|
||||||
@ -9527,6 +9538,8 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
|
|
||||||
ggml_compute_forward(¶ms, node);
|
ggml_compute_forward(¶ms, node);
|
||||||
|
|
||||||
|
int next_task = 0;
|
||||||
|
|
||||||
// COMPUTE
|
// COMPUTE
|
||||||
if (node->n_tasks > 1) {
|
if (node->n_tasks > 1) {
|
||||||
// launch thread pool
|
// launch thread pool
|
||||||
@ -9542,12 +9555,34 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[j]);
|
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[j]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
int start = i;
|
||||||
|
int end = i;
|
||||||
|
if (i + 1 < cgraph->n_nodes)
|
||||||
|
{
|
||||||
|
struct ggml_tensor * next = cgraph->nodes[i + 1];
|
||||||
|
if (next->src0 != node && next->src1 != node && next->n_tasks == 1)
|
||||||
|
{
|
||||||
|
workers[next_task].params = (struct ggml_compute_params) {
|
||||||
|
.type = GGML_TASK_COMPUTE | GGML_TASK_INIT,
|
||||||
|
.ith = 0,
|
||||||
|
.nth = 1,
|
||||||
|
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
||||||
|
.wdata = cgraph->work ? cgraph->work->data : NULL,
|
||||||
|
};
|
||||||
|
|
||||||
|
thpool_add_work(ctx->tpool, ggml_graph_compute_thread, &workers[next_task]);
|
||||||
|
next_task++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
params.type = GGML_TASK_COMPUTE;
|
params.type = GGML_TASK_COMPUTE;
|
||||||
ggml_compute_forward(¶ms, node);
|
ggml_compute_forward(¶ms, node);
|
||||||
|
|
||||||
// wait for thread pool
|
// wait for thread pool
|
||||||
if (node->n_tasks > 1) {
|
if (node->n_tasks > 1 || next_task != 0) {
|
||||||
thpool_wait(ctx->tpool);
|
thpool_wait(ctx->tpool);
|
||||||
}
|
}
|
||||||
#if 0
|
#if 0
|
||||||
|
Loading…
Reference in New Issue
Block a user