group norm

This commit is contained in:
Georgi Gerganov 2024-12-10 20:33:29 +02:00
parent ce49e6a2cd
commit ba07b35d6f
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1629,16 +1629,16 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_CONV_NEXT_GAMMA, "conv_next.gamma" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_POS_NET_CONV1, "pos_net.conv1" },
{ LLM_TENSOR_POS_NET_CONV2, "pos_net.conv2" },
{ LLM_TENSOR_POS_NET_CONV1, "pos_net.%d.conv1" },
{ LLM_TENSOR_POS_NET_CONV2, "pos_net.%d.conv2" },
{ LLM_TENSOR_POS_NET_NORM, "pos_net.norm" },
{ LLM_TENSOR_POS_NET_NORM1, "pos_net.norm1" },
{ LLM_TENSOR_POS_NET_NORM2, "pos_net.norm2" },
{ LLM_TENSOR_POS_NET_ATTN_NORM, "pos_net.attn_norm" },
{ LLM_TENSOR_POS_NET_ATTN_Q, "pos_net.attn_q" },
{ LLM_TENSOR_POS_NET_ATTN_K, "pos_net.attn_k" },
{ LLM_TENSOR_POS_NET_ATTN_V, "pos_net.attn_v" },
{ LLM_TENSOR_POS_NET_ATTN_OUT, "pos_net.attn_output" },
{ LLM_TENSOR_POS_NET_NORM1, "pos_net.%d.norm1" },
{ LLM_TENSOR_POS_NET_NORM2, "pos_net.%d.norm2" },
{ LLM_TENSOR_POS_NET_ATTN_NORM, "pos_net.%d.attn_norm" },
{ LLM_TENSOR_POS_NET_ATTN_Q, "pos_net.%d.attn_q" },
{ LLM_TENSOR_POS_NET_ATTN_K, "pos_net.%d.attn_k" },
{ LLM_TENSOR_POS_NET_ATTN_V, "pos_net.%d.attn_v" },
{ LLM_TENSOR_POS_NET_ATTN_OUT, "pos_net.%d.attn_output" },
{ LLM_TENSOR_HANN_WINDOW, "hann_window" },
},
},
@ -3054,9 +3054,13 @@ struct llama_model {
struct ggml_tensor * cls_out = nullptr;
struct ggml_tensor * cls_out_b = nullptr;
// outetts vocoder
struct ggml_tensor * conv_1d = nullptr;
struct ggml_tensor * conv_1d_b = nullptr;
struct ggml_tensor * posnet_0_norm1 = nullptr;
struct ggml_tensor * posnet_0_norm1_b = nullptr;
std::vector<llama_layer> layers;
// gguf metadata
@ -7357,6 +7361,7 @@ static const std::map<llm_tensor, llm_tensor_info> llm_tensor_info_mapping = {
// this tensor is loaded for T5, but never used
{LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}},
{LLM_TENSOR_CONV1D, {LLM_TENSOR_LAYER_INPUT, GGML_OP_IM2COL}},
{LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};
// checks if the weight tensor can be used with the specified buffer type and device
@ -9438,6 +9443,9 @@ static bool llm_load_tensors(
model.conv_1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, n_embd, 768}, 0);
model.conv_1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {768}, 0);
model.posnet_0_norm1 = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "weight", 0), {768}, 0);
model.posnet_0_norm1_b = create_tensor(tn(LLM_TENSOR_POS_NET_NORM1, "bias", 0), {768}, 0);
// output
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {768}, 0);
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {768, 1282}, llama_model_loader::TENSOR_NOT_REQUIRED);
@ -9661,6 +9669,7 @@ enum llm_ffn_gate_type {
enum llm_norm_type {
LLM_NORM,
LLM_NORM_RMS,
LLM_NORM_GROUP,
};
static struct ggml_tensor * llm_build_inp_embd(
@ -9802,8 +9811,15 @@ static struct ggml_tensor * llm_build_norm(
const llm_build_cb & cb,
int il) {
switch (type) {
case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break;
case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, hparams.f_norm_rms_eps); break;
case LLM_NORM: cur = ggml_norm (ctx, cur, hparams.f_norm_eps); break;
case LLM_NORM_RMS: cur = ggml_rms_norm (ctx, cur, hparams.f_norm_rms_eps); break;
case LLM_NORM_GROUP:
{
// TODO: these reshapes should be removed, fix ggml_group_norm
cur = ggml_reshape_3d(ctx, cur, cur->ne[0], 1, cur->ne[1]);
cur = ggml_group_norm(ctx, cur, 32, 1e-6); // TODO: add groups, eps params
cur = ggml_reshape_2d(ctx, cur, cur->ne[0], cur->ne[2]);
} break;
}
if (mw || mb) {
@ -17025,6 +17041,12 @@ struct llm_build_context {
printf("conv1d: %d %d %d\n", model.conv_1d->ne[0], model.conv_1d->ne[1], model.conv_1d->ne[2]);
cur = ggml_conv_1d_ph(ctx0, model.conv_1d, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.conv_1d_b, 1, model.conv_1d_b->ne[0]));
cur = llm_build_norm(ctx0, cur, hparams,
ggml_reshape_2d(ctx0, model.posnet_0_norm1, 1, model.posnet_0_norm1->ne[0]),
ggml_reshape_2d(ctx0, model.posnet_0_norm1_b, 1, model.posnet_0_norm1_b->ne[0]),
LLM_NORM_GROUP, cb, 0);
printf("cur: %d %d %d\n", cur->ne[0], cur->ne[1], cur->ne[2]);
//cur = llm_build_norm(ctx0, cur, hparams,