mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 20:04:35 +00:00
ggml : process mul mat rows in chunks
This commit is contained in:
parent
8a203f9fa1
commit
5a317898e8
75
ggml.c
75
ggml.c
@ -3590,6 +3590,9 @@ struct ggml_compute_params {
|
|||||||
// work buffer for all threads
|
// work buffer for all threads
|
||||||
size_t wsize;
|
size_t wsize;
|
||||||
void * wdata;
|
void * wdata;
|
||||||
|
|
||||||
|
// atomic counter used to distribute chunks of work
|
||||||
|
atomic_int * aic;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -9754,6 +9757,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
UNUSED(ith);
|
||||||
|
|
||||||
GGML_ASSERT(ne02 == ne12);
|
GGML_ASSERT(ne02 == ne12);
|
||||||
GGML_ASSERT(ne03 == ne13);
|
GGML_ASSERT(ne03 == ne13);
|
||||||
GGML_ASSERT(ne2 == ne12);
|
GGML_ASSERT(ne2 == ne12);
|
||||||
@ -9867,6 +9872,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
atomic_store(params->aic, 0);
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -9874,43 +9881,48 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// parallelize by src0 rows using ggml_vec_dot_q
|
|
||||||
|
|
||||||
// total rows in src0
|
|
||||||
const int nr = ne01*ne02*ne03;
|
|
||||||
|
|
||||||
// rows per thread
|
|
||||||
const int dr = (nr + nth - 1)/nth;
|
|
||||||
|
|
||||||
// row range for this thread
|
|
||||||
const int ir0 = dr*ith;
|
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
|
||||||
|
|
||||||
void * wdata = params->wdata;
|
void * wdata = params->wdata;
|
||||||
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
||||||
|
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
// parallelize by src0 rows using ggml_vec_dot_q
|
||||||
// src0 indices
|
|
||||||
const int i03 = ir/(ne02*ne01);
|
|
||||||
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
|
||||||
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
|
||||||
|
|
||||||
const int i13 = i03;
|
const int nr = ggml_nrows(src0);
|
||||||
const int i12 = i02;
|
const int dr = (nr + 8*nth - 1)/(8*nth);
|
||||||
|
|
||||||
const int i0 = i01;
|
while (true) {
|
||||||
const int i2 = i02;
|
const int ir0 = atomic_fetch_add(params->aic, dr);
|
||||||
const int i3 = i03;
|
|
||||||
|
|
||||||
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
for (int ir = ir0; ir < ir0 + dr; ++ir) {
|
||||||
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
|
if (ir >= nr) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
// src0 indices
|
||||||
|
const int i03 = ir/(ne02*ne01);
|
||||||
|
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
||||||
|
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
assert(ne00 % 32 == 0);
|
const int i13 = i03;
|
||||||
|
const int i12 = i02;
|
||||||
|
|
||||||
for (int64_t ic = 0; ic < ne11; ++ic) {
|
const int i0 = i01;
|
||||||
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
const int i2 = i02;
|
||||||
|
const int i3 = i03;
|
||||||
|
|
||||||
|
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
||||||
|
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
|
||||||
|
|
||||||
|
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
||||||
|
|
||||||
|
assert(ne00 % 32 == 0);
|
||||||
|
|
||||||
|
for (int64_t ic = 0; ic < ne11; ++ic) {
|
||||||
|
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ir0 + dr >= nr) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -13749,6 +13761,7 @@ struct ggml_compute_state_shared {
|
|||||||
|
|
||||||
// synchronization primitives
|
// synchronization primitives
|
||||||
atomic_int n_ready;
|
atomic_int n_ready;
|
||||||
|
atomic_int aic;
|
||||||
atomic_bool has_work;
|
atomic_bool has_work;
|
||||||
atomic_bool stop; // stop all threads
|
atomic_bool stop; // stop all threads
|
||||||
};
|
};
|
||||||
@ -13817,6 +13830,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
/*.spin =*/ GGML_LOCK_INITIALIZER,
|
/*.spin =*/ GGML_LOCK_INITIALIZER,
|
||||||
/*.n_threads =*/ n_threads,
|
/*.n_threads =*/ n_threads,
|
||||||
/*.n_ready =*/ 0,
|
/*.n_ready =*/ 0,
|
||||||
|
/*.aic =*/ 0,
|
||||||
/*.has_work =*/ false,
|
/*.has_work =*/ false,
|
||||||
/*.stop =*/ false,
|
/*.stop =*/ false,
|
||||||
};
|
};
|
||||||
@ -13837,6 +13851,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
.nth = n_threads,
|
.nth = n_threads,
|
||||||
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
||||||
.wdata = cgraph->work ? cgraph->work->data : NULL,
|
.wdata = cgraph->work ? cgraph->work->data : NULL,
|
||||||
|
.aic = &state_shared.aic,
|
||||||
},
|
},
|
||||||
.node = NULL,
|
.node = NULL,
|
||||||
.shared = &state_shared,
|
.shared = &state_shared,
|
||||||
@ -14126,6 +14141,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
/*.nth =*/ node->n_tasks,
|
/*.nth =*/ node->n_tasks,
|
||||||
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
||||||
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
|
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
|
||||||
|
/*.aic =*/ &state_shared.aic,
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_compute_forward(¶ms, node);
|
ggml_compute_forward(¶ms, node);
|
||||||
@ -14149,6 +14165,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
.nth = node->n_tasks,
|
.nth = node->n_tasks,
|
||||||
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
||||||
.wdata = cgraph->work ? cgraph->work->data : NULL,
|
.wdata = cgraph->work ? cgraph->work->data : NULL,
|
||||||
|
.aic = &state_shared.aic,
|
||||||
};
|
};
|
||||||
workers[j].node = node;
|
workers[j].node = node;
|
||||||
}
|
}
|
||||||
@ -14164,6 +14181,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
}
|
}
|
||||||
|
|
||||||
params.type = GGML_TASK_COMPUTE;
|
params.type = GGML_TASK_COMPUTE;
|
||||||
|
params.aic = &state_shared.aic;
|
||||||
ggml_compute_forward(¶ms, node);
|
ggml_compute_forward(¶ms, node);
|
||||||
|
|
||||||
// wait for thread pool
|
// wait for thread pool
|
||||||
@ -14204,6 +14222,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||||||
.nth = node->n_tasks,
|
.nth = node->n_tasks,
|
||||||
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
|
||||||
.wdata = cgraph->work ? cgraph->work->data : NULL,
|
.wdata = cgraph->work ? cgraph->work->data : NULL,
|
||||||
|
.aic = &state_shared.aic,
|
||||||
};
|
};
|
||||||
workers[j].node = node;
|
workers[j].node = node;
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user