mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 03:14:35 +00:00
ggml : broadcast mul_mat + conv batch support (#2199)
* ggml : broadcast mul_mat + conv batch support * ggml : apply mul_mat broadcast fix by @jploski
This commit is contained in:
parent
4523d10d0c
commit
975221e954
112
ggml.c
112
ggml.c
@ -4168,10 +4168,9 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
|
||||
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
|
||||
|
||||
return
|
||||
(t0->ne[0] == t1->ne[0]) &&
|
||||
(t0->ne[2] == t1->ne[2]) &&
|
||||
(t0->ne[3] == t1->ne[3]);
|
||||
return (t0->ne[0] == t1->ne[0]) &&
|
||||
(t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
|
||||
(t1->ne[3]%t0->ne[3] == 0);
|
||||
}
|
||||
|
||||
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
|
||||
@ -6036,8 +6035,8 @@ struct ggml_tensor * ggml_mul_mat(
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne);
|
||||
const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
|
||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
|
||||
|
||||
result->op = GGML_OP_MUL_MAT;
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
@ -7173,7 +7172,6 @@ struct ggml_tensor* ggml_conv_2d(
|
||||
int d0,
|
||||
int d1) {
|
||||
|
||||
GGML_ASSERT(b->ne[3] == 1);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
||||
bool is_node = false;
|
||||
|
||||
@ -7185,7 +7183,7 @@ struct ggml_tensor* ggml_conv_2d(
|
||||
const int64_t ne[4] = {
|
||||
ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
|
||||
ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
|
||||
a->ne[3], 1,
|
||||
a->ne[3], b->ne[3],
|
||||
};
|
||||
struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||
|
||||
@ -10641,7 +10639,6 @@ static void ggml_compute_forward_rms_norm_back(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ggml_compute_forward_mul_mat
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
@ -10685,17 +10682,17 @@ static void ggml_compute_forward_mul_mat(
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
GGML_ASSERT(ne2 == ne12);
|
||||
GGML_ASSERT(ne3 == ne13);
|
||||
|
||||
// we don't support permuted src0 or src1
|
||||
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
|
||||
GGML_ASSERT(nb10 == sizeof(float));
|
||||
@ -10706,16 +10703,16 @@ static void ggml_compute_forward_mul_mat(
|
||||
GGML_ASSERT(nb1 <= nb2);
|
||||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
GGML_ASSERT(ne2 == ne02);
|
||||
GGML_ASSERT(ne3 == ne03);
|
||||
|
||||
// nb01 >= nb00 - src0 is not transposed
|
||||
// compute by src0 rows
|
||||
|
||||
#if defined(GGML_USE_CLBLAST)
|
||||
if (ggml_cl_can_mul_mat(src0, src1, dst)) {
|
||||
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
|
||||
// ref: https://github.com/ggerganov/ggml/pull/224
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
}
|
||||
@ -10725,6 +10722,11 @@ static void ggml_compute_forward_mul_mat(
|
||||
|
||||
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
|
||||
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
|
||||
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
|
||||
// ref: https://github.com/ggerganov/ggml/pull/224
|
||||
GGML_ASSERT(ne02 == ne12);
|
||||
GGML_ASSERT(ne03 == ne13);
|
||||
|
||||
if (params->ith != 0) {
|
||||
return;
|
||||
}
|
||||
@ -10794,41 +10796,44 @@ static void ggml_compute_forward_mul_mat(
|
||||
return;
|
||||
}
|
||||
|
||||
// parallelize by src0 rows using ggml_vec_dot_q
|
||||
// parallelize by src0 rows
|
||||
const int64_t dr = (ne01 + nth - 1)/nth;
|
||||
|
||||
// total rows in src0
|
||||
const int nr = ne01*ne02*ne03;
|
||||
const int64_t ir10 = dr*ith;
|
||||
const int64_t ir11 = MIN(ir10 + dr, ne01);
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
// src1 rows
|
||||
const int64_t nr1 = ne11*ne12*ne13;
|
||||
|
||||
void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
||||
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// src0 indices
|
||||
const int i03 = ir/(ne02*ne01);
|
||||
const int i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
|
||||
const int64_t i13 = (ir1/(ne12*ne11));
|
||||
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
|
||||
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
|
||||
|
||||
const int i13 = i03;
|
||||
const int i12 = i02;
|
||||
const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
|
||||
const int64_t i03 = (ir0/(ne02));
|
||||
// Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
|
||||
// See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
|
||||
// GG: this is likely the correct way to broadcast, though need some more thought
|
||||
// therefore leaving the comments to remind us for now
|
||||
const int64_t i02 = (i12 / (ne12 / ne02));
|
||||
// Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
|
||||
// const int64_t i02 = (ir0 - i03*ne02);
|
||||
|
||||
const int i0 = i01;
|
||||
const int i2 = i02;
|
||||
const int i3 = i03;
|
||||
const int64_t i1 = i11;
|
||||
const int64_t i2 = i12;
|
||||
const int64_t i3 = i13;
|
||||
|
||||
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
||||
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
|
||||
const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 );
|
||||
const char * src1_col = (const char *) wdata + (i11 + i12*ne11 + i13*ne12*ne11)*row_size;
|
||||
|
||||
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
||||
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
for (int64_t ic = 0; ic < ne11; ++ic) {
|
||||
vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
||||
for (int64_t ir = ir10; ir < ir11; ++ir) {
|
||||
vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
|
||||
}
|
||||
}
|
||||
|
||||
@ -13013,9 +13018,10 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
|
||||
{
|
||||
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
||||
|
||||
for (int i13 = 0; i13 < ne13; i13++) {
|
||||
for (int i12 = 0; i12 < ne12; i12++) {
|
||||
const float * const src = (float *)((char *) src1->data + i12*nb12);
|
||||
ggml_fp16_t * dst_data = wdata;
|
||||
const float * const src = (float *)((char *) src1->data + i13*nb13 + i12*nb12);
|
||||
ggml_fp16_t * dst_data = wdata + i13*(ne1*ne0*ew0);
|
||||
|
||||
for (int i1 = 0; i1 < ne1; i1++) {
|
||||
for (int i0 = 0; i0 < ne0; i0++) {
|
||||
@ -13029,6 +13035,7 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
@ -13049,14 +13056,16 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
|
||||
|
||||
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
|
||||
|
||||
for (int i3 = 0; i3 < ne3; i3++) {
|
||||
for (int i2 = ip0; i2 < ip1; i2++) {
|
||||
float * dst_data = (float *)((char *) dst->data + i2*nb2);
|
||||
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2);
|
||||
|
||||
for (int i1 = 0; i1 < ne1; ++i1) {
|
||||
for (int i0 = 0; i0 < ne0; ++i0) {
|
||||
ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
|
||||
(ggml_fp16_t *) ((char *) src0->data + i2*nb03),
|
||||
(ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0);
|
||||
(ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -13105,10 +13114,9 @@ static void ggml_compute_forward_conv_2d(
|
||||
|
||||
if (s0 == src0->ne[0] && s1 == src0->ne[1]) {
|
||||
ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst);
|
||||
}
|
||||
else {
|
||||
} else {
|
||||
GGML_ASSERT(false); // only stride equal to kernel size is supported
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// ggml_compute_forward_pool_1d_sk_p0
|
||||
@ -16558,8 +16566,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
|
||||
GGML_ASSERT(node->src[1]->ne[3] == 1);
|
||||
|
||||
const int64_t ne00 = node->src[0]->ne[0]; // W
|
||||
const int64_t ne01 = node->src[0]->ne[1]; // H
|
||||
const int64_t ne02 = node->src[0]->ne[2]; // C
|
||||
|
Loading…
Reference in New Issue
Block a user