diff --git a/llama.cpp b/llama.cpp index 741b0f0bd..e4d1a530a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4223,7 +4223,7 @@ struct llm_build_context { cb(ffn_inp, "ffn_inp", il); // feed-forward network - { + if (model.layers[il].ffn_gate_inp == nullptr) { cur = llm_build_norm(ctx0, ffn_inp, hparams, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, cb, il); @@ -4235,6 +4235,51 @@ struct llm_build_context { model.layers[il].ffn_down, NULL, LLM_FFN_SILU, LLM_FFN_PAR, cb, il); cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + const int n_experts_per_tok = 2; // TODO: param + + ggml_tensor * logits = ggml_mul_mat(ctx0, model.layers[il].ffn_gate_inp, cur); // [n_tokens, num_experts] + ggml_tensor * probs = ggml_soft_max(ctx0, logits); // [n_tokens, num_experts] + + // select experts + ggml_tensor * selected_experts = ggml_top_k(ctx0, probs, n_experts_per_tok); // [n_tokens, num_experts_per_tok] + ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [n_tokens, num_experts_per_tok, 1] + weights = ggml_div(ctx0, weights, ggml_sum_rows(ctx0, weights)); // [n_tokens, num_experts_per_tok, 1] + + // compute expert outputs + ggml_tensor * moe_out; + + for (int i = 0; i < n_experts_per_tok; ++i) { + ggml_tensor * cur_expert; + + // TODO: fix + ggml_tensor ** ffn_up_exp = (ggml_tensor **) model.layers[il].ffn_up_exp; + ggml_tensor ** ffn_gate_exp = (ggml_tensor **) model.layers[il].ffn_gate_exp; + ggml_tensor ** ffn_down_exp = (ggml_tensor **) model.layers[il].ffn_down_exp; + + cur_expert = ggml_mul(ctx0, + ggml_mul_mat_id(ctx0, ffn_up_exp, selected_experts, i, cur), + ggml_silu(ctx0, + ggml_mul_mat_id(ctx0, ffn_gate_exp, selected_experts, i, cur))); // [n_tokens, n_embd] + + cur_expert = ggml_mul_mat_id(ctx0, ffn_down_exp, selected_experts, i, cur_expert); // [n_tokens, n_embd] + cur_expert = ggml_mul(ctx0, cur, + ggml_view_2d(ctx0, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0])); + + if (i == 0) { + moe_out = cur_expert; + } else { + moe_out = ggml_add(ctx0, moe_out, cur_expert); + } + } + + cur = moe_out; } cur = ggml_add(ctx0, cur, ffn_inp);