llm_build_lora_mm_id

This commit is contained in:
ngxson 2024-07-11 00:30:07 +02:00
parent e68344cb06
commit 916e95928b

View File

@ -7882,7 +7882,6 @@ static struct ggml_tensor * llm_build_lora_mm(
if (lora == nullptr) {
continue;
}
// TODO: check if lora_a need transpose
struct ggml_tensor * ab_cur = ggml_mul_mat(
ctx0, lora->b,
ggml_mul_mat(ctx0, lora->a, cur)
@ -7893,6 +7892,31 @@ static struct ggml_tensor * llm_build_lora_mm(
return res;
}
// do mat_mul_id, while optionally apply lora
static struct ggml_tensor * llm_build_lora_mm_id(
struct llama_context & lctx,
struct ggml_context * ctx0,
struct ggml_tensor * w, // struct ggml_tensor * as
struct ggml_tensor * cur, // struct ggml_tensor * b
struct ggml_tensor * ids) {
struct ggml_tensor * res = ggml_mul_mat_id(ctx0, w, cur, ids);
for (auto & it : lctx.lora_adapters) {
struct llama_lora_weight * lora = it.first->get_weight(w);
float scale = it.second;
if (lora == nullptr) {
continue;
}
struct ggml_tensor * ab_cur = ggml_mul_mat_id(
ctx0, lora->b,
ggml_mul_mat_id(ctx0, lora->a, cur, ids),
ids
);
ab_cur = ggml_scale_inplace(ctx0, ab_cur, scale);
res = ggml_add(ctx0, res, ab_cur);
}
return res;
}
static struct ggml_tensor * llm_build_norm(
struct ggml_context * ctx,
struct ggml_tensor * cur,
@ -8103,10 +8127,10 @@ static struct ggml_tensor * llm_build_moe_ffn(
}
cur = ggml_reshape_3d(ctx, cur, n_embd, 1, n_tokens);
ggml_tensor * up = ggml_mul_mat_id(ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
ggml_tensor * up = llm_build_lora_mm_id(lctx, ctx, up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(up, "ffn_moe_up", il);
ggml_tensor * gate = ggml_mul_mat_id(ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
ggml_tensor * gate = llm_build_lora_mm_id(lctx, ctx, gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
cb(gate, "ffn_moe_gate", il);
switch (type_op) {
@ -8127,7 +8151,7 @@ static struct ggml_tensor * llm_build_moe_ffn(
ggml_tensor * par = ggml_mul(ctx, up, gate); // [n_ff, n_expert_used, n_tokens]
cb(par, "ffn_moe_gate_par", il);
ggml_tensor * experts = ggml_mul_mat_id(ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
ggml_tensor * experts = llm_build_lora_mm_id(lctx, ctx, down_exps, par, selected_experts); // [n_embd, n_expert_used, n_tokens]
cb(experts, "ffn_moe_down", il);
experts = ggml_mul(ctx, experts, weights);