diff --git a/examples/tts/convert_pt_to_hf.py b/examples/tts/convert_pt_to_hf.py index 4a0d4bcc8..501fc4d6a 100644 --- a/examples/tts/convert_pt_to_hf.py +++ b/examples/tts/convert_pt_to_hf.py @@ -90,14 +90,17 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'): # these are the only rows used # ref: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/wav_tokenizer/audio_codec.py#L100 - if new_key == "backbone.norm.scale.weight": - new_key = "backbone.norm.weight" + if new_key.endswith("norm.scale.weight"): + new_key = new_key.replace("norm.scale.weight", "norm.weight") value = value[0] - if new_key == "backbone.norm.shift.weight": - new_key = "backbone.norm.bias" + if new_key.endswith("norm.shift.weight"): + new_key = new_key.replace("norm.shift.weight", "norm.bias") value = value[0] + if new_key.endswith("gamma"): + new_key = new_key.replace("gamma", "gamma.weight") + size_mb = value.element_size() * value.nelement() / (1024 * 1024) print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}") @@ -149,6 +152,7 @@ config = { "hidden_size": 512, "vocab_size": 4096, "n_head": 1, + "layer_norm_epsilon": 1e-6, "max_position_embeddings": 8192, # ? "num_hidden_layers": 12 } diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b0c1ac9ce..c714fc8c8 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1564,17 +1564,6 @@ extern "C" { int d1, // dilation dimension 1 bool is_2D); - GGML_API struct ggml_tensor * ggml_conv_depthwise_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, // convolution kernel - struct ggml_tensor * b, // data - int s0, // stride dimension 0 - int s1, // stride dimension 1 - int p0, // padding dimension 0 - int p1, // padding dimension 1 - int d0, // dilation dimension 0 - int d1); // dilation dimension 1 - GGML_API struct ggml_tensor * ggml_conv_1d( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel @@ -1592,6 +1581,23 @@ extern "C" { int s, // stride int d); // dilation + // depthwise + // TODO: this is very likely wrong for some cases! - needs more testing + GGML_API struct ggml_tensor * ggml_conv_1d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride + int p0, // padding + int d0); // dilation + + GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride + int d0); // dilation + GGML_API struct ggml_tensor * ggml_conv_transpose_1d( struct ggml_context * ctx, struct ggml_tensor * a, // convolution kernel @@ -1611,7 +1617,6 @@ extern "C" { int d0, // dilation dimension 0 int d1); // dilation dimension 1 - // kernel size is a->ne[0] x a->ne[1] // stride is equal to kernel size // padding is zero @@ -1638,6 +1643,18 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // depthwise + GGML_API struct ggml_tensor * ggml_conv_2d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, // convolution kernel + struct ggml_tensor * b, // data + int s0, // stride dimension 0 + int s1, // stride dimension 1 + int p0, // padding dimension 0 + int p1, // padding dimension 1 + int d0, // dilation dimension 0 + int d1); // dilation dimension 1 + GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4c0cff48d..4f29ab0bf 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3760,104 +3760,10 @@ struct ggml_tensor * ggml_clamp( return result; } -// ggml_conv_1d - static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) { return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; } -GGML_API struct ggml_tensor * ggml_conv_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0) { - struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K] - - struct ggml_tensor * result = - ggml_mul_mat(ctx, - ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K] - ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K] - - result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL] - - return result; -} - -// ggml_conv_1d_ph - -struct ggml_tensor* ggml_conv_1d_ph( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s, - int d) { - return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); -} - -// ggml_conv_transpose_1d - -static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { - return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; -} - -GGML_API struct ggml_tensor * ggml_conv_transpose_1d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int p0, - int d0) { - GGML_ASSERT(ggml_is_matrix(b)); - GGML_ASSERT(a->ne[2] == b->ne[1]); - GGML_ASSERT(a->ne[3] == 1); - - GGML_ASSERT(p0 == 0); - GGML_ASSERT(d0 == 1); - - const int64_t ne[4] = { - ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), - a->ne[1], b->ne[2], 1, - }; - struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); - - int32_t params[] = { s0, p0, d0 }; - ggml_set_op_params(result, params, sizeof(params)); - - result->op = GGML_OP_CONV_TRANSPOSE_1D; - result->src[0] = a; - result->src[1] = b; - - return result; -} - -// ggml_conv_depthwise - -struct ggml_tensor * ggml_conv_depthwise_2d( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - int s0, - int s1, - int p0, - int p1, - int d0, - int d1) { - struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); - struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, - ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), - s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] - struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] - - new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] - struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b); - result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] - - return result; -} -// ggml_conv_2d - // im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW] // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] @@ -3877,7 +3783,8 @@ struct ggml_tensor * ggml_im2col( if (is_2D) { GGML_ASSERT(a->ne[2] == b->ne[2]); } else { - GGML_ASSERT(a->ne[1] == b->ne[1]); + //GGML_ASSERT(b->ne[1] % a->ne[1] == 0); + GGML_ASSERT(b->ne[1] == a->ne[1]); GGML_ASSERT(b->ne[3] == 1); } @@ -3928,6 +3835,112 @@ struct ggml_tensor * ggml_im2col_back( return result; } +// ggml_conv_1d + +struct ggml_tensor * ggml_conv_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K] + + printf("a: %lld %lld %lld %lld\n", a->ne[0], a->ne[1], a->ne[2], a->ne[3]); + printf("b: %lld %lld %lld %lld\n", b->ne[0], b->ne[1], b->ne[2], b->ne[3]); + printf("im2col: %lld %lld %lld %lld\n", im2col->ne[0], im2col->ne[1], im2col->ne[2], im2col->ne[3]); + + struct ggml_tensor * result = + ggml_mul_mat(ctx, + ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K] + ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OC,IC, K] => [OC, IC * K] + + result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL] + + return result; +} + +// ggml_conv_1d_ph + +struct ggml_tensor* ggml_conv_1d_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s, + int d) { + return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); +} + +// ggml_conv_1d_dw + +struct ggml_tensor * ggml_conv_1d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]); + struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]); + + struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); + + struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a); + + result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1); + + return result; +} + +// ggml_conv_1d_dw_ph + +struct ggml_tensor * ggml_conv_1d_dw_ph( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int d0) { + return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0); +} + +// ggml_conv_transpose_1d + +static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) { + return (ins - 1) * s - 2 * p + d * (ks - 1) + 1; +} + +GGML_API struct ggml_tensor * ggml_conv_transpose_1d( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int p0, + int d0) { + GGML_ASSERT(ggml_is_matrix(b)); + GGML_ASSERT(a->ne[2] == b->ne[1]); + GGML_ASSERT(a->ne[3] == 1); + + GGML_ASSERT(p0 == 0); + GGML_ASSERT(d0 == 1); + + const int64_t ne[4] = { + ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/), + a->ne[1], b->ne[2], 1, + }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); + + int32_t params[] = { s0, p0, d0 }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_CONV_TRANSPOSE_1D; + result->src[0] = a; + result->src[1] = b; + + return result; +} + +// ggml_conv_2d + // a: [OC,IC, KH, KW] // b: [N, IC, IH, IW] // result: [N, OC, OH, OW] @@ -3973,6 +3986,31 @@ struct ggml_tensor * ggml_conv_2d_s1_ph( return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1); } +// ggml_conv_2d_dw + +struct ggml_tensor * ggml_conv_2d_dw( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + int s0, + int s1, + int p0, + int p1, + int d0, + int d1) { + struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]); + struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, + ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]), + s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW] + struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW] + + new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC,1, KH, KW] => [1, OC, 1, KH * KW] + struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b); + result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW] + + return result; +} + // ggml_conv_transpose_2d_p0 static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 81e434d11..ea74354a4 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -374,7 +374,6 @@ class MODEL_TENSOR(IntEnum): CONV1D = auto() CONV_NEXT_DW = auto() CONV_NEXT_NORM = auto() - CONV_NEXT_SHIFT = auto() CONV_NEXT_PW1 = auto() CONV_NEXT_PW2 = auto() CONV_NEXT_GAMMA = auto() @@ -557,7 +556,6 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.CONV1D: "conv1d", MODEL_TENSOR.CONV_NEXT_DW: "conv_next.{bid}.dw", MODEL_TENSOR.CONV_NEXT_NORM: "conv_next.{bid}.norm", - MODEL_TENSOR.CONV_NEXT_SHIFT: "conv_next.{bid}.shift", MODEL_TENSOR.CONV_NEXT_PW1: "conv_next.{bid}.pw1", MODEL_TENSOR.CONV_NEXT_PW2: "conv_next.{bid}.pw2", MODEL_TENSOR.CONV_NEXT_GAMMA: "conv_next.{bid}.gamma", @@ -1416,7 +1414,6 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.CONV1D, MODEL_TENSOR.CONV_NEXT_DW, MODEL_TENSOR.CONV_NEXT_NORM, - MODEL_TENSOR.CONV_NEXT_SHIFT, MODEL_TENSOR.CONV_NEXT_PW1, MODEL_TENSOR.CONV_NEXT_PW2, MODEL_TENSOR.CONV_NEXT_GAMMA, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 872205e77..ca7bb40fe 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -714,11 +714,7 @@ class TensorNameMap: ), MODEL_TENSOR.CONV_NEXT_NORM: ( - "backbone.convnext.{bid}.norm.scale", # outetts - ), - - MODEL_TENSOR.CONV_NEXT_SHIFT: ( - "backbone.convnext.{bid}.norm.shift", # outetts + "backbone.convnext.{bid}.norm", # outetts ), MODEL_TENSOR.CONV_NEXT_PW1: ( diff --git a/src/llama.cpp b/src/llama.cpp index aae7eb966..c923c5db9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -614,7 +614,6 @@ enum llm_tensor { LLM_TENSOR_CONV1D, LLM_TENSOR_CONV_NEXT_DW, LLM_TENSOR_CONV_NEXT_NORM, - LLM_TENSOR_CONV_NEXT_SHIFT, LLM_TENSOR_CONV_NEXT_PW1, LLM_TENSOR_CONV_NEXT_PW2, LLM_TENSOR_CONV_NEXT_GAMMA, @@ -1619,12 +1618,11 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, { LLM_TENSOR_CONV1D, "conv1d" }, - { LLM_TENSOR_CONV_NEXT_DW, "conv_next.dw" }, - { LLM_TENSOR_CONV_NEXT_NORM, "conv_next.norm" }, - { LLM_TENSOR_CONV_NEXT_SHIFT, "conv_next.shift" }, - { LLM_TENSOR_CONV_NEXT_PW1, "conv_next.pw1" }, - { LLM_TENSOR_CONV_NEXT_PW2, "conv_next.pw2" }, - { LLM_TENSOR_CONV_NEXT_GAMMA, "conv_next.gamma" }, + { LLM_TENSOR_CONV_NEXT_DW, "conv_next.%d.dw" }, + { LLM_TENSOR_CONV_NEXT_NORM, "conv_next.%d.norm" }, + { LLM_TENSOR_CONV_NEXT_PW1, "conv_next.%d.pw1" }, + { LLM_TENSOR_CONV_NEXT_PW2, "conv_next.%d.pw2" }, + { LLM_TENSOR_CONV_NEXT_GAMMA, "conv_next.%d.gamma" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, { LLM_TENSOR_POS_NET_CONV1, "pos_net.%d.conv1" }, @@ -2922,6 +2920,21 @@ struct llama_layer { struct ggml_tensor * ffn_gate_scale; struct ggml_tensor * ffn_up_scale; struct ggml_tensor * ffn_down_scale; + + // convnext + struct ggml_tensor * convnext_dw; + struct ggml_tensor * convnext_dw_b; + + struct ggml_tensor * convnext_norm; + struct ggml_tensor * convnext_norm_b; + + struct ggml_tensor * convnext_pw1; + struct ggml_tensor * convnext_pw1_b; + + struct ggml_tensor * convnext_pw2; + struct ggml_tensor * convnext_pw2_b; + + struct ggml_tensor * convnext_gamma; }; // very similar to llama_batch, @@ -6420,6 +6433,10 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_OUTETTS_VOC: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + } break; default: (void)0; } @@ -7439,6 +7456,11 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_POS_NET_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_POS_NET_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_POS_NET_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONV_NEXT_DW, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}}, + {LLM_TENSOR_CONV_NEXT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_CONV_NEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONV_NEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CONV_NEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, }; // checks if the weight tensor can be used with the specified buffer type and device @@ -9589,6 +9611,25 @@ static bool llm_load_tensors( model.posnet_5_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", 5), {768}, 0); model.posnet_5_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", 5), {768}, 0); + for (int i = 0; i < n_layer; ++i) { + auto & layer = model.layers[i]; + + layer.convnext_dw = create_tensor(tn(LLM_TENSOR_CONV_NEXT_DW, "weight", i), {7, 1, 768}, 0); + layer.convnext_dw_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_DW, "bias", i), {768}, 0); + + layer.convnext_norm = create_tensor(tn(LLM_TENSOR_CONV_NEXT_NORM, "weight", i), {768}, 0); + layer.convnext_norm_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_NORM, "bias", i), {768}, 0); + + // TODO: n_ff + layer.convnext_pw1 = create_tensor(tn(LLM_TENSOR_CONV_NEXT_PW1, "weight", i), {768, 2304}, 0); + layer.convnext_pw1_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_PW1, "bias", i), {2304}, 0); + + layer.convnext_pw2 = create_tensor(tn(LLM_TENSOR_CONV_NEXT_PW2, "weight", i), {2304, 768}, 0); + layer.convnext_pw2_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_PW2, "bias", i), {768}, 0); + + layer.convnext_gamma = create_tensor(tn(LLM_TENSOR_CONV_NEXT_GAMMA, "weight", i), {768}, 0); + } + // output model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {768}, 0); model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {768, 1282}, llama_model_loader::TENSOR_NOT_REQUIRED); @@ -17338,11 +17379,46 @@ struct llm_build_context { LLM_NORM_GROUP, cb, 0); } - cur = llm_build_norm(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), hparams, + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + cur = llm_build_norm(ctx0, cur, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + inpL = cur; + + for (int il = 0; il < n_layer; ++il) { + cur = inpL; + + cur = ggml_conv_1d_dw_ph(ctx0, model.layers[il].convnext_dw, cur, 1, 1); + cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.layers[il].convnext_dw_b, 1, model.layers[il].convnext_dw_b->ne[0])); + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + cur = llm_build_norm(ctx0, cur, hparams, + model.layers[il].convnext_norm, + model.layers[il].convnext_norm_b, + LLM_NORM, cb, -1); + + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].convnext_pw1, model.layers[il].convnext_pw1_b, NULL, + NULL, NULL, NULL, + model.layers[il].convnext_pw2, model.layers[il].convnext_pw2_b, NULL, + NULL, + LLM_FFN_GELU, LLM_FFN_SEQ, cb, il); + + cur = ggml_mul(ctx0, cur, model.layers[il].convnext_gamma); + + cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur)); + + inpL = ggml_add(ctx0, cur, inpL); + } + + cur = inpL; + printf("cur: %d %d %d\n", cur->ne[0], cur->ne[1], cur->ne[2]); //cur = llm_build_norm(ctx0, cur, hparams,