From ba07b35d6f83e096ade7ed2ee1aedbc55b5200d8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 10 Dec 2024 20:33:29 +0200 Subject: [PATCH] group norm --- src/llama.cpp | 44 +++++++++++++++++++++++++++++++++----------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 7c4ad5691..254e9f868 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1629,16 +1629,16 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_CONV_NEXT_GAMMA, "conv_next.gamma" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, - { LLM_TENSOR_POS_NET_CONV1, "pos_net.conv1" }, - { LLM_TENSOR_POS_NET_CONV2, "pos_net.conv2" }, + { LLM_TENSOR_POS_NET_CONV1, "pos_net.%d.conv1" }, + { LLM_TENSOR_POS_NET_CONV2, "pos_net.%d.conv2" }, { LLM_TENSOR_POS_NET_NORM, "pos_net.norm" }, - { LLM_TENSOR_POS_NET_NORM1, "pos_net.norm1" }, - { LLM_TENSOR_POS_NET_NORM2, "pos_net.norm2" }, - { LLM_TENSOR_POS_NET_ATTN_NORM, "pos_net.attn_norm" }, - { LLM_TENSOR_POS_NET_ATTN_Q, "pos_net.attn_q" }, - { LLM_TENSOR_POS_NET_ATTN_K, "pos_net.attn_k" }, - { LLM_TENSOR_POS_NET_ATTN_V, "pos_net.attn_v" }, - { LLM_TENSOR_POS_NET_ATTN_OUT, "pos_net.attn_output" }, + { LLM_TENSOR_POS_NET_NORM1, "pos_net.%d.norm1" }, + { LLM_TENSOR_POS_NET_NORM2, "pos_net.%d.norm2" }, + { LLM_TENSOR_POS_NET_ATTN_NORM, "pos_net.%d.attn_norm" }, + { LLM_TENSOR_POS_NET_ATTN_Q, "pos_net.%d.attn_q" }, + { LLM_TENSOR_POS_NET_ATTN_K, "pos_net.%d.attn_k" }, + { LLM_TENSOR_POS_NET_ATTN_V, "pos_net.%d.attn_v" }, + { LLM_TENSOR_POS_NET_ATTN_OUT, "pos_net.%d.attn_output" }, { LLM_TENSOR_HANN_WINDOW, "hann_window" }, }, }, @@ -3054,9 +3054,13 @@ struct llama_model { struct ggml_tensor * cls_out = nullptr; struct ggml_tensor * cls_out_b = nullptr; + // outetts vocoder struct ggml_tensor * conv_1d = nullptr; struct ggml_tensor * conv_1d_b = nullptr; + struct ggml_tensor * posnet_0_norm1 = nullptr; + struct ggml_tensor * posnet_0_norm1_b = nullptr; + std::vector layers; // gguf metadata @@ -7357,6 +7361,7 @@ static const std::map llm_tensor_info_mapping = { // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, {LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}}, + {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; // checks if the weight tensor can be used with the specified buffer type and device @@ -9438,6 +9443,9 @@ static bool llm_load_tensors( model.conv_1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, n_embd, 768}, 0); model.conv_1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {768}, 0); + model.posnet_0_norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", 0), {768}, 0); + model.posnet_0_norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", 0), {768}, 0); + // output model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {768}, 0); model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {768, 1282}, llama_model_loader::TENSOR_NOT_REQUIRED); @@ -9661,6 +9669,7 @@ enum llm_ffn_gate_type { enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, + LLM_NORM_GROUP, }; static struct ggml_tensor * llm_build_inp_embd( @@ -9802,8 +9811,15 @@ static struct ggml_tensor * llm_build_norm( const llm_build_cb & cb, int il) { switch (type) { - case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break; - case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break; + case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break; + case LLM_NORM_RMS: cur = ggml_rms_norm (ctx, cur, hparams.f_norm_rms_eps); break; + case LLM_NORM_GROUP: + { + // TODO: these reshapes should be removed, fix ggml_group_norm + cur = ggml_reshape_3d(ctx, cur, cur->ne[0], 1, cur->ne[1]); + cur = ggml_group_norm(ctx, cur, 32, 1e-6); // TODO: add groups, eps params + cur = ggml_reshape_2d(ctx, cur, cur->ne[0], cur->ne[2]); + } break; } if (mw || mb) { @@ -17025,6 +17041,12 @@ struct llm_build_context { printf("conv1d: %d %d %d\n", model.conv_1d->ne[0], model.conv_1d->ne[1], model.conv_1d->ne[2]); cur = ggml_conv_1d_ph(ctx0, model.conv_1d, cur, 1, 1); cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.conv_1d_b, 1, model.conv_1d_b->ne[0])); + + cur = llm_build_norm(ctx0, cur, hparams, + ggml_reshape_2d(ctx0, model.posnet_0_norm1, 1, model.posnet_0_norm1->ne[0]), + ggml_reshape_2d(ctx0, model.posnet_0_norm1_b, 1, model.posnet_0_norm1_b->ne[0]), + LLM_NORM_GROUP, cb, 0); + printf("cur: %d %d %d\n", cur->ne[0], cur->ne[1], cur->ne[2]); //cur = llm_build_norm(ctx0, cur, hparams,