llama : remove extra ; + deduplicate gate_b logic

This commit is contained in:
Georgi Gerganov 2023-10-31 16:28:09 +02:00
parent fc5a26aade
commit 2073347e3b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -3135,7 +3135,7 @@ static void llm_build_k_shift(
case LLM_ROPE: rope_type = 0; break; case LLM_ROPE: rope_type = 0; break;
case LLM_ROPE_NEOX: rope_type = 2; break; case LLM_ROPE_NEOX: rope_type = 2; break;
case LLM_ROPE_GLM: rope_type = 4; break; case LLM_ROPE_GLM: rope_type = 4; break;
}; }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
struct ggml_tensor * tmp = struct ggml_tensor * tmp =
@ -3207,7 +3207,8 @@ static struct ggml_tensor * llm_build_norm(
switch (type) { switch (type) {
case LLM_NORM: cur = ggml_norm (ctx, cur, eps); break; case LLM_NORM: cur = ggml_norm (ctx, cur, eps); break;
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, eps); break; case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, eps); break;
}; }
if (mw || mb) { if (mw || mb) {
cb(cur, "norm", il); cb(cur, "norm", il);
} }
@ -3265,23 +3266,18 @@ static struct ggml_tensor * llm_build_ffn(
{ {
cur = ggml_mul_mat(ctx, gate, tmp); cur = ggml_mul_mat(ctx, gate, tmp);
cb(cur, "ffn_gate", il); cb(cur, "ffn_gate", il);
if (gate_b) {
cur = ggml_add(ctx, cur, gate_b);
cb(cur, "ffn_gate_b", il);
}
} break; } break;
case LLM_FFN_PAR: case LLM_FFN_PAR:
{ {
cur = ggml_mul_mat(ctx, gate, cur); cur = ggml_mul_mat(ctx, gate, cur);
cb(cur, "ffn_gate", il); cb(cur, "ffn_gate", il);
if (gate_b) {
cur = ggml_add(ctx, cur, gate_b);
cb(cur, "ffn_gate_b", il);
}
} break; } break;
}; }
if (gate_b) {
cur = ggml_add(ctx, cur, gate_b);
cb(cur, "ffn_gate_b", il);
}
} else { } else {
cur = tmp; cur = tmp;
} }
@ -3310,7 +3306,7 @@ static struct ggml_tensor * llm_build_ffn(
cur = ggml_sqr(ctx, cur); cur = ggml_sqr(ctx, cur);
cb(cur, "ffn_sqr(relu)", il); cb(cur, "ffn_sqr(relu)", il);
} break; } break;
}; }
if (type_gate == LLM_FFN_PAR) { if (type_gate == LLM_FFN_PAR) {
cur = ggml_mul(ctx, cur, tmp); cur = ggml_mul(ctx, cur, tmp);
@ -4098,6 +4094,7 @@ static struct ggml_cgraph * llm_build_persimmon(
const bool do_rope_shift = worst_case || kv_self.has_shift; const bool do_rope_shift = worst_case || kv_self.has_shift;
auto & buf_compute = lctx.buf_compute; auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size, /*.mem_size =*/ buf_compute.size,
/*.mem_buffer =*/ buf_compute.data, /*.mem_buffer =*/ buf_compute.data,