mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 13:30:35 +00:00
ggml : use dynamic thread scheduling for matrix multiplication (#6915)
* Just reordering some structs. * Adding in the calls to mm_pause * Passing around the state * Renaming and moving a bunch of variables around. * Extracting the logic to it's own function. * Moving some variable definitions into the chunk function. * Moving some variables around * moving src1_cont inside * Moving row_size * adding the current_chunk * Reorg the code. * Formatting to match the orig patch * starting to setup the chunking variables * Starting the buildup of the loop * The yield shouldn't be necessary. * adding the looping structure based on the chunk configuration. * Add in the re-chunking code. * Making it much more likely to rechunk. * disable resizing if numa is enabled. * Updating comments with what we've learned. * Fix formatting * Couple more formatting fixes. * More style fixes. * Fix Warnings * Going with unused because there's conditional logic that needs it. * Update ggml.c * Update ggml.c ---------
This commit is contained in:
parent
dc020985b8
commit
e1b40ac3b9
381
ggml.c
381
ggml.c
@ -112,6 +112,8 @@ typedef void * thread_ret_t;
|
||||
|
||||
#endif
|
||||
|
||||
typedef pthread_t ggml_thread_t;
|
||||
|
||||
#ifdef GGML_USE_CPU_HBM
|
||||
#include <hbwmalloc.h>
|
||||
#endif
|
||||
@ -1539,6 +1541,59 @@ static inline void __sse_f16x4_store(ggml_fp16_t *x, __m128 y) {
|
||||
#define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR)
|
||||
#endif
|
||||
|
||||
//
|
||||
// ggml context
|
||||
//
|
||||
|
||||
struct ggml_context {
|
||||
size_t mem_size;
|
||||
void* mem_buffer;
|
||||
bool mem_buffer_owned;
|
||||
bool no_alloc;
|
||||
bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
|
||||
|
||||
int n_objects;
|
||||
|
||||
struct ggml_object* objects_begin;
|
||||
struct ggml_object* objects_end;
|
||||
|
||||
struct ggml_scratch scratch;
|
||||
struct ggml_scratch scratch_save;
|
||||
};
|
||||
|
||||
struct ggml_context_container {
|
||||
bool used;
|
||||
|
||||
struct ggml_context context;
|
||||
};
|
||||
|
||||
struct ggml_compute_state_shared {
|
||||
const struct ggml_cgraph* cgraph;
|
||||
const struct ggml_cplan* cplan;
|
||||
|
||||
int64_t perf_node_start_cycles;
|
||||
int64_t perf_node_start_time_us;
|
||||
|
||||
const int n_threads;
|
||||
|
||||
// synchronization primitives
|
||||
atomic_int n_active; // num active threads
|
||||
atomic_int node_n; // active graph node
|
||||
atomic_int node_task; // active graph node task phase
|
||||
|
||||
ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
|
||||
void* abort_callback_data;
|
||||
|
||||
atomic_int current_chunk; // currently processing chunk during Mat_Mul, shared between all the threads.
|
||||
};
|
||||
|
||||
struct ggml_compute_state {
|
||||
ggml_thread_t thrd;
|
||||
int ith;
|
||||
struct ggml_compute_state_shared* shared;
|
||||
enum ggml_status ec;
|
||||
};
|
||||
|
||||
//
|
||||
// fundamental operations
|
||||
//
|
||||
@ -2385,32 +2440,6 @@ static void ggml_setup_op_has_task_pass(void) {
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// ggml context
|
||||
//
|
||||
|
||||
struct ggml_context {
|
||||
size_t mem_size;
|
||||
void * mem_buffer;
|
||||
bool mem_buffer_owned;
|
||||
bool no_alloc;
|
||||
bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
|
||||
|
||||
int n_objects;
|
||||
|
||||
struct ggml_object * objects_begin;
|
||||
struct ggml_object * objects_end;
|
||||
|
||||
struct ggml_scratch scratch;
|
||||
struct ggml_scratch scratch_save;
|
||||
};
|
||||
|
||||
struct ggml_context_container {
|
||||
bool used;
|
||||
|
||||
struct ggml_context context;
|
||||
};
|
||||
|
||||
//
|
||||
// NUMA support
|
||||
//
|
||||
@ -11815,9 +11844,101 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) {
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ggml_compute_forward_mul_mat_one_chunk(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
const int64_t num_rows_per_vec_dot,
|
||||
const int64_t ir0_start,
|
||||
const int64_t ir0_end,
|
||||
const int64_t ir1_start,
|
||||
const int64_t ir1_end) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
const bool src1_cont = ggml_is_contiguous(src1);
|
||||
|
||||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
const int64_t r3 = ne13 / ne03;
|
||||
|
||||
//printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
|
||||
|
||||
// threads with no work simply yield (not sure if it helps)
|
||||
if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
|
||||
return;
|
||||
}
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
||||
assert(ne12 % ne02 == 0);
|
||||
assert(ne13 % ne03 == 0);
|
||||
|
||||
// block-tiling attempt
|
||||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
|
||||
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
||||
|
||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||
// 16 * 2, accounting for mmla kernels
|
||||
float tmp[32];
|
||||
|
||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
|
||||
const int64_t i13 = (ir1 / (ne12 * ne1));
|
||||
const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
|
||||
const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
|
||||
|
||||
// broadcast src0 into src1
|
||||
const int64_t i03 = i13 / r3;
|
||||
const int64_t i02 = i12 / r2;
|
||||
|
||||
const int64_t i1 = i11;
|
||||
const int64_t i2 = i12;
|
||||
const int64_t i3 = i13;
|
||||
|
||||
const char * src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
|
||||
|
||||
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||
// the original src1 data pointer, so we should index using the indices directly
|
||||
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||
const char * src1_col = (const char*)wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
|
||||
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
|
||||
float * dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
||||
|
||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
||||
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
||||
//}
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||
}
|
||||
|
||||
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
||||
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_mul_mat(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
struct ggml_tensor * dst,
|
||||
struct ggml_compute_state * state) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src1 = dst->src[1];
|
||||
@ -11832,9 +11953,6 @@ static void ggml_compute_forward_mul_mat(
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
const bool src1_cont = ggml_is_contiguous(src1);
|
||||
|
||||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||
int64_t const vec_dot_num_rows = type_traits[type].nrows;
|
||||
@ -11855,8 +11973,10 @@ static void ggml_compute_forward_mul_mat(
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
// broadcast factors
|
||||
const int64_t r2 = ne12/ne02;
|
||||
const int64_t r3 = ne13/ne03;
|
||||
const int64_t r2 = ne12 / ne02;
|
||||
const int64_t r3 = ne13 / ne03;
|
||||
UNUSED(r2);
|
||||
UNUSED(r3);
|
||||
|
||||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
@ -11938,6 +12058,8 @@ static void ggml_compute_forward_mul_mat(
|
||||
#endif
|
||||
|
||||
#if GGML_USE_LLAMAFILE
|
||||
const bool src1_cont = ggml_is_contiguous(src1);
|
||||
|
||||
if (src1_cont) {
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++)
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++)
|
||||
@ -11963,6 +12085,8 @@ UseGgmlGemm1:;
|
||||
if (ith != 0) {
|
||||
return;
|
||||
}
|
||||
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||
atomic_store(&state->shared->current_chunk, nth);
|
||||
if (src1->type != vec_dot_type) {
|
||||
char * wdata = params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
@ -11987,11 +12111,11 @@ UseGgmlGemm1:;
|
||||
return;
|
||||
}
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
||||
#if GGML_USE_LLAMAFILE
|
||||
if (src1->type != vec_dot_type) {
|
||||
const void* wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
||||
for (int64_t i13 = 0; i13 < ne13; i13++)
|
||||
for (int64_t i12 = 0; i12 < ne12; i12++)
|
||||
if (!llamafile_sgemm(ne01, ne11, ne00/ggml_blck_size(src0->type),
|
||||
@ -12012,98 +12136,87 @@ UseGgmlGemm1:;
|
||||
UseGgmlGemm2:;
|
||||
#endif
|
||||
|
||||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = ne1*ne12*ne13; // src1 rows
|
||||
#ifdef GGML_PERF
|
||||
int chunks_executed = 0;
|
||||
UNUSED(chunks_executed);
|
||||
#endif
|
||||
|
||||
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
|
||||
// This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers)
|
||||
const int64_t nr0 = ne0;
|
||||
|
||||
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||
|
||||
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
|
||||
|
||||
const int64_t ith0 = ith % nth0;
|
||||
const int64_t ith1 = ith / nth0;
|
||||
|
||||
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
|
||||
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
|
||||
|
||||
const int64_t ir010 = dr0*ith0;
|
||||
const int64_t ir011 = MIN(ir010 + dr0, nr0);
|
||||
|
||||
const int64_t ir110 = dr1*ith1;
|
||||
const int64_t ir111 = MIN(ir110 + dr1, nr1);
|
||||
|
||||
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
|
||||
|
||||
// threads with no work simply yield (not sure if it helps)
|
||||
if (ir010 >= ir011 || ir110 >= ir111) {
|
||||
sched_yield();
|
||||
return;
|
||||
}
|
||||
|
||||
assert(ne12 % ne02 == 0);
|
||||
assert(ne13 % ne03 == 0);
|
||||
|
||||
// block-tiling attempt
|
||||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
// This is the size of the rest of the dimensions of the result
|
||||
const int64_t nr1 = ne1 * ne2 * ne3;
|
||||
|
||||
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
||||
int64_t nrc = vec_dot_num_rows;
|
||||
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
||||
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
||||
// this check can be removed once they are extended to support odd numbered rows/cols too
|
||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
||||
nrc = 1;
|
||||
num_rows_per_vec_dot = 1;
|
||||
}
|
||||
|
||||
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
||||
// Now select a reasonable chunk size.
|
||||
int chunk_size = 16;
|
||||
|
||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||
// 16 * 2, accounting for mmla kernels
|
||||
float tmp[32];
|
||||
// We need to step up the size if it's small
|
||||
if (nr0 == 1 || nr1 == 1) {
|
||||
chunk_size = 64;
|
||||
}
|
||||
|
||||
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
|
||||
const int64_t i13 = (ir1/(ne12*ne1));
|
||||
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
|
||||
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
|
||||
// distribute the work across the inner or outer loop based on which one is larger
|
||||
// The number of chunks in the 0/1 dim.
|
||||
// CEIL(nr0/chunk_size)
|
||||
int64_t nchunk0 = (nr0 + chunk_size - 1) / chunk_size;
|
||||
int64_t nchunk1 = (nr1 + chunk_size - 1) / chunk_size;
|
||||
|
||||
// broadcast src0 into src1
|
||||
const int64_t i03 = i13/r3;
|
||||
const int64_t i02 = i12/r2;
|
||||
// If the chunking is poor for the number of threads on this setup, scrap the whole plan. Re-chunk it by thread.
|
||||
// Also, chunking by thread was measured to have perform better on NUMA systems. See https://github.com/ggerganov/llama.cpp/pull/6915
|
||||
// In theory, chunking should be just as useful on NUMA and non NUMA systems, but testing disagreed with that.
|
||||
if (nchunk0 * nchunk1 < nth * 4 || ggml_is_numa()) {
|
||||
// distribute the thread work across the inner or outer loop based on which one is larger
|
||||
nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
|
||||
nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
|
||||
}
|
||||
|
||||
const int64_t i1 = i11;
|
||||
const int64_t i2 = i12;
|
||||
const int64_t i3 = i13;
|
||||
// The number of elements in each chunk
|
||||
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
|
||||
const int64_t dr1 = (nr1 + nchunk1 - 1) / nchunk1;
|
||||
|
||||
const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
|
||||
//if (ith == 0)
|
||||
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
|
||||
|
||||
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||
// the original src1 data pointer, so we should index using the indices directly
|
||||
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||
const char * src1_col = (const char *) wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
|
||||
: (i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
||||
//}
|
||||
while (current_chunk < nchunk0 * nchunk1) {
|
||||
const int64_t ith0 = current_chunk % nchunk0;
|
||||
const int64_t ith1 = current_chunk / nchunk0;
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
|
||||
}
|
||||
const int64_t ir0_start = dr0 * ith0;
|
||||
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
|
||||
|
||||
for (int cn = 0; cn < nrc; ++cn) {
|
||||
memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
||||
}
|
||||
}
|
||||
const int64_t ir1_start = dr1 * ith1;
|
||||
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
||||
|
||||
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end);
|
||||
|
||||
#ifdef GGML_PERF
|
||||
chunks_executed++;
|
||||
#endif
|
||||
|
||||
if (nth >= nchunk0 * nchunk1) {
|
||||
break;
|
||||
}
|
||||
|
||||
current_chunk = atomic_fetch_add(&state->shared->current_chunk, 1);
|
||||
}
|
||||
|
||||
#ifdef GGML_PERF
|
||||
// These numbers are useful when trying to measure how well the threading scheduling works.
|
||||
//int64_t workSize = (ne01 * ne11 * ne12 * ne13 * ne00) / nchunk0 / nchunk1;
|
||||
//float time = (ggml_perf_time_us() - t0);
|
||||
//printf("MUL_MAT = %f ms, [%d, %d, %d, %d] x [%d, %d, %d, %d] = %I64u, %f ops/usec in %d chunks.\n", time / 1000.0, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, workSize, (float)workSize/time, chunks_executed);
|
||||
#endif
|
||||
}
|
||||
|
||||
// ggml_compute_forward_mul_mat_id
|
||||
@ -17358,7 +17471,7 @@ static void ggml_compute_forward_cross_entropy_loss_back(
|
||||
|
||||
/////////////////////////////////
|
||||
|
||||
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_compute_state * state) {
|
||||
GGML_ASSERT(params);
|
||||
|
||||
if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
|
||||
@ -17456,7 +17569,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
ggml_compute_forward_mul_mat(params, tensor);
|
||||
ggml_compute_forward_mul_mat(params, tensor, state);
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
{
|
||||
@ -19072,8 +19185,6 @@ typedef int ggml_lock_t;
|
||||
|
||||
#define GGML_LOCK_INITIALIZER 0
|
||||
|
||||
typedef pthread_t ggml_thread_t;
|
||||
|
||||
#define ggml_thread_create pthread_create
|
||||
#define ggml_thread_join pthread_join
|
||||
|
||||
@ -19099,8 +19210,6 @@ typedef int ggml_lock_t;
|
||||
|
||||
#define GGML_LOCK_INITIALIZER 0
|
||||
|
||||
typedef pthread_t ggml_thread_t;
|
||||
|
||||
#define ggml_thread_create pthread_create
|
||||
#define ggml_thread_join pthread_join
|
||||
|
||||
@ -19180,31 +19289,6 @@ static void set_numa_thread_affinity(int thread_n) { UNUSED(thread_n); }
|
||||
static void clear_numa_thread_affinity(void) {}
|
||||
#endif
|
||||
|
||||
struct ggml_compute_state_shared {
|
||||
const struct ggml_cgraph * cgraph;
|
||||
const struct ggml_cplan * cplan;
|
||||
|
||||
int64_t perf_node_start_cycles;
|
||||
int64_t perf_node_start_time_us;
|
||||
|
||||
const int n_threads;
|
||||
|
||||
// synchronization primitives
|
||||
atomic_int n_active; // num active threads
|
||||
atomic_int node_n; // active graph node
|
||||
atomic_int node_task; // active graph node task phase
|
||||
|
||||
ggml_abort_callback abort_callback; // abort ggml_graph_compute when true
|
||||
void * abort_callback_data;
|
||||
};
|
||||
|
||||
struct ggml_compute_state {
|
||||
ggml_thread_t thrd;
|
||||
int ith;
|
||||
struct ggml_compute_state_shared * shared;
|
||||
enum ggml_status ec;
|
||||
};
|
||||
|
||||
static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const struct ggml_compute_state_shared * st) {
|
||||
int64_t cycles_cur = ggml_perf_cycles() - st->perf_node_start_cycles;
|
||||
int64_t time_us_cur = ggml_perf_time_us() - st->perf_node_start_time_us;
|
||||
@ -19477,6 +19561,10 @@ static void ggml_graph_compute_thread_sync_node(int * node_n, struct ggml_comput
|
||||
|
||||
* node_n = atomic_load(&state->shared->node_n);
|
||||
if (* node_n != last_node_n) break;
|
||||
#if defined(__SSE3__)
|
||||
// Tell the processor we're spinning. It's a processor hint for spinlocks.
|
||||
_mm_pause();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -19491,6 +19579,10 @@ static void ggml_graph_compute_thread_sync_task(int * task_phase, struct ggml_co
|
||||
|
||||
* task_phase = atomic_load(&state->shared->node_task);
|
||||
if (* task_phase != last_task_phase) break;
|
||||
#if defined(__SSE3__)
|
||||
// Tell the processor we're spinning. It's a processor hint for spinlocks.
|
||||
_mm_pause();
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -19530,7 +19622,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
struct ggml_tensor * node = cgraph->nodes[node_n];
|
||||
if (GGML_OP_HAS_FINALIZE[node->op]) {
|
||||
params.nth = ggml_get_n_tasks(node, n_threads, state->shared->n_threads);
|
||||
ggml_compute_forward(¶ms, node);
|
||||
ggml_compute_forward(¶ms, node, state);
|
||||
}
|
||||
ggml_graph_compute_perf_stats_node(node, state->shared);
|
||||
}
|
||||
@ -19550,17 +19642,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
/* INIT */
|
||||
if (GGML_OP_HAS_INIT[node->op]) {
|
||||
params.type = GGML_TASK_TYPE_INIT;
|
||||
ggml_compute_forward(¶ms, node);
|
||||
ggml_compute_forward(¶ms, node, state);
|
||||
}
|
||||
|
||||
// TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1,
|
||||
// they do something more efficient than spinning (?)
|
||||
params.type = GGML_TASK_TYPE_COMPUTE;
|
||||
ggml_compute_forward(¶ms, node);
|
||||
ggml_compute_forward(¶ms, node, state);
|
||||
|
||||
if (GGML_OP_HAS_FINALIZE[node->op]) {
|
||||
params.type = GGML_TASK_TYPE_FINALIZE;
|
||||
ggml_compute_forward(¶ms, node);
|
||||
ggml_compute_forward(¶ms, node, state);
|
||||
}
|
||||
|
||||
ggml_graph_compute_perf_stats_node(node, state->shared);
|
||||
@ -19599,7 +19691,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
|
||||
if (state->ith < n_tasks) {
|
||||
if (GGML_OP_HAS_INIT[node->op]) {
|
||||
ggml_compute_forward(¶ms, node);
|
||||
ggml_compute_forward(¶ms, node, state);
|
||||
}
|
||||
}
|
||||
|
||||
@ -19620,7 +19712,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
|
||||
|
||||
if (state->ith < n_tasks) {
|
||||
params.type = GGML_TASK_TYPE_COMPUTE;
|
||||
ggml_compute_forward(¶ms, node);
|
||||
ggml_compute_forward(¶ms, node, state);
|
||||
}
|
||||
|
||||
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
|
||||
@ -19871,6 +19963,7 @@ enum ggml_status ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cpl
|
||||
/*.node_task =*/ GGML_TASK_TYPE_FINALIZE,
|
||||
/*.abort_callback =*/ NULL,
|
||||
/*.abort_callback_data =*/ NULL,
|
||||
/*.current_chunk; =*/ 0,
|
||||
};
|
||||
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user