llama : fix llm_build_k_shift to use correct n_rot (#4889)

* llama : fix llm_build_k_shift to use correct n_rot

ggml-ci

* llama : always use hparams.n_rot for ggml_rope_custom

ggml-ci

* convert : fix persimmon conversion to write correct n_rot
This commit is contained in:
Georgi Gerganov 2024-01-12 13:01:56 +02:00 committed by GitHub
parent 326b418b59
commit f445c0e68c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 33 deletions

View File

@ -1055,6 +1055,9 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
} }
static ggml_type kv_cache_type_from_str(const std::string & s) { static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "f32") {
return GGML_TYPE_F32;
}
if (s == "f16") { if (s == "f16") {
return GGML_TYPE_F16; return GGML_TYPE_F16;
} }

View File

@ -817,10 +817,17 @@ class PersimmonModel(Model):
hidden_size = self.hparams["hidden_size"] hidden_size = self.hparams["hidden_size"]
self.gguf_writer.add_name('persimmon-8b-chat') self.gguf_writer.add_name('persimmon-8b-chat')
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hidden_size) self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
# NOTE: not sure about this change - why does the model not have a rope dimension count when it is smaller
# than the head size?
# ref: https://github.com/ggerganov/llama.cpp/pull/4889
#self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
self.gguf_writer.add_rope_dimension_count(hidden_size // head_count // 2)
self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"]) self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])

View File

@ -57,6 +57,7 @@ class TensorNameMap:
"transformer.norm_f", # mpt "transformer.norm_f", # mpt
"ln_f", # refact bloom qwen gpt2 "ln_f", # refact bloom qwen gpt2
"language_model.encoder.final_layernorm", # persimmon "language_model.encoder.final_layernorm", # persimmon
"model.final_layernorm", # persimmon
"lm_head.ln", # phi2 "lm_head.ln", # phi2
), ),
@ -98,6 +99,7 @@ class TensorNameMap:
"transformer.h.{bid}.self_attention.query_key_value", # falcon "transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom "h.{bid}.self_attention.query_key_value", # bloom
"language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon
"model.layers.{bid}.self_attn.query_key_value", # persimmon
"h.{bid}.attn.c_attn", # gpt2 "h.{bid}.attn.c_attn", # gpt2
"transformer.h.{bid}.mixer.Wqkv", # phi2 "transformer.h.{bid}.mixer.Wqkv", # phi2
), ),
@ -141,6 +143,7 @@ class TensorNameMap:
"encoder.layer.{bid}.attention.output.dense", # bert "encoder.layer.{bid}.attention.output.dense", # bert
"transformer.h.{bid}.attn.out_proj", # gpt-j "transformer.h.{bid}.attn.out_proj", # gpt-j
"language_model.encoder.layers.{bid}.self_attention.dense", # persimmon "language_model.encoder.layers.{bid}.self_attention.dense", # persimmon
"model.layers.{bid}.self_attn.dense", # persimmon
"h.{bid}.attn.c_proj", # gpt2 "h.{bid}.attn.c_proj", # gpt2
"transformer.h.{bid}.mixer.out_proj", # phi2 "transformer.h.{bid}.mixer.out_proj", # phi2
"model.layers.layers.{bid}.self_attn.o_proj", # plamo "model.layers.layers.{bid}.self_attn.o_proj", # plamo
@ -184,6 +187,7 @@ class TensorNameMap:
"encoder.layer.{bid}.intermediate.dense", # bert "encoder.layer.{bid}.intermediate.dense", # bert
"transformer.h.{bid}.mlp.fc_in", # gpt-j "transformer.h.{bid}.mlp.fc_in", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
"model.layers.{bid}.mlp.dense_h_to_4h", # persimmon
"transformer.h.{bid}.mlp.w1", # qwen "transformer.h.{bid}.mlp.w1", # qwen
"h.{bid}.mlp.c_fc", # gpt2 "h.{bid}.mlp.c_fc", # gpt2
"transformer.h.{bid}.mlp.fc1", # phi2 "transformer.h.{bid}.mlp.fc1", # phi2
@ -225,6 +229,7 @@ class TensorNameMap:
"encoder.layer.{bid}.output.dense", # bert "encoder.layer.{bid}.output.dense", # bert
"transformer.h.{bid}.mlp.fc_out", # gpt-j "transformer.h.{bid}.mlp.fc_out", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon "language_model.encoder.layers.{bid}.mlp.dense_4h_to_h", # persimmon
"model.layers.{bid}.mlp.dense_4h_to_h", # persimmon
"h.{bid}.mlp.c_proj", # gpt2 "h.{bid}.mlp.c_proj", # gpt2
"transformer.h.{bid}.mlp.fc2", # phi2 "transformer.h.{bid}.mlp.fc2", # phi2
"model.layers.layers.{bid}.mlp.down_proj", # plamo "model.layers.layers.{bid}.mlp.down_proj", # plamo
@ -237,10 +242,12 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_Q_NORM: ( MODEL_TENSOR.ATTN_Q_NORM: (
"language_model.encoder.layers.{bid}.self_attention.q_layernorm", "language_model.encoder.layers.{bid}.self_attention.q_layernorm",
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
), ),
MODEL_TENSOR.ATTN_K_NORM: ( MODEL_TENSOR.ATTN_K_NORM: (
"language_model.encoder.layers.{bid}.self_attention.k_layernorm", "language_model.encoder.layers.{bid}.self_attention.k_layernorm",
"model.layers.{bid}.self_attn.k_layernorm", # persimmon
), ),
MODEL_TENSOR.ROPE_FREQS: ( MODEL_TENSOR.ROPE_FREQS: (

View File

@ -4104,7 +4104,6 @@ static void llm_build_k_shift(
struct ggml_cgraph * graph, struct ggml_cgraph * graph,
llm_rope_type type, llm_rope_type type,
int64_t n_ctx, int64_t n_ctx,
int n_rot,
float freq_base, float freq_base,
float freq_scale, float freq_scale,
const llm_build_cb & cb) { const llm_build_cb & cb) {
@ -4112,14 +4111,13 @@ static void llm_build_k_shift(
const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int32_t n_rot = hparams.n_rot;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx; const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const float ext_factor = cparams.yarn_ext_factor; const float ext_factor = cparams.yarn_ext_factor;
const float attn_factor = cparams.yarn_attn_factor; const float attn_factor = cparams.yarn_attn_factor;
const float beta_fast = cparams.yarn_beta_fast; const float beta_fast = cparams.yarn_beta_fast;
const float beta_slow = cparams.yarn_beta_slow; const float beta_slow = cparams.yarn_beta_slow;
GGML_ASSERT(n_embd_head_k % n_rot == 0);
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx); struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
cb(K_shift, "K_shift", -1); cb(K_shift, "K_shift", -1);
@ -4523,7 +4521,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -4561,14 +4559,14 @@ struct llm_build_context {
Qcur = ggml_rope_custom( Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom( Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
@ -4691,6 +4689,7 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -4708,7 +4707,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -4734,12 +4733,12 @@ struct llm_build_context {
case MODEL_7B: case MODEL_7B:
Qcur = ggml_rope_custom( Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
Kcur = ggml_rope_custom( Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
n_embd_head, 0, 0, n_orig_ctx, freq_base, freq_scale, hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow ext_factor, attn_factor, beta_fast, beta_slow
); );
break; break;
@ -4812,6 +4811,7 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -4829,7 +4829,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -4870,13 +4870,13 @@ struct llm_build_context {
// using mode = 2 for neox mode // using mode = 2 for neox mode
Qcur = ggml_rope_custom( Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom( Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
@ -5034,8 +5034,7 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head/2 == hparams.n_rot);
const int64_t n_rot = n_embd_head_k / 2;
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -5052,7 +5051,7 @@ struct llm_build_context {
cb(KQ_mask, "KQ_mask", -1); cb(KQ_mask, "KQ_mask", -1);
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -5112,7 +5111,7 @@ struct llm_build_context {
// RoPE the first n_rot of q/k, pass the other half, and concat. // RoPE the first n_rot of q/k, pass the other half, and concat.
struct ggml_tensor * qrot = ggml_view_3d( struct ggml_tensor * qrot = ggml_view_3d(
ctx0, tmpq, n_rot, n_head, n_tokens, ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
ggml_element_size(tmpq) * n_embd_head, ggml_element_size(tmpq) * n_embd_head,
ggml_element_size(tmpq) * n_embd_head * n_head, ggml_element_size(tmpq) * n_embd_head * n_head,
0 0
@ -5120,7 +5119,7 @@ struct llm_build_context {
cb(qrot, "qrot", il); cb(qrot, "qrot", il);
struct ggml_tensor * krot = ggml_view_3d( struct ggml_tensor * krot = ggml_view_3d(
ctx0, tmpk, n_rot, n_head, n_tokens, ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
ggml_element_size(tmpk) * n_embd_head, ggml_element_size(tmpk) * n_embd_head,
ggml_element_size(tmpk) * n_embd_head * n_head, ggml_element_size(tmpk) * n_embd_head * n_head,
0 0
@ -5129,29 +5128,29 @@ struct llm_build_context {
// get the second half of tmpq, e.g tmpq[n_rot:, :, :] // get the second half of tmpq, e.g tmpq[n_rot:, :, :]
struct ggml_tensor * qpass = ggml_view_3d( struct ggml_tensor * qpass = ggml_view_3d(
ctx0, tmpq, n_rot, n_head, n_tokens, ctx0, tmpq, hparams.n_rot, n_head, n_tokens,
ggml_element_size(tmpq) * n_embd_head, ggml_element_size(tmpq) * n_embd_head,
ggml_element_size(tmpq) * n_embd_head * n_head, ggml_element_size(tmpq) * n_embd_head * n_head,
ggml_element_size(tmpq) * n_rot ggml_element_size(tmpq) * hparams.n_rot
); );
cb(qpass, "qpass", il); cb(qpass, "qpass", il);
struct ggml_tensor * kpass = ggml_view_3d( struct ggml_tensor * kpass = ggml_view_3d(
ctx0, tmpk, n_rot, n_head, n_tokens, ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
ggml_element_size(tmpk) * n_embd_head, ggml_element_size(tmpk) * n_embd_head,
ggml_element_size(tmpk) * n_embd_head * n_head, ggml_element_size(tmpk) * n_embd_head * n_head,
ggml_element_size(tmpk) * n_rot ggml_element_size(tmpk) * hparams.n_rot
); );
cb(kpass, "kpass", il); cb(kpass, "kpass", il);
struct ggml_tensor * qrotated = ggml_rope_custom( struct ggml_tensor * qrotated = ggml_rope_custom(
ctx0, qrot, inp_pos, n_rot, 2, 0, n_orig_ctx, ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(qrotated, "qrotated", il); cb(qrotated, "qrotated", il);
struct ggml_tensor * krotated = ggml_rope_custom( struct ggml_tensor * krotated = ggml_rope_custom(
ctx0, krot, inp_pos, n_rot, 2, 0, n_orig_ctx, ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(krotated, "krotated", il); cb(krotated, "krotated", il);
@ -5531,6 +5530,7 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -5548,7 +5548,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -5661,7 +5661,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -5693,13 +5693,13 @@ struct llm_build_context {
// using mode = 2 for neox mode // using mode = 2 for neox mode
Qcur = ggml_rope_custom( Qcur = ggml_rope_custom(
ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, ctx0, Qcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom( Kcur = ggml_rope_custom(
ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx, ctx0, Kcur, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
@ -5778,7 +5778,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -5874,6 +5874,7 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_head == hparams.n_rot);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -5891,7 +5892,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
@ -5917,13 +5918,13 @@ struct llm_build_context {
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
Qcur = ggml_rope_custom( Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, ctx0, ggml_reshape_3d(ctx0, Qcur, hparams.n_rot, n_head, n_tokens), inp_pos,
n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
Kcur = ggml_rope_custom( Kcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, ctx0, ggml_reshape_3d(ctx0, Kcur, hparams.n_rot, n_head_kv, n_tokens), inp_pos,
n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale, n_embd_head, 2, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);