llama : remove extra ; + deduplicate gate_b logic

This commit is contained in:
Georgi Gerganov 2023-10-31 16:28:09 +02:00
parent fc5a26aade
commit 2073347e3b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -3135,7 +3135,7 @@ static void llm_build_k_shift(
case LLM_ROPE: rope_type = 0; break;
case LLM_ROPE_NEOX: rope_type = 2; break;
case LLM_ROPE_GLM: rope_type = 4; break;
};
}
for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp =
@ -3207,7 +3207,8 @@ static struct ggml_tensor * llm_build_norm(
switch (type) {
case LLM_NORM: cur = ggml_norm (ctx, cur, eps); break;
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, eps); break;
};
}
if (mw || mb) {
cb(cur, "norm", il);
}
@ -3265,23 +3266,18 @@ static struct ggml_tensor * llm_build_ffn(
{
cur = ggml_mul_mat(ctx, gate, tmp);
cb(cur, "ffn_gate", il);
if (gate_b) {
cur = ggml_add(ctx, cur, gate_b);
cb(cur, "ffn_gate_b", il);
}
} break;
case LLM_FFN_PAR:
{
cur = ggml_mul_mat(ctx, gate, cur);
cb(cur, "ffn_gate", il);
} break;
}
if (gate_b) {
cur = ggml_add(ctx, cur, gate_b);
cb(cur, "ffn_gate_b", il);
}
} break;
};
} else {
cur = tmp;
}
@ -3310,7 +3306,7 @@ static struct ggml_tensor * llm_build_ffn(
cur = ggml_sqr(ctx, cur);
cb(cur, "ffn_sqr(relu)", il);
} break;
};
}
if (type_gate == LLM_FFN_PAR) {
cur = ggml_mul(ctx, cur, tmp);
@ -4098,6 +4094,7 @@ static struct ggml_cgraph * llm_build_persimmon(
const bool do_rope_shift = worst_case || kv_self.has_shift;
auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data,