mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
llm : fix llm_build_kqv taking unused tensor (benign, #3837)
This commit is contained in:
parent
523e49b111
commit
c43c2da8af
19
llama.cpp
19
llama.cpp
@ -3345,7 +3345,6 @@ static struct ggml_tensor * llm_build_ffn(
|
||||
// if max_alibi_bias > 0 then apply ALiBi
|
||||
static struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * cur,
|
||||
const llama_hparams & hparams,
|
||||
const llama_kv_cache & kv,
|
||||
struct ggml_tensor * wo,
|
||||
@ -3411,7 +3410,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||
cb(kqv_merged, "kqv_merged", il);
|
||||
|
||||
cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens);
|
||||
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd, n_tokens);
|
||||
cb(cur, "kqv_merged_cont", il);
|
||||
|
||||
cur = ggml_mul_mat(ctx, wo, cur);
|
||||
@ -3565,7 +3564,7 @@ struct llm_build_context {
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
|
||||
cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
|
||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
@ -3677,7 +3676,7 @@ struct llm_build_context {
|
||||
// apply ALiBi for 13B model
|
||||
const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
|
||||
|
||||
cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
|
||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
@ -3795,7 +3794,7 @@ struct llm_build_context {
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
|
||||
cur = llm_build_kqv(ctx0, attn_norm, hparams, kv_self,
|
||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
@ -3895,7 +3894,7 @@ struct llm_build_context {
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
|
||||
cur = llm_build_kqv(ctx0, cur, hparams, kv_self,
|
||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
@ -4100,7 +4099,7 @@ struct llm_build_context {
|
||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
|
||||
// TODO: not tested, could be broken
|
||||
cur = llm_build_kqv(ctx0, Q, hparams, kv_self,
|
||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
@ -4191,7 +4190,7 @@ struct llm_build_context {
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
|
||||
cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
|
||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
@ -4288,7 +4287,7 @@ struct llm_build_context {
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
|
||||
cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
|
||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, 8.0f, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
@ -4382,7 +4381,7 @@ struct llm_build_context {
|
||||
|
||||
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
|
||||
|
||||
cur = llm_build_kqv(ctx0, Qcur, hparams, kv_self,
|
||||
cur = llm_build_kqv(ctx0, hparams, kv_self,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, hparams.f_max_alibi_bias, cb, il);
|
||||
cb(cur, "kqv_out", il);
|
||||
|
Loading…
Reference in New Issue
Block a user