mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
test-backend-ops : add moe test
This commit is contained in:
parent
e640cbe055
commit
cefebb3660
@ -51,7 +51,7 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
|
|||||||
t.join();
|
t.join();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tensor->type == GGML_TYPE_F32) {
|
if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
|
||||||
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
|
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
|
||||||
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
|
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
|
||||||
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
|
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
|
||||||
@ -233,6 +233,10 @@ static bool ggml_is_view_op(enum ggml_op op) {
|
|||||||
struct test_case {
|
struct test_case {
|
||||||
virtual ~test_case() {}
|
virtual ~test_case() {}
|
||||||
|
|
||||||
|
virtual std::string op_desc(ggml_tensor * t) {
|
||||||
|
return ggml_op_desc(t);
|
||||||
|
}
|
||||||
|
|
||||||
virtual std::string vars() {
|
virtual std::string vars() {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
@ -240,7 +244,7 @@ struct test_case {
|
|||||||
virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
|
virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
|
||||||
|
|
||||||
virtual double max_nmse_err() {
|
virtual double max_nmse_err() {
|
||||||
return 1e-6;
|
return 1e-7;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual void initialize_tensors(ggml_context * ctx) {
|
virtual void initialize_tensors(ggml_context * ctx) {
|
||||||
@ -270,13 +274,13 @@ struct test_case {
|
|||||||
|
|
||||||
ggml_tensor * out = build_graph(ctx);
|
ggml_tensor * out = build_graph(ctx);
|
||||||
|
|
||||||
if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
|
if (op_name != nullptr && op_desc(out) != op_name) {
|
||||||
//printf(" %s: skipping\n", ggml_op_desc(out));
|
//printf(" %s: skipping\n", op_desc(out).c_str());
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
printf(" %s(%s): ", ggml_op_desc(out), vars().c_str());
|
printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
// check if backends support op
|
// check if backends support op
|
||||||
@ -317,7 +321,7 @@ struct test_case {
|
|||||||
for (size_t i = 0; i < f1.size(); i++) {
|
for (size_t i = 0; i < f1.size(); i++) {
|
||||||
// check for nans
|
// check for nans
|
||||||
if (std::isnan(f1[i]) || std::isnan(f2[i])) {
|
if (std::isnan(f1[i]) || std::isnan(f2[i])) {
|
||||||
printf("NaN at index %zu ", i);
|
printf("[%s] NaN at index %zu ", ggml_op_desc(t1), i);
|
||||||
ud->ok = false;
|
ud->ok = false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
@ -325,21 +329,32 @@ struct test_case {
|
|||||||
if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
|
if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
|
||||||
if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
|
if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
|
||||||
if (std::signbit(f1[i]) != std::signbit(f2[i])) {
|
if (std::signbit(f1[i]) != std::signbit(f2[i])) {
|
||||||
printf("inf sign mismatch: %f %f ", f1[i], f2[i]);
|
printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
|
||||||
ud->ok = false;
|
ud->ok = false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
printf("inf mismatch: %f %f ", f1[i], f2[i]);
|
printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
|
||||||
ud->ok = false;
|
ud->ok = false;
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//if (t1->op == GGML_OP_SOFT_MAX) {
|
||||||
|
// printf("[%s] ", ggml_op_desc(t1));
|
||||||
|
// for (int i = 0; i < f1.size(); i++) {
|
||||||
|
// printf("(%x, %x) ", *(uint32_t*)&f1[i], *(uint32_t*)&f2[i]);
|
||||||
|
// }
|
||||||
|
// printf("\n");
|
||||||
|
//}
|
||||||
double err = nmse(f1.data(), f2.data(), f1.size());
|
double err = nmse(f1.data(), f2.data(), f1.size());
|
||||||
if (err > ud->max_err) {
|
if (err > ud->max_err) {
|
||||||
printf("NMSE = %f ", err);
|
printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
|
||||||
|
//for (int i = 0; i < f1.size(); i++) {
|
||||||
|
// printf("(%f, %f) ", f1[i], f2[i]);
|
||||||
|
//}
|
||||||
|
//printf("\n");
|
||||||
ud->ok = false;
|
ud->ok = false;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
@ -374,13 +389,13 @@ struct test_case {
|
|||||||
|
|
||||||
ggml_tensor * out = build_graph(ctx);
|
ggml_tensor * out = build_graph(ctx);
|
||||||
|
|
||||||
if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
|
if (op_name != nullptr && op_desc(out) != op_name) {
|
||||||
//printf(" %s: skipping\n", ggml_op_desc(out));
|
//printf(" %s: skipping\n", op_desc(out).c_str());
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
int len = printf(" %s(%s): ", ggml_op_desc(out), vars().c_str());
|
int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
// check if backends support op
|
// check if backends support op
|
||||||
@ -1122,6 +1137,91 @@ struct test_sum_rows : public test_case {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct test_moe : public test_case {
|
||||||
|
const int n_experts = 8;
|
||||||
|
const int n_experts_per_tok = 2;
|
||||||
|
const int n_tokens = 1;
|
||||||
|
const int n_embd = 4096;
|
||||||
|
const int n_ff = 14336;
|
||||||
|
|
||||||
|
std::string op_desc(ggml_tensor * t) override {
|
||||||
|
return "MOE";
|
||||||
|
GGML_UNUSED(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_moe() {
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);
|
||||||
|
|
||||||
|
std::vector<ggml_tensor *> ffn_up_exp(n_experts);
|
||||||
|
std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
|
||||||
|
std::vector<ggml_tensor *> ffn_down_exp(n_experts);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_experts; ++i) {
|
||||||
|
ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
|
||||||
|
ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
|
||||||
|
ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
|
||||||
|
|
||||||
|
ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur); // [n_tokens, num_experts]
|
||||||
|
ggml_tensor * probs = ggml_soft_max(ctx, logits); // [n_tokens, num_experts]
|
||||||
|
|
||||||
|
// select experts
|
||||||
|
ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
|
||||||
|
|
||||||
|
ggml_tensor * weights = ggml_get_rows(ctx,
|
||||||
|
ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
|
||||||
|
printf("get rows args %ld %ld %ld %ld, %ld %ld %ld %ld\n",
|
||||||
|
weights->src[0]->ne[0], weights->src[0]->ne[1], weights->src[0]->ne[2], weights->src[0]->ne[3],
|
||||||
|
weights->src[1]->ne[0], weights->src[1]->ne[1], weights->src[1]->ne[2], weights->src[1]->ne[3]);
|
||||||
|
|
||||||
|
|
||||||
|
weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens); // [n_tokens, num_experts_per_tok]
|
||||||
|
|
||||||
|
ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
|
||||||
|
|
||||||
|
weights = ggml_div(ctx, weights, weights_sum); // [n_tokens, num_experts_per_tok]
|
||||||
|
|
||||||
|
// compute expert outputs
|
||||||
|
ggml_tensor * moe_out = nullptr;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_experts_per_tok; ++i) {
|
||||||
|
ggml_tensor * cur_expert;
|
||||||
|
|
||||||
|
ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);
|
||||||
|
|
||||||
|
ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);
|
||||||
|
|
||||||
|
cur_gate = ggml_silu(ctx, cur_gate);
|
||||||
|
|
||||||
|
cur_expert = ggml_mul(ctx, cur_up, cur_gate); // [n_tokens, n_embd]
|
||||||
|
|
||||||
|
cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert); // [n_tokens, n_embd]
|
||||||
|
|
||||||
|
cur_expert = ggml_mul(ctx, cur_expert,
|
||||||
|
ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
|
||||||
|
|
||||||
|
if (i == 0) {
|
||||||
|
moe_out = cur_expert;
|
||||||
|
} else {
|
||||||
|
moe_out = ggml_add(ctx, moe_out, cur_expert);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cur = moe_out;
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
enum test_mode {
|
enum test_mode {
|
||||||
MODE_TEST,
|
MODE_TEST,
|
||||||
MODE_PERF,
|
MODE_PERF,
|
||||||
@ -1140,11 +1240,14 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
GGML_TYPE_Q6_K
|
GGML_TYPE_Q6_K
|
||||||
};
|
};
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_moe());
|
||||||
|
|
||||||
// unary ops
|
// unary ops
|
||||||
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
|
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
|
||||||
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
|
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
|
||||||
for (ggml_type type : all_types) {
|
for (ggml_type type : all_types) {
|
||||||
for (int b : {1, 7}) {
|
for (int b : {1, 7}) {
|
||||||
for (bool v : {false, true}) {
|
for (bool v : {false, true}) {
|
||||||
@ -1265,6 +1368,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
|||||||
test_cases.emplace_back(new test_concat());
|
test_cases.emplace_back(new test_concat());
|
||||||
|
|
||||||
for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
|
for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
|
||||||
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
|
||||||
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user