ggml : group all experts in a single ggml_mul_mat_id (#6505)

* ggml : group all experts in a single ggml_mul_mat_id
cuda : improve mmid row copy

* cuda : fix bin bcast with non-cont src0

* test-backend-ops : only run all mul mat tests for base types

* llama : disable moe offloading with SYCL

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
slaren 2024-04-18 15:18:48 +02:00 committed by GitHub
parent 03c0946d73
commit 0d56246f4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 971 additions and 821 deletions

View File

@ -44,7 +44,7 @@ private:
std::mutex m_mutex;
int m_last_call = 0;
std::vector<float> m_src1_data;
std::vector<int> m_ids; // the expert ids from ggml_mul_mat_id
std::vector<char> m_ids; // the expert ids from ggml_mul_mat_id
//
void save_imatrix(const char * file_name) const;
void keep_imatrix(int ncall) const;
@ -81,6 +81,7 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
if (ask) {
if (t->op == GGML_OP_MUL_MAT_ID) return true; // collect all indirect matrix multiplications
if (t->op != GGML_OP_MUL_MAT) return false;
// why are small batches ignored (<16 tokens)?
if (src1->ne[1] < 16 || src1->type != GGML_TYPE_F32) return false;
if (!(wname.substr(0, 4) == "blk." || (m_params.collect_output_weight && wname == "output.weight"))) return false;
return true;
@ -101,14 +102,19 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
// this has been adapted to the new format of storing merged experts in a single 3d tensor
// ref: https://github.com/ggerganov/llama.cpp/pull/6387
if (t->op == GGML_OP_MUL_MAT_ID) {
const int idx = ((int32_t *) t->op_params)[0];
// ids -> [n_experts_used, n_tokens]
// src1 -> [cols, n_expert_used, n_tokens]
const ggml_tensor * ids = t->src[2];
const int n_as = src0->ne[2];
const int n_ids = ids->ne[0];
// the top-k selected expert ids are stored in the ids tensor
// for simplicity, always copy ids to host, because it is small
GGML_ASSERT(ids->ne[1] == src1->ne[1]);
m_ids.resize(ggml_nbytes(ids)/sizeof(int));
// take into account that ids is not contiguous!
GGML_ASSERT(ids->ne[1] == src1->ne[2]);
m_ids.resize(ggml_nbytes(ids));
ggml_backend_tensor_get(ids, m_ids.data(), 0, ggml_nbytes(ids));
auto & e = m_stats[wname];
@ -118,9 +124,6 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
// using the following line, we can correct for that if needed by replacing the line above with:
//if (idx == t->src[0]->ne[0] - 1) ++e.ncall;
// loop over all possible experts, regardless if they are used or not in the batch
for (int ex = 0; ex < n_as; ++ex) {
size_t e_start = ex*src1->ne[0];
if (e.values.empty()) {
e.values.resize(src1->ne[0]*n_as, 0);
}
@ -129,17 +132,29 @@ bool IMatrixCollector::collect_imatrix(struct ggml_tensor * t, bool ask, void *
exit(1); //GGML_ASSERT(false);
}
if (m_params.verbosity > 1) {
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[1], (int)src1->type);
printf("%s[%d]: %32s, %s, %5d x %5d, %d\n", __func__, m_last_call, wname.c_str(), ggml_op_name(t->op), (int)src1->ne[0], (int)src1->ne[2], (int)src1->type);
}
for (int row = 0; row < (int)src1->ne[1]; ++row) {
const int excur = m_ids[row*n_as + idx];
// loop over all possible experts, regardless if they are used or not in the batch
for (int ex = 0; ex < n_as; ++ex) {
size_t e_start = ex*src1->ne[0];
for (int idx = 0; idx < n_ids; ++idx) {
for (int row = 0; row < (int)src1->ne[2]; ++row) {
const int excur = *(const int32_t *) (m_ids.data() + row*ids->nb[1] + idx*ids->nb[0]);
GGML_ASSERT(excur >= 0 && excur < n_as); // sanity check
if (excur != ex) continue;
const float * x = data + row * src1->ne[0];
const int64_t i11 = idx % src1->ne[1];
const int64_t i12 = row;
const float * x = (const float *)((const char *)data + i11*src1->nb[1] + i12*src1->nb[2]);
for (int j = 0; j < (int)src1->ne[0]; ++j) {
e.values[e_start + j] += x[j]*x[j];
}
}
}
if (e.ncall > m_last_call) {
m_last_call = e.ncall;
if (m_last_call % m_params.n_output_frequency == 0) {

View File

@ -1231,7 +1231,7 @@ static void ggml_cuda_op_mul_mat_cublas(
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool());
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
@ -1241,7 +1241,7 @@ static void ggml_cuda_op_mul_mat_cublas(
}
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool());
ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
if (src1->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
@ -1250,7 +1250,7 @@ static void ggml_cuda_op_mul_mat_cublas(
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(), row_diff*src1_ncols);
ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
@ -1960,20 +1960,73 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
}
}
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
int64_t ne11, int64_t ne10,
size_t nb11, size_t nb12) {
int32_t iid1 = blockIdx.x;
int32_t id = blockIdx.y;
const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
if (row_id_i != i02) {
return;
}
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
__shared__ int src1_row;
if (threadIdx.x == 0) {
src1_row = atomicAdd(cur_src1_row, 1);
row_mapping[src1_row] = {id, iid1};
}
__syncthreads();
const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
src1_row_contiguous[i] = src1_row_original[i];
}
}
static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
const mmid_row_mapping * __restrict__ row_mapping,
int64_t ne0,
size_t nb1, size_t nb2) {
int32_t i = blockIdx.x;
const int32_t i1 = row_mapping[i].i1;
const int32_t i2 = row_mapping[i].i2;
const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
dst_row_original[j] = dst_row_contiguous[j];
}
}
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
cudaStream_t stream = ctx.stream();
const size_t nb11 = src1->nb[1];
const size_t nb1 = dst->nb[1];
const int32_t id = ((int32_t *) dst->op_params)[0];
const int32_t n_as = src0->ne[2];
const int64_t n_as = ne02;
const int64_t n_ids = ids->ne[0];
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
@ -1990,20 +2043,40 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[3] = src0->nb[2];
src0_row.nb[3] = nb02;
if (src1->ne[1] == 1) {
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;
GGML_ASSERT(row_id >= 0 && row_id < n_as);
dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
src0_row.data = src0_original + row_id*src0->nb[2];
src1_row.data = src1_original + i01*src1->nb[1];
dst_row.data = dst_original + i01*dst->nb[1];
if (ne12 == 1) {
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
const int64_t i11 = id % ne11;
const int64_t i12 = iid1;
const int64_t i1 = id;
const int64_t i2 = i12;
src0_row.data = src0_original + i02*nb02;
src1_row.data = src1_original + i11*nb11 + i12*nb12;
dst_row.data = dst_original + i1*nb1 + i2*nb2;
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
}
}
} else {
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
ggml_cuda_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
@ -2011,54 +2084,69 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
src1_row.data = src1_contiguous.get();
dst_row.data = dst_contiguous.get();
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
if (row_id_i != row_id) {
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
if (row_id_i != i02) {
continue;
}
GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
nb11, cudaMemcpyDeviceToDevice, stream));
num_src1_rows++;
}
}
if (num_src1_rows == 0) {
continue;
}
src0_row.data = src0_original + row_id*src0->nb[2];
ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
{
dim3 block_dims(std::min((unsigned int)ne10, 768u));
dim3 grid_dims(ids->ne[1], n_ids);
k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
src1_original, src1_contiguous.get(),
dev_cur_src1_row.get(), dev_row_mapping.get(),
ids_dev, i02, ids->nb[1], ids->nb[0],
ne11, ne10,
nb11, nb12);
CUDA_CHECK(cudaGetLastError());
}
src0_row.data = src0_original + i02*nb02;
GGML_ASSERT(nb11 == sizeof(float)*ne10);
GGML_ASSERT(nb1 == sizeof(float)*ne0);
src1_row.ne[1] = num_src1_rows;
dst_row.ne[1] = num_src1_rows;
src1_row.nb[1] = nb11;
src1_row.nb[2] = num_src1_rows*nb11;
src1_row.nb[3] = num_src1_rows*nb11;
dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
if (row_id_i != row_id) {
continue;
}
GGML_ASSERT(row_id >= 0 && row_id < n_as);
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
nb1, cudaMemcpyDeviceToDevice, stream));
num_src1_rows++;
{
dim3 block_dims(std::min((unsigned int)ne0, 768u));
dim3 grid_dims(num_src1_rows);
k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
dst_original, dst_contiguous.get(),
dev_row_mapping.get(),
ne0,
nb1, nb2);
CUDA_CHECK(cudaGetLastError());
}
}
}
@ -2487,7 +2575,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
const int min_batch_size = 32;
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
GGML_UNUSED(backend);
}

