mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
llama : update graph to support MoE
This commit is contained in:
parent
861cd67899
commit
aedfad120a
47
llama.cpp
47
llama.cpp
@ -4223,7 +4223,7 @@ struct llm_build_context {
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network
|
||||
{
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
@ -4235,6 +4235,51 @@ struct llm_build_context {
|
||||
model.layers[il].ffn_down, NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
cur = llm_build_norm(ctx0, ffn_inp, hparams,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, cb, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
const int n_experts_per_tok = 2; // TODO: param
|
||||
|
||||
ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts]
|
||||
ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts]
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok]
|
||||
ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [n_tokens, num_experts_per_tok, 1]
|
||||
weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok, 1]
|
||||
|
||||
// compute expert outputs
|
||||
ggml_tensor * moe_out;
|
||||
|
||||
for (int i = 0; i < n_experts_per_tok; ++i) {
|
||||
ggml_tensor * cur_expert;
|
||||
|
||||
// TODO: fix
|
||||
ggml_tensor ** ffn_up_exp = (ggml_tensor **) model.layers[il].ffn_up_exp;
|
||||
ggml_tensor ** ffn_gate_exp = (ggml_tensor **) model.layers[il].ffn_gate_exp;
|
||||
ggml_tensor ** ffn_down_exp = (ggml_tensor **) model.layers[il].ffn_down_exp;
|
||||
|
||||
cur_expert = ggml_mul(ctx0,
|
||||
ggml_mul_mat_id(ctx0, ffn_up_exp, selected_experts, i, cur),
|
||||
ggml_silu(ctx0,
|
||||
ggml_mul_mat_id(ctx0, ffn_gate_exp, selected_experts, i, cur))); // [n_tokens, n_embd]
|
||||
|
||||
cur_expert = ggml_mul_mat_id(ctx0, ffn_down_exp, selected_experts, i, cur_expert); // [n_tokens, n_embd]
|
||||
cur_expert = ggml_mul(ctx0, cur,
|
||||
ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
|
||||
|
||||
if (i == 0) {
|
||||
moe_out = cur_expert;
|
||||
} else {
|
||||
moe_out = ggml_add(ctx0, moe_out, cur_expert);
|
||||
}
|
||||
}
|
||||
|
||||
cur = moe_out;
|
||||
}
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
Loading…
Reference in New Issue
Block a user