test-backend-ops : add moe test

This commit is contained in:
slaren 2023-12-10 13:11:39 +01:00
parent e640cbe055
commit cefebb3660

View File

@ -51,7 +51,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
t.join(); t.join();
} }
if (tensor->type == GGML_TYPE_F32) { if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) { } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
@ -233,6 +233,10 @@ static bool ggml_is_view_op(enum ggml_op op) {
struct test_case { struct test_case {
virtual ~test_case() {} virtual ~test_case() {}
virtual std::string op_desc(ggml_tensor * t) {
return ggml_op_desc(t);
}
virtual std::string vars() { virtual std::string vars() {
return ""; return "";
} }
@ -240,7 +244,7 @@ struct test_case {
virtual ggml_tensor * build_graph(ggml_context * ctx) = 0; virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
virtual double max_nmse_err() { virtual double max_nmse_err() {
return 1e-6; return 1e-7;
} }
virtual void initialize_tensors(ggml_context * ctx) { virtual void initialize_tensors(ggml_context * ctx) {
@ -270,13 +274,13 @@ struct test_case {
ggml_tensor * out = build_graph(ctx); ggml_tensor * out = build_graph(ctx);
if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) { if (op_name != nullptr && op_desc(out) != op_name) {
//printf(" %s: skipping\n", ggml_op_desc(out)); //printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx); ggml_free(ctx);
return true; return true;
} }
printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout); fflush(stdout);
// check if backends support op // check if backends support op
@ -317,7 +321,7 @@ struct test_case {
for (size_t i = 0; i < f1.size(); i++) { for (size_t i = 0; i < f1.size(); i++) {
// check for nans // check for nans
if (std::isnan(f1[i]) || std::isnan(f2[i])) { if (std::isnan(f1[i]) || std::isnan(f2[i])) {
printf("NaN at index %zu ", i); printf("[%s] NaN at index %zu ", ggml_op_desc(t1), i);
ud->ok = false; ud->ok = false;
return true; return true;
} }
@ -325,21 +329,32 @@ struct test_case {
if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) { if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
if (std::signbit(f1[i]) != std::signbit(f2[i])) { if (std::signbit(f1[i]) != std::signbit(f2[i])) {
printf("inf sign mismatch: %f %f ", f1[i], f2[i]); printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
ud->ok = false; ud->ok = false;
return true; return true;
} }
} else { } else {
printf("inf mismatch: %f %f ", f1[i], f2[i]); printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
ud->ok = false; ud->ok = false;
return true; return true;
} }
} }
} }
//if (t1->op == GGML_OP_SOFT_MAX) {
// printf("[%s] ", ggml_op_desc(t1));
// for (int i = 0; i < f1.size(); i++) {
// printf("(%x, %x) ", *(uint32_t*)&f1[i], *(uint32_t*)&f2[i]);
// }
// printf("\n");
//}
double err = nmse(f1.data(), f2.data(), f1.size()); double err = nmse(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) { if (err > ud->max_err) {
printf("NMSE = %f ", err); printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
//for (int i = 0; i < f1.size(); i++) {
// printf("(%f, %f) ", f1[i], f2[i]);
//}
//printf("\n");
ud->ok = false; ud->ok = false;
} }
return true; return true;
@ -374,13 +389,13 @@ struct test_case {
ggml_tensor * out = build_graph(ctx); ggml_tensor * out = build_graph(ctx);
if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) { if (op_name != nullptr && op_desc(out) != op_name) {
//printf(" %s: skipping\n", ggml_op_desc(out)); //printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx); ggml_free(ctx);
return true; return true;
} }
int len = printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout); fflush(stdout);
// check if backends support op // check if backends support op
@ -1122,6 +1137,91 @@ struct test_sum_rows : public test_case {
} }
}; };
struct test_moe : public test_case {
const int n_experts = 8;
const int n_experts_per_tok = 2;
const int n_tokens = 1;
const int n_embd = 4096;
const int n_ff = 14336;
std::string op_desc(ggml_tensor * t) override {
return "MOE";
GGML_UNUSED(t);
}
std::string vars() override {
return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
}
test_moe() {
}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);
std::vector<ggml_tensor *> ffn_up_exp(n_experts);
std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
std::vector<ggml_tensor *> ffn_down_exp(n_experts);
for (int i = 0; i < n_experts; ++i) {
ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
}
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); // [n_tokens, num_experts]
ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_tokens, num_experts]
// select experts
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
ggml_tensor * weights = ggml_get_rows(ctx,
ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
printf("get rows args %ld %ld %ld %ld, %ld %ld %ld %ld\n",
weights->src[0]->ne[0], weights->src[0]->ne[1], weights->src[0]->ne[2], weights->src[0]->ne[3],
weights->src[1]->ne[0], weights->src[1]->ne[1], weights->src[1]->ne[2], weights->src[1]->ne[3]);
weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
weights = ggml_div(ctx, weights, weights_sum); // [n_tokens, num_experts_per_tok]
// compute expert outputs
ggml_tensor * moe_out = nullptr;
for (int i = 0; i < n_experts_per_tok; ++i) {
ggml_tensor * cur_expert;
ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);
cur_gate = ggml_silu(ctx, cur_gate);
cur_expert = ggml_mul(ctx, cur_up, cur_gate); // [n_tokens, n_embd]
cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]
cur_expert = ggml_mul(ctx, cur_expert,
ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
if (i == 0) {
moe_out = cur_expert;
} else {
moe_out = ggml_add(ctx, moe_out, cur_expert);
}
}
cur = moe_out;
return cur;
}
};
enum test_mode { enum test_mode {
MODE_TEST, MODE_TEST,
MODE_PERF, MODE_PERF,
@ -1140,11 +1240,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
GGML_TYPE_Q6_K GGML_TYPE_Q6_K
}; };
test_cases.emplace_back(new test_moe());
// unary ops // unary ops
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) { for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
test_cases.emplace_back(new test_unary((ggml_unary_op) op)); test_cases.emplace_back(new test_unary((ggml_unary_op) op));
} }
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
for (ggml_type type : all_types) { for (ggml_type type : all_types) {
for (int b : {1, 7}) { for (int b : {1, 7}) {
for (bool v : {false, true}) { for (bool v : {false, true}) {
@ -1265,6 +1368,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
test_cases.emplace_back(new test_concat()); test_cases.emplace_back(new test_concat());
for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) { for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
} }