This commit is contained in:
Georgi Gerganov 2024-12-11 10:06:48 +02:00
parent 50904afb98
commit b09557dac6
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
6 changed files with 255 additions and 127 deletions

View File

@ -90,14 +90,17 @@ def flatten_state_dict(state_dict, parent_key='', sep='.'):
# these are the only rows used
# ref: https://github.com/edwko/OuteTTS/blob/a613e79c489d8256dd657ea9168d78de75895d82/outetts/wav_tokenizer/audio_codec.py#L100
if new_key == "backbone.norm.scale.weight":
new_key = "backbone.norm.weight"
if new_key.endswith("norm.scale.weight"):
new_key = new_key.replace("norm.scale.weight", "norm.weight")
value = value[0]
if new_key == "backbone.norm.shift.weight":
new_key = "backbone.norm.bias"
if new_key.endswith("norm.shift.weight"):
new_key = new_key.replace("norm.shift.weight", "norm.bias")
value = value[0]
if new_key.endswith("gamma"):
new_key = new_key.replace("gamma", "gamma.weight")
size_mb = value.element_size() * value.nelement() / (1024 * 1024)
print(f"{size_mb:8.2f} MB - {new_key}: {value.shape}")
@ -149,6 +152,7 @@ config = {
"hidden_size": 512,
"vocab_size": 4096,
"n_head": 1,
"layer_norm_epsilon": 1e-6,
"max_position_embeddings": 8192, # ?
"num_hidden_layers": 12
}

View File

