From 916e95928b0757fb9e5e601ee5f325af29e5253e Mon Sep 17 00:00:00 2001 From: ngxson Date: Thu, 11 Jul 2024 00:30:07 +0200 Subject: [PATCH] llm_build_lora_mm_id --- src/llama.cpp | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 7ed80fcaf..30ecbb801 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7882,7 +7882,6 @@ static struct ggml_tensor * llm_build_lora_mm( if (lora == nullptr) { continue; } - // TODO: check if lora_a need transpose struct ggml_tensor * ab_cur = ggml_mul_mat( ctx0, lora->b, ggml_mul_mat(ctx0, lora->a, cur) @@ -7893,6 +7892,31 @@ static struct ggml_tensor * llm_build_lora_mm( return res; } +// do mat_mul_id, while optionally apply lora +static struct ggml_tensor * llm_build_lora_mm_id( + struct llama_context & lctx, + struct ggml_context * ctx0, + struct ggml_tensor * w, // struct ggml_tensor * as + struct ggml_tensor * cur, // struct ggml_tensor * b + struct ggml_tensor * ids) { + struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids); + for (auto & it : lctx.lora_adapters) { + struct llama_lora_weight * lora = it.first->get_weight(w); + float scale = it.second; + if (lora == nullptr) { + continue; + } + struct ggml_tensor * ab_cur = ggml_mul_mat_id( + ctx0, lora->b, + ggml_mul_mat_id(ctx0, lora->a, cur, ids), + ids + ); + ab_cur = ggml_scale_inplace(ctx0, ab_cur, scale); + res = ggml_add(ctx0, res, ab_cur); + } + return res; +} + static struct ggml_tensor * llm_build_norm( struct ggml_context * ctx, struct ggml_tensor * cur, @@ -8103,10 +8127,10 @@ static struct ggml_tensor * llm_build_moe_ffn( } cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens); - ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(up, "ffn_moe_up", il); - ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] cb(gate, "ffn_moe_gate", il); switch (type_op) { @@ -8127,7 +8151,7 @@ static struct ggml_tensor * llm_build_moe_ffn( ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens] cb(par, "ffn_moe_gate_par", il); - ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] + ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens] cb(experts, "ffn_moe_down", il); experts = ggml_mul(ctx, experts, weights);