mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
Fixed incorrectly applying RMS norm twice (#1925)
This commit is contained in:
parent
8596af4277
commit
0ede372a51
@ -1657,11 +1657,7 @@ static bool llama_eval_internal(
|
||||
{
|
||||
cur = ggml_rms_norm(ctx0, inpL);
|
||||
offload_func_nr(cur);
|
||||
ggml_set_name(cur, "rms_norm_inpL");
|
||||
|
||||
cur = ggml_rms_norm(ctx0, cur);
|
||||
offload_func_nr(cur);
|
||||
ggml_set_name(cur, "rms_norm_after");
|
||||
ggml_set_name(cur, "rms_norm_2");
|
||||
|
||||
// cur = cur*norm(broadcasted)
|
||||
cur = ggml_mul(ctx0, cur, model.norm);
|
||||
|
Loading…
Reference in New Issue
Block a user