@ -1564,17 +1564,6 @@ extern "C" {
int d1, // dilation dimension 1
bool is_2D);
GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
struct ggml_tensor * b, // data
int s0, // stride dimension 0
int s1, // stride dimension 1
int p0, // padding dimension 0
int p1, // padding dimension 1
int d0, // dilation dimension 0
int d1); // dilation dimension 1
GGML_API struct ggml_tensor * ggml_conv_1d(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
@ -1592,6 +1581,23 @@ extern "C" {
int s, // stride
int d); // dilation
// depthwise
// TODO: this is very likely wrong for some cases! - needs more testing
GGML_API struct ggml_tensor * ggml_conv_1d_dw(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
struct ggml_tensor * b, // data
int s0, // stride
int p0, // padding
int d0); // dilation
GGML_API struct ggml_tensor * ggml_conv_1d_dw_ph(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
struct ggml_tensor * b, // data
int s0, // stride
int d0); // dilation
GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
@ -1611,7 +1617,6 @@ extern "C" {
int d0, // dilation dimension 0
int d1); // dilation dimension 1
// kernel size is a->ne[0] x a->ne[1]
// stride is equal to kernel size
// padding is zero
@ -1638,6 +1643,18 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
// depthwise
GGML_API struct ggml_tensor * ggml_conv_2d_dw(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
struct ggml_tensor * b, // data
int s0, // stride dimension 0
int s1, // stride dimension 1
int p0, // padding dimension 0
int p1, // padding dimension 1
int d0, // dilation dimension 0
int d1); // dilation dimension 1
GGML_API struct ggml_tensor * ggml_conv_transpose_2d_p0(
struct ggml_context * ctx,
struct ggml_tensor * a,

View File

@ -3760,104 +3760,10 @@ struct ggml_tensor * ggml_clamp(
return result;
}
// ggml_conv_1d
static int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
return (ins + 2 * p - d * (ks - 1) - 1) / s + 1;
}
GGML_API struct ggml_tensor * ggml_conv_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int p0,
int d0) {
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
struct ggml_tensor * result =
ggml_mul_mat(ctx,
ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OCIC, K] => [OC, IC * K]
result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
return result;
}
// ggml_conv_1d_ph
struct ggml_tensor* ggml_conv_1d_ph(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s,
int d) {
return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
}
// ggml_conv_transpose_1d
static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
}
GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int p0,
int d0) {
GGML_ASSERT(ggml_is_matrix(b));
GGML_ASSERT(a->ne[2] == b->ne[1]);
GGML_ASSERT(a->ne[3] == 1);
GGML_ASSERT(p0 == 0);
GGML_ASSERT(d0 == 1);
const int64_t ne[4] = {
ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
a->ne[1], b->ne[2], 1,
};
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
int32_t params[] = { s0, p0, d0 };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_CONV_TRANSPOSE_1D;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_conv_depthwise
struct ggml_tensor * ggml_conv_depthwise_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1) {
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC1, KH, KW] => [1, OC, 1, KH * KW]
struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
return result;
}
// ggml_conv_2d
// im2col: [N, IC, IH, IW] => [N, OH, OW, IC*KH*KW]
// a: [OCIC, KH, KW]
// b: [N, IC, IH, IW]
@ -3877,7 +3783,8 @@ struct ggml_tensor * ggml_im2col(
if (is_2D) {
GGML_ASSERT(a->ne[2] == b->ne[2]);
} else {
GGML_ASSERT(a->ne[1] == b->ne[1]);
//GGML_ASSERT(b->ne[1] % a->ne[1] == 0);
GGML_ASSERT(b->ne[1] == a->ne[1]);
GGML_ASSERT(b->ne[3] == 1);
}
@ -3928,6 +3835,112 @@ struct ggml_tensor * ggml_im2col_back(
return result;
}
// ggml_conv_1d
struct ggml_tensor * ggml_conv_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int p0,
int d0) {
struct ggml_tensor * im2col = ggml_im2col(ctx, a, b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16); // [N, OL, IC * K]
printf("a: %lld %lld %lld %lld\n", a->ne[0], a->ne[1], a->ne[2], a->ne[3]);
printf("b: %lld %lld %lld %lld\n", b->ne[0], b->ne[1], b->ne[2], b->ne[3]);
printf("im2col: %lld %lld %lld %lld\n", im2col->ne[0], im2col->ne[1], im2col->ne[2], im2col->ne[3]);
struct ggml_tensor * result =
ggml_mul_mat(ctx,
ggml_reshape_2d(ctx, im2col, im2col->ne[0], (im2col->ne[2] * im2col->ne[1])), // [N, OL, IC * K] => [N*OL, IC * K]
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1]), a->ne[2])); // [OCIC, K] => [OC, IC * K]
result = ggml_reshape_3d(ctx, result, im2col->ne[1], a->ne[2], im2col->ne[2]); // [N, OC, OL]
return result;
}
// ggml_conv_1d_ph
struct ggml_tensor* ggml_conv_1d_ph(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s,
int d) {
return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
}
// ggml_conv_1d_dw
struct ggml_tensor * ggml_conv_1d_dw(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int p0,
int d0) {
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], 1, a->ne[1], a->ne[2]);
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, b, b->ne[0], 1, b->ne[1], b->ne[2]);
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a, new_b, s0, 0, p0, 0, d0, 0, false, GGML_TYPE_F16);
struct ggml_tensor * result = ggml_mul_mat(ctx, im2col, a);
result = ggml_reshape_3d(ctx, result, b->ne[0], b->ne[1], 1);
return result;
}
// ggml_conv_1d_dw_ph
struct ggml_tensor * ggml_conv_1d_dw_ph(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int d0) {
return ggml_conv_1d_dw(ctx, a, b, s0, a->ne[0] / 2, d0);
}
// ggml_conv_transpose_1d
static int64_t ggml_calc_conv_transpose_1d_output_size(int64_t ins, int64_t ks, int s, int p, int d) {
return (ins - 1) * s - 2 * p + d * (ks - 1) + 1;
}
GGML_API struct ggml_tensor * ggml_conv_transpose_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int p0,
int d0) {
GGML_ASSERT(ggml_is_matrix(b));
GGML_ASSERT(a->ne[2] == b->ne[1]);
GGML_ASSERT(a->ne[3] == 1);
GGML_ASSERT(p0 == 0);
GGML_ASSERT(d0 == 1);
const int64_t ne[4] = {
ggml_calc_conv_transpose_1d_output_size(b->ne[0], a->ne[0], s0, 0 /*p0*/, 1 /*d0*/),
a->ne[1], b->ne[2], 1,
};
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
int32_t params[] = { s0, p0, d0 };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_CONV_TRANSPOSE_1D;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_conv_2d
// a: [OCIC, KH, KW]
// b: [N, IC, IH, IW]
// result: [N, OC, OH, OW]
@ -3973,6 +3986,31 @@ struct ggml_tensor * ggml_conv_2d_s1_ph(
return ggml_conv_2d(ctx, a, b, 1, 1, a->ne[0] / 2, a->ne[1] / 2, 1, 1);
}
// ggml_conv_2d_dw
struct ggml_tensor * ggml_conv_2d_dw(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1) {
struct ggml_tensor * new_a = ggml_reshape_4d(ctx, a, a->ne[0], a->ne[1], 1, a->ne[2] * a->ne[3]);
struct ggml_tensor * im2col = ggml_im2col(ctx, new_a,
ggml_reshape_4d(ctx, b, b->ne[0], b->ne[1], 1, b->ne[2] * b->ne[3]),
s0, s1, p0, p1, d0, d1, true, GGML_TYPE_F16); // [N * IC, OH, OW, KH * KW]
struct ggml_tensor * new_b = ggml_reshape_4d(ctx, im2col, im2col->ne[0], im2col->ne[2] * im2col->ne[1], b->ne[2], b->ne[3]); // [N * IC, OH, OW, KH * KW] => [N, IC, OH * OW, KH * KW]
new_a = ggml_reshape_4d(ctx, new_a, (new_a->ne[0] * new_a->ne[1]), new_a->ne[2], new_a->ne[3], 1); // [OC1, KH, KW] => [1, OC, 1, KH * KW]
struct ggml_tensor * result = ggml_mul_mat(ctx, new_a, new_b);
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], b->ne[2], b->ne[3]); // [N, OC, OH, OW]
return result;
}
// ggml_conv_transpose_2d_p0
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {

View File

@ -374,7 +374,6 @@ class MODEL_TENSOR(IntEnum):
CONV1D = auto()
CONV_NEXT_DW = auto()
CONV_NEXT_NORM = auto()
CONV_NEXT_SHIFT = auto()
CONV_NEXT_PW1 = auto()
CONV_NEXT_PW2 = auto()
CONV_NEXT_GAMMA = auto()
@ -557,7 +556,6 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.CONV1D: "conv1d",
MODEL_TENSOR.CONV_NEXT_DW: "conv_next.{bid}.dw",
MODEL_TENSOR.CONV_NEXT_NORM: "conv_next.{bid}.norm",
MODEL_TENSOR.CONV_NEXT_SHIFT: "conv_next.{bid}.shift",
MODEL_TENSOR.CONV_NEXT_PW1: "conv_next.{bid}.pw1",
MODEL_TENSOR.CONV_NEXT_PW2: "conv_next.{bid}.pw2",
MODEL_TENSOR.CONV_NEXT_GAMMA: "conv_next.{bid}.gamma",
@ -1416,7 +1414,6 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.CONV1D,
MODEL_TENSOR.CONV_NEXT_DW,
MODEL_TENSOR.CONV_NEXT_NORM,
MODEL_TENSOR.CONV_NEXT_SHIFT,
MODEL_TENSOR.CONV_NEXT_PW1,
MODEL_TENSOR.CONV_NEXT_PW2,
MODEL_TENSOR.CONV_NEXT_GAMMA,

View File

@ -714,11 +714,7 @@ class TensorNameMap:
),
MODEL_TENSOR.CONV_NEXT_NORM: (
"backbone.convnext.{bid}.norm.scale", # outetts
),
MODEL_TENSOR.CONV_NEXT_SHIFT: (
"backbone.convnext.{bid}.norm.shift", # outetts
"backbone.convnext.{bid}.norm", # outetts
),
MODEL_TENSOR.CONV_NEXT_PW1: (

View File

@ -614,7 +614,6 @@ enum llm_tensor {
LLM_TENSOR_CONV1D,
LLM_TENSOR_CONV_NEXT_DW,
LLM_TENSOR_CONV_NEXT_NORM,
LLM_TENSOR_CONV_NEXT_SHIFT,
LLM_TENSOR_CONV_NEXT_PW1,
LLM_TENSOR_CONV_NEXT_PW2,
LLM_TENSOR_CONV_NEXT_GAMMA,
@ -1619,12 +1618,11 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_CONV1D, "conv1d" },
{ LLM_TENSOR_CONV_NEXT_DW, "conv_next.dw" },
{ LLM_TENSOR_CONV_NEXT_NORM, "conv_next.norm" },
{ LLM_TENSOR_CONV_NEXT_SHIFT, "conv_next.shift" },
{ LLM_TENSOR_CONV_NEXT_PW1, "conv_next.pw1" },
{ LLM_TENSOR_CONV_NEXT_PW2, "conv_next.pw2" },
{ LLM_TENSOR_CONV_NEXT_GAMMA, "conv_next.gamma" },
{ LLM_TENSOR_CONV_NEXT_DW, "conv_next.%d.dw" },
{ LLM_TENSOR_CONV_NEXT_NORM, "conv_next.%d.norm" },
{ LLM_TENSOR_CONV_NEXT_PW1, "conv_next.%d.pw1" },
{ LLM_TENSOR_CONV_NEXT_PW2, "conv_next.%d.pw2" },
{ LLM_TENSOR_CONV_NEXT_GAMMA, "conv_next.%d.gamma" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_POS_NET_CONV1, "pos_net.%d.conv1" },
@ -2922,6 +2920,21 @@ struct llama_layer {
struct ggml_tensor * ffn_gate_scale;
struct ggml_tensor * ffn_up_scale;
struct ggml_tensor * ffn_down_scale;
// convnext
struct ggml_tensor * convnext_dw;
struct ggml_tensor * convnext_dw_b;
struct ggml_tensor * convnext_norm;
struct ggml_tensor * convnext_norm_b;
struct ggml_tensor * convnext_pw1;
struct ggml_tensor * convnext_pw1_b;
struct ggml_tensor * convnext_pw2;
struct ggml_tensor * convnext_pw2_b;
struct ggml_tensor * convnext_gamma;
};
// very similar to llama_batch,
@ -6420,6 +6433,10 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_OUTETTS_VOC:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
} break;
default: (void)0;
}
@ -7439,6 +7456,11 @@ static const std::map<llm_tensor, llm_tensor_info> llm_tensor_info_mapping = {
{LLM_TENSOR_POS_NET_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_POS_NET_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_POS_NET_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONV_NEXT_DW, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_IM2COL}},
{LLM_TENSOR_CONV_NEXT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_CONV_NEXT_PW1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONV_NEXT_PW2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CONV_NEXT_GAMMA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
};
// checks if the weight tensor can be used with the specified buffer type and device
@ -9589,6 +9611,25 @@ static bool llm_load_tensors(
model.posnet_5_norm = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "weight", 5), {768}, 0);
model.posnet_5_norm_b = create_tensor(tn(LLM_TENSOR_POS_NET_ATTN_NORM, "bias", 5), {768}, 0);
for (int i = 0; i < n_layer; ++i) {
auto & layer = model.layers[i];
layer.convnext_dw = create_tensor(tn(LLM_TENSOR_CONV_NEXT_DW, "weight", i), {7, 1, 768}, 0);
layer.convnext_dw_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_DW, "bias", i), {768}, 0);
layer.convnext_norm = create_tensor(tn(LLM_TENSOR_CONV_NEXT_NORM, "weight", i), {768}, 0);
layer.convnext_norm_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_NORM, "bias", i), {768}, 0);
// TODO: n_ff
layer.convnext_pw1 = create_tensor(tn(LLM_TENSOR_CONV_NEXT_PW1, "weight", i), {768, 2304}, 0);
layer.convnext_pw1_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_PW1, "bias", i), {2304}, 0);
layer.convnext_pw2 = create_tensor(tn(LLM_TENSOR_CONV_NEXT_PW2, "weight", i), {2304, 768}, 0);
layer.convnext_pw2_b = create_tensor(tn(LLM_TENSOR_CONV_NEXT_PW2, "bias", i), {768}, 0);
layer.convnext_gamma = create_tensor(tn(LLM_TENSOR_CONV_NEXT_GAMMA, "weight", i), {768}, 0);
}
// output
model.output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {768}, 0);
model.output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {768, 1282}, llama_model_loader::TENSOR_NOT_REQUIRED);
@ -17338,11 +17379,46 @@ struct llm_build_context {
LLM_NORM_GROUP, cb, 0);
}
cur = llm_build_norm(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), hparams,
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
cur = llm_build_norm(ctx0, cur, hparams,
model.tok_norm,
model.tok_norm_b,
LLM_NORM, cb, -1);
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
inpL = cur;
for (int il = 0; il < n_layer; ++il) {
cur = inpL;
cur = ggml_conv_1d_dw_ph(ctx0, model.layers[il].convnext_dw, cur, 1, 1);
cur = ggml_add(ctx0, cur, ggml_reshape_2d(ctx0, model.layers[il].convnext_dw_b, 1, model.layers[il].convnext_dw_b->ne[0]));
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
cur = llm_build_norm(ctx0, cur, hparams,
model.layers[il].convnext_norm,
model.layers[il].convnext_norm_b,
LLM_NORM, cb, -1);
cur = llm_build_ffn(ctx0, lctx, cur,
model.layers[il].convnext_pw1, model.layers[il].convnext_pw1_b, NULL,
NULL, NULL, NULL,
model.layers[il].convnext_pw2, model.layers[il].convnext_pw2_b, NULL,
NULL,
LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
cur = ggml_mul(ctx0, cur, model.layers[il].convnext_gamma);
cur = ggml_cont(ctx0, ggml_transpose(ctx0, cur));
inpL = ggml_add(ctx0, cur, inpL);
}
cur = inpL;
printf("cur: %d %d %d\n", cur->ne[0], cur->ne[1], cur->ne[2]);
//cur = llm_build_norm(ctx0, cur, hparams,