ggml : process mul mat rows in chunks

This commit is contained in:
Georgi Gerganov 2023-05-17 19:46:09 +03:00
parent 8a203f9fa1
commit 5a317898e8

45
ggml.c
View File

@ -3590,6 +3590,9 @@ struct ggml_compute_params {
// work buffer for all threads
size_t wsize;
void * wdata;
// atomic counter used to distribute chunks of work
atomic_int * aic;
};
//
@ -9754,6 +9757,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int ith = params->ith;
const int nth = params->nth;
UNUSED(ith);
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
@ -9867,6 +9872,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
}
}
atomic_store(params->aic, 0);
return;
}
@ -9874,22 +9881,22 @@ static void ggml_compute_forward_mul_mat_q_f32(
return;
}
// parallelize by src0 rows using ggml_vec_dot_q
// total rows in src0
const int nr = ne01*ne02*ne03;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
void * wdata = params->wdata;
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
for (int ir = ir0; ir < ir1; ++ir) {
// parallelize by src0 rows using ggml_vec_dot_q
const int nr = ggml_nrows(src0);
const int dr = (nr + 8*nth - 1)/(8*nth);
while (true) {
const int ir0 = atomic_fetch_add(params->aic, dr);
for (int ir = ir0; ir < ir0 + dr; ++ir) {
if (ir >= nr) {
break;
}
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
@ -9914,6 +9921,11 @@ static void ggml_compute_forward_mul_mat_q_f32(
}
}
if (ir0 + dr >= nr) {
break;
}
}
//int64_t t1 = ggml_time_us();
//static int64_t acc = 0;
//acc += t1 - t0;
@ -13749,6 +13761,7 @@ struct ggml_compute_state_shared {
// synchronization primitives
atomic_int n_ready;
atomic_int aic;
atomic_bool has_work;
atomic_bool stop; // stop all threads
};
@ -13817,6 +13830,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
/*.spin =*/ GGML_LOCK_INITIALIZER,
/*.n_threads =*/ n_threads,
/*.n_ready =*/ 0,
/*.aic =*/ 0,
/*.has_work =*/ false,
/*.stop =*/ false,
};
@ -13837,6 +13851,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.nth = n_threads,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
.aic = &state_shared.aic,
},
.node = NULL,
.shared = &state_shared,
@ -14126,6 +14141,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
/*.nth =*/ node->n_tasks,
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
/*.aic =*/ &state_shared.aic,
};
ggml_compute_forward(&params, node);
@ -14149,6 +14165,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
.aic = &state_shared.aic,
};
workers[j].node = node;
}
@ -14164,6 +14181,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
}
params.type = GGML_TASK_COMPUTE;
params.aic = &state_shared.aic;
ggml_compute_forward(&params, node);
// wait for thread pool
@ -14204,6 +14222,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL,
.aic = &state_shared.aic,
};
workers[j].node = node;
}