View File

@ -22,6 +22,7 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
@ -36,9 +37,9 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i_src0;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
@ -55,6 +56,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s00,*/ int s01, int s02, int s03,
/*int s10,*/ int s11, int s12, int s13) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -72,9 +74,9 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i_src0;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
@ -101,10 +103,14 @@ struct bin_bcast_cuda {
int nr[4] = { nr0, nr1, nr2, nr3 };
// collapse dimensions until first broadcast dimension
int64_t cne0[] = {ne0, ne1, ne2, ne3};
int64_t cne[] = {ne0, ne1, ne2, ne3};
int64_t cne0[] = {ne00, ne01, ne02, ne03};
int64_t cne1[] = {ne10, ne11, ne12, ne13};
size_t cnb0[] = {nb0, nb1, nb2, nb3};
size_t cnb[] = {nb0, nb1, nb2, nb3};
size_t cnb0[] = {nb00, nb01, nb02, nb03};
size_t cnb1[] = {nb10, nb11, nb12, nb13};
auto collapse = [](int64_t cne[]) {
cne[0] *= cne[1];
cne[1] = cne[2];
@ -118,32 +124,47 @@ struct bin_bcast_cuda {
cnb[3] *= cne[3];
};
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
}
if (i > 0) {
collapse_nb(cnb, cne);
collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1);
collapse(cne);
collapse(cne0);
collapse(cne1);
}
}
}
{
int64_t ne0 = cne0[0];
int64_t ne1 = cne0[1];
int64_t ne2 = cne0[2];
int64_t ne3 = cne0[3];
int64_t ne0 = cne[0];
int64_t ne1 = cne[1];
int64_t ne2 = cne[2];
int64_t ne3 = cne[3];
//int64_t ne00 = cne0[0]; GGML_UNUSED(ne00);
//int64_t ne01 = cne0[1]; GGML_UNUSED(ne01);
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
size_t nb0 = cnb0[0];
size_t nb1 = cnb0[1];
size_t nb2 = cnb0[2];
size_t nb3 = cnb0[3];
size_t nb0 = cnb[0];
size_t nb1 = cnb[1];
size_t nb2 = cnb[2];
size_t nb3 = cnb[3];
size_t nb00 = cnb0[0];
size_t nb01 = cnb0[1];
size_t nb02 = cnb0[2];
size_t nb03 = cnb0[3];
size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
@ -160,7 +181,28 @@ struct bin_bcast_cuda {
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
size_t s00 = nb00 / sizeof(src0_t);
size_t s01 = nb01 / sizeof(src0_t);
size_t s02 = nb02 / sizeof(src0_t);
size_t s03 = nb03 / sizeof(src0_t);
GGML_ASSERT(nb0 % sizeof(dst_t) == 0);
GGML_ASSERT(nb1 % sizeof(dst_t) == 0);
GGML_ASSERT(nb2 % sizeof(dst_t) == 0);
GGML_ASSERT(nb3 % sizeof(dst_t) == 0);
GGML_ASSERT(nb00 % sizeof(src0_t) == 0);
GGML_ASSERT(nb01 % sizeof(src0_t) == 0);
GGML_ASSERT(nb02 % sizeof(src0_t) == 0);
GGML_ASSERT(nb03 % sizeof(src0_t) == 0);
GGML_ASSERT(nb10 % sizeof(src1_t) == 0);
GGML_ASSERT(nb11 % sizeof(src1_t) == 0);
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s00 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
@ -179,13 +221,14 @@ struct bin_bcast_cuda {
);
if (block_nums.z > 65535) {
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
// this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00, */ s01, s02, s03,
/* s10, */ s11, s12, s13);
} else {
k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
@ -193,6 +236,7 @@ struct bin_bcast_cuda {
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00, */ s01, s02, s03,
/* s10, */ s11, s12, s13);
}
}

