llama : remove redundant GQA check

This commit is contained in:
Georgi Gerganov 2024-01-06 16:04:20 +02:00
parent 96e80dabc6
commit 7cfde78190
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -4772,7 +4772,6 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_gqa == n_embd);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -4896,7 +4895,6 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_gqa == n_embd);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * pos; struct ggml_tensor * pos;
@ -4997,7 +4995,6 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_gqa == n_embd);
const int64_t n_rot = n_embd_head_k / 2; const int64_t n_rot = n_embd_head_k / 2;
@ -5211,7 +5208,6 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_gqa == n_embd);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -5304,7 +5300,6 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_gqa == n_embd);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -5400,7 +5395,6 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_gqa == n_embd);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * inpL; struct ggml_tensor * inpL;
@ -5727,7 +5721,6 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_gqa == n_embd);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * attn_norm_output; struct ggml_tensor * attn_norm_output;
@ -5951,7 +5944,6 @@ struct llm_build_context {
const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_head = hparams.n_embd_head_v;
const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); const int64_t n_embd_gqa = hparams.n_embd_v_gqa();
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
GGML_ASSERT(n_embd_gqa == n_embd);
struct ggml_tensor * cur; struct ggml_tensor * cur;
struct ggml_tensor * pos; struct ggml_tensor * pos;