View File

@ -45,6 +45,8 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
vals[ix] = x0[ix];
}
__syncthreads();
#pragma unroll
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
if (need_check && i0 + iy + 2*threadIdx.x >= k) {

View File

@ -1732,15 +1732,10 @@ static enum ggml_status ggml_metal_graph_compute(
} break;
case GGML_OP_MUL_MAT_ID:
{
//GGML_ASSERT(ne00 == ne10);
//GGML_ASSERT(ne03 == ne13);
const int n_as = src0->ne[2];
// max size of the src1ids array in the kernel shared buffer
GGML_ASSERT(ne11 <= 4096);
// src2 = ids
const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
const int64_t ne20 = src2->ne[0];
const int64_t ne21 = src2->ne[1];
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
@ -1761,15 +1756,13 @@ static enum ggml_status ggml_metal_graph_compute(
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
int ne11_mm_min = n_as;
// ne20 = n_used_experts
// ne21 = n_rows
const int dst_rows = ne20*ne21;
const int dst_rows_min = n_as;
const int idx = ((int32_t *) dst->op_params)[0];
// batch size
GGML_ASSERT(ne21 == ne11); // ?
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
const uint r2 = 1;
const uint r3 = 1;
// max size of the rowids array in the kernel shared buffer
GGML_ASSERT(dst_rows <= 2048);
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
@ -1779,7 +1772,7 @@ static enum ggml_status ggml_metal_graph_compute(
// !!!
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00 % 32 == 0 && ne00 >= 64 &&
ne11 > ne11_mm_min) {
dst_rows > dst_rows_min) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
@ -1821,26 +1814,26 @@ static enum ggml_status ggml_metal_graph_compute(
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
[encoder setBytes:&idx length:sizeof(idx) atIndex:19];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
int nth1 = 1;
@ -1993,72 +1986,72 @@ static enum ggml_status ggml_metal_graph_compute(
GGML_ASSERT(ne00 >= nth0*nth1);
}
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
[encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
[encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
[encoder setBytes:&idx length:sizeof(idx) atIndex:23];
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
const int64_t _ne1 = 1;
const int tgz = dst_rows;
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
const int mem_size = 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif
}
else if (src0t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
const int64_t ny = (_ne1 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;

File diff suppressed because it is too large Load Diff

View File

@ -17752,7 +17752,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
GGML_CALL static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
const int min_batch_size = 32;
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
GGML_UNUSED(backend);
}

121
ggml.c
View File

@ -4578,21 +4578,32 @@ void ggml_mul_mat_set_prec(
// ggml_mul_mat_id
// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
// this will allow computing all the used experts in a single matrix multiplication
/*
c = ggml_mul_mat_id(ctx, as, b, ids);
as -> [cols, rows, n_expert]
ids -> [n_experts_used, n_tokens] (i32)
b -> [cols, n_expert_used, n_tokens]
c -> [cols, n_expert_used, n_tokens]
in b, n_experts_used can be broadcasted to match the n_expert_used of ids
c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
*/
struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
struct ggml_tensor * as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b) {
struct ggml_tensor * b,
struct ggml_tensor * ids) {
GGML_ASSERT(!ggml_is_transposed(as));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
GGML_ASSERT(b->ne[3] == 1); // b is 3d
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
bool is_node = false;
@ -4600,11 +4611,9 @@ struct ggml_tensor * ggml_mul_mat_id(
is_node = true;
}
const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
ggml_set_op_params_i32(result, 0, id);
result->op = GGML_OP_MUL_MAT_ID;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = as;
@ -11009,11 +11018,6 @@ static void ggml_compute_forward_mul_mat_id(
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
@ -11024,22 +11028,21 @@ static void ggml_compute_forward_mul_mat_id(
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// broadcast is not supported with mmid
assert(ne12 == 1);
assert(ne13 == 1);
// row groups
const int id = ggml_get_op_params_i32(dst, 0);
const int n_as = src0->ne[2];
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
struct mmid_row_mapping {
int32_t i1;
int32_t i2;
};
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
if (params->type == GGML_TASK_TYPE_INIT) {
if (ith != 0) {
@ -11065,13 +11068,18 @@ static void ggml_compute_forward_mul_mat_id(
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
// group rows by src0 matrix
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
GGML_ASSERT(row_id >= 0 && row_id < n_as);
MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
matrix_row_counts[row_id] += 1;
// group rows by src0 matrix
for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
for (int id = 0; id < n_ids; ++id) {
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
assert(i02 >= 0 && i02 < n_as);
MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
matrix_row_counts[i02] += 1;
}
}
return;
@ -11089,15 +11097,13 @@ static void ggml_compute_forward_mul_mat_id(
continue;
}
size_t src0_offset = cur_a*src0->nb[2];
const char * src0_cur = (const char *) src0->data + cur_a*nb02;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
const int64_t nr1 = cne1; // src1 rows
// distribute the thread work across the inner or outer loop based on which one is larger
@ -11116,13 +11122,11 @@ static void ggml_compute_forward_mul_mat_id(
const int64_t ir110 = dr1*ith1;
const int64_t ir111 = MIN(ir110 + dr1, nr1);
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
// threads with no work simply yield (not sure if it helps)
if (ir010 >= ir011 || ir110 >= ir111) {
sched_yield();
continue;
}
//if (ir010 >= ir011 || ir110 >= ir111) {
// sched_yield();
// continue;
//}
// block-tiling attempt
const int64_t blck_0 = 16;
@ -11134,20 +11138,16 @@ static void ggml_compute_forward_mul_mat_id(
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
const int64_t _i12 = ir1; // logical row index for this expert
// broadcast src0 into src1
//const int64_t i03 = i13/r3;
//const int64_t i02 = i12/r2;
struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
const int id = row_mapping.i1; // selected expert index
const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;
const int64_t i11 = id % ne11;
const int64_t i12 = row_mapping.i2; // row index in src1
const char * src0_row = (const char *) src0->data + src0_offset;
const int64_t i1 = id; // selected expert index
const int64_t i2 = i12; // row
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@ -11155,18 +11155,19 @@ static void ggml_compute_forward_mul_mat_id(
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));
? (i11 + i12*ne11)*row_size
: (i11*nb11 + i12*nb12));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
}
@ -18512,7 +18513,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
const int n_as = src0->ne[2];
cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts
cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
} break;
case GGML_OP_OUT_PROD:
{
@ -20938,12 +20939,12 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
ok = ok && cur != NULL;
ggml_set_name(cur, ctx->infos[i].name.data);
if (!ok) {
break;
}
ggml_set_name(cur, ctx->infos[i].name.data);
// point the data member to the appropriate location in the binary blob using the tensor infos
if (!params.no_alloc) {
//cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file

6
ggml.h
View File

@ -1161,13 +1161,11 @@ extern "C" {
enum ggml_prec prec);
// indirect matrix multiplication
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
GGML_API struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
struct ggml_tensor * as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b);
struct ggml_tensor * b,
struct ggml_tensor * ids);
// A: m columns, n rows,
// B: p columns, n rows,

223
llama.cpp
View File

@ -4495,6 +4495,13 @@ static bool llm_load_tensors(
auto & hparams = model.hparams;
#ifdef GGML_USE_SYCL
// disable MoE with SYCL until mul_mat_id is updated
if (hparams.n_expert > 0) {
n_gpu_layers = 0;
}
#endif
model.split_mode = split_mode;
model.main_gpu = main_gpu;
model.n_gpu_layers = n_gpu_layers;
@ -6099,6 +6106,100 @@ static struct ggml_tensor * llm_build_ffn(
return cur;
}
static struct ggml_tensor * llm_build_moe_ffn(
struct ggml_context * ctx,
struct ggml_tensor * cur,
struct ggml_tensor * gate_inp,
struct ggml_tensor * up_exps,
struct ggml_tensor * gate_exps,
struct ggml_tensor * down_exps,
int64_t n_expert,
int64_t n_expert_used,
llm_ffn_op_type type_op,
bool norm_w,
const llm_build_cb & cb,
int il) {
int64_t n_embd = cur->ne[0];
int64_t n_tokens = cur->ne[1];
ggml_tensor * logits = ggml_mul_mat(ctx, gate_inp, cur); // [n_expert, n_tokens]
cb(logits, "ffn_moe_logits", il);
ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_expert, n_tokens]
cb(probs, "ffn_moe_probs", il);
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_expert_used); // [n_expert_used, n_tokens]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
cb(selected_experts, "ffn_moe_topk", il);
ggml_tensor * weights = ggml_get_rows(ctx,
ggml_reshape_3d(ctx, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights", il);
if (norm_w) {
weights = ggml_reshape_2d(ctx, weights, n_expert_used, n_tokens);
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights); // [1, n_tokens]
cb(weights_sum, "ffn_moe_weights_sum", il);
weights = ggml_div(ctx, weights, weights_sum); // [n_expert_used, n_tokens]
cb(weights, "ffn_moe_weights_norm", il);
weights = ggml_reshape_3d(ctx, weights, 1, n_expert_used, n_tokens);
}
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(gate, "ffn_moe_gate", il);
switch (type_op) {
case LLM_FFN_SILU:
{
gate = ggml_silu(ctx, gate);
cb(gate, "ffn_moe_silu", il);
} break;
case LLM_FFN_GELU:
{
gate = ggml_gelu(ctx, gate);
cb(gate, "ffn_moe_gelu", il);
} break;
default:
GGML_ASSERT(false);
}
ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
cb(par, "ffn_moe_gate_par", il);
ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
experts = ggml_mul(ctx, experts, weights);
// aggregate experts
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert = ggml_view_2d(ctx, experts, n_embd, n_tokens,
experts->nb[2], i*experts->nb[1]);
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx, moe_out, cur_expert);
}
}
if (n_expert_used == 1) {
// avoid returning a non-contiguous tensor
moe_out = ggml_cont(ctx, moe_out);
}
return moe_out;
}
// if max_alibi_bias > 0 then apply ALiBi
static struct ggml_tensor * llm_build_kqv(
struct ggml_context * ctx,
@ -6642,7 +6743,15 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il);
cur = llm_build_moe_ffn(ctx0, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
cb, il);
cb(cur, "ffn_moe_out", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
@ -6674,80 +6783,6 @@ struct llm_build_context {
return gf;
}
// REVIEW: will be replaced by https://github.com/ggerganov/llama.cpp/pull/6505
ggml_tensor * build_moe_ffn(ggml_tensor * cur, int32_t n_tokens, llm_ffn_op_type type_op, bool norm_w, int il) {
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
cb(logits, "ffn_moe_logits", il);
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
cb(probs, "ffn_moe_probs", il);
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_expert_used); // [n_tokens, num_experts_per_tok]
cb(selected_experts->src[0], "ffn_moe_argsort", il);
ggml_tensor * weights = ggml_get_rows(ctx0,
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts);
cb(weights, "ffn_moe_weights", il);
weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens); // [n_tokens, num_experts_per_tok]
if (norm_w) {
ggml_tensor * weights_sum = ggml_sum_rows(ctx0, weights);
cb(weights_sum, "ffn_moe_weights_sum", il);
weights = ggml_div(ctx0, weights, weights_sum); // [n_tokens, num_experts_per_tok]
cb(weights, "ffn_moe_weights_norm", il);
}
// compute expert outputs
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_expert_used; ++i) {
ggml_tensor * cur_expert;
ggml_tensor * cur_up = ggml_mul_mat_id(ctx0, model.layers[il].ffn_up_exps, selected_experts, i, cur);
cb(cur_up, "ffn_moe_up", il);
ggml_tensor * gate = ggml_mul_mat_id(ctx0, model.layers[il].ffn_gate_exps, selected_experts, i, cur);
cb(gate, "ffn_moe_gate", il);
switch (type_op) {
case LLM_FFN_SILU:
{
gate = ggml_silu(ctx0, gate);
cb(gate, "ffn_moe_silu", il);
} break;
case LLM_FFN_GELU:
{
gate = ggml_gelu(ctx0, gate);
cb(gate, "ffn_moe_gelu", il);
} break;
default:
GGML_ASSERT(false);
}
cur_expert = ggml_mul(ctx0, cur_up, gate);
cb(cur_expert, "ffn_moe_gate_par", il);
cur_expert = ggml_mul_mat_id(ctx0, model.layers[il].ffn_down_exps, selected_experts, i, cur_expert); // [n_tokens, n_embd]
cb(cur_expert, "ffn_moe_down", il);
cur_expert = ggml_mul(ctx0, cur_expert,
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
cb(cur_expert, "ffn_moe_weighted", il);
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx0, moe_out, cur_expert);
cb(moe_out, "ffn_moe_out", il);
}
}
return moe_out;
}
struct ggml_cgraph * build_baichuan() {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
@ -7195,7 +7230,15 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
cur = build_moe_ffn(cur, n_tokens, LLM_FFN_GELU, true, il);
cur = llm_build_moe_ffn(ctx0, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
n_expert, n_expert_used,
LLM_FFN_GELU, true,
cb, il);
cb(cur, "ffn_moe_out", il);
// Grok
// if layer_out_norm is present then apply it before adding the input
@ -7207,7 +7250,6 @@ struct llm_build_context {
cb(cur, "layer_out_norm", il);
}
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
@ -7331,7 +7373,15 @@ struct llm_build_context {
LLM_NORM, cb, il);
cb(cur, "attn_out_norm", il);
cur = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, true, il);
cur = llm_build_moe_ffn(ctx0, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
n_expert, n_expert_used,
LLM_FFN_SILU, true,
cb, il);
cb(cur, "ffn_moe_out", il);
cur = ggml_add(ctx0, cur, ffn_inp);
cb(cur, "ffn_out", il);
@ -8502,12 +8552,6 @@ struct llm_build_context {
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
cb(Vcur, "Vcur", il);
// these nodes are added to the graph together so that they are not reordered
// by doing so, the number of splits in the graph is reduced
ggml_build_forward_expand(gf, Qcur);
ggml_build_forward_expand(gf, Kcur);
ggml_build_forward_expand(gf, Vcur);
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
@ -8658,7 +8702,16 @@ struct llm_build_context {
LLM_NORM_RMS, cb, il);
cb(cur, "ffn_norm", il);
ggml_tensor * moe_out = build_moe_ffn(cur, n_tokens, LLM_FFN_SILU, false, il);
ggml_tensor * moe_out =
llm_build_moe_ffn(ctx0, cur,
model.layers[il].ffn_gate_inp,
model.layers[il].ffn_up_exps,
model.layers[il].ffn_gate_exps,
model.layers[il].ffn_down_exps,
n_expert, n_expert_used,
LLM_FFN_SILU, false,
cb, il);
cb(cur, "ffn_moe_out", il);
// FFN shared expert
{

View File

@ -12,19 +12,7 @@ bench_args="${@:3}"
rm -f llama-bench.sqlite
backend="cpu"
if [[ "$OSTYPE" == "darwin"* ]]; then
backend="metal"
elif command -v nvcc &> /dev/null; then
backend="cuda"
fi
make_opts=""
if [[ "$backend" == "cuda" ]]; then
make_opts="LLAMA_CUDA=1"
fi
# to test a backend, call the script with the corresponding environment variable (e.g. LLAMA_CUDA=1 ./scripts/compare-commits.sh ...)
git checkout $1
make clean && make -j32 $make_opts llama-bench

View File

@ -101,7 +101,7 @@ static std::vector<float> tensor_to_float(const ggml_tensor * t) {
} else if (t->type == GGML_TYPE_I8) {
tv.push_back((float)*(int8_t *) &buf[i]);
} else if (quantized) {
tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type));
tt.to_float(&buf[i], vq.data(), bs);
tv.insert(tv.end(), vq.begin(), vq.end());
} else {
GGML_ASSERT(false);
@ -948,14 +948,14 @@ struct test_mul_mat_id : public test_case {
const ggml_type type_a;
const ggml_type type_b;
const int n_mats;
const int id;
const int n_used;
const bool b; // brodcast b matrix
const int64_t m;
const int64_t n;
const int64_t k;
const bool v; // view (non-contiguous ids)
std::string vars() override {
return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v);
return VARS_TO_STR8(type_a, type_b, n_mats, n_used, b, m, n, k);
}
double max_nmse_err() override {
@ -972,20 +972,22 @@ struct test_mul_mat_id : public test_case {
}
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int n_mats = 2, int id = 0,
int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false)
: type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
m(m), n(n), k(k), v(v) {}
int n_mats = 8, int n_used = 2, bool b = false,
int64_t m = 32, int64_t n = 32, int64_t k = 32)
: type_a(type_a), type_b(type_b), n_mats(n_mats), n_used(n_used), b(b),
m(m), n(n), k(k) {
GGML_ASSERT(n_used <= n_mats);
}
ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
ggml_tensor * mats = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
ggml_tensor * as = ggml_new_tensor_3d(ctx, type_a, k, m, n_mats);
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
if (v) {
ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
if (n_used != n_mats) {
ids = ggml_view_2d(ctx, ids, n_used, n, ids->nb[1], 0);
}
ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
ggml_tensor * out = ggml_mul_mat_id(ctx, mats, ids, v ? id/2 : id, b);
ggml_tensor * b = ggml_new_tensor_3d(ctx, type_b, k, this->b ? 1 : n_used, n);
ggml_tensor * out = ggml_mul_mat_id(ctx, as, b, ids);
return out;
}
@ -1611,7 +1613,6 @@ public:
}
};
// Llama
struct test_llama : public test_llm {
static constexpr float freq_base = 10000.0f;
@ -1875,6 +1876,25 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
};
const ggml_type base_types[] = {
GGML_TYPE_F32, GGML_TYPE_F16,
GGML_TYPE_Q4_0,
GGML_TYPE_Q4_K,
GGML_TYPE_IQ2_XXS
};
const ggml_type other_types[] = {
GGML_TYPE_Q4_1,
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
GGML_TYPE_Q8_0,
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
GGML_TYPE_Q5_K,
GGML_TYPE_Q6_K,
GGML_TYPE_IQ2_XS, GGML_TYPE_IQ2_S,
GGML_TYPE_IQ3_XXS, GGML_TYPE_IQ1_S, GGML_TYPE_IQ1_M,
GGML_TYPE_IQ4_NL, GGML_TYPE_IQ3_S, GGML_TYPE_IQ4_XS,
};
// unary ops
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
@ -1983,7 +2003,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
}
for (ggml_type type_a : all_types) {
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
@ -2003,6 +2023,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
}
for (ggml_type type_a : other_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
}
}
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 128, { 8, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 83, 2, 128, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 2, 64, { 8, 1}, {4, 1}));
@ -2010,13 +2036,32 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 64, 45, 128, { 8, 1}, {4, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 45, 64, { 8, 1}, {4, 1}));
for (ggml_type type_a : all_types) {
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {2, 4, 8}) {
for (int id = 0; id < n_mats; id++) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 1, 256, v));
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
for (int n_mats : {4, 8}) {
for (int n_used : {1, 2, 4}) {
for (bool b : {false, true}) {
for (int n : {1, 32}) {
int m = 512;
int k = 256;
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
}
}
}
}
}
}
for (ggml_type type_a : other_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4}) {
for (int n_used : {2}) {
for (bool b : {false}) {
for (int n : {1}) {
int m = 512;
int k = 256;
test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k));
}
}
}
}