diff --git a/examples/gptneox-wip/falcon-main.cpp b/examples/gptneox-wip/falcon-main.cpp index c97aa602a..5be192a48 100644 --- a/examples/gptneox-wip/falcon-main.cpp +++ b/examples/gptneox-wip/falcon-main.cpp @@ -605,11 +605,10 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt2bpe_ const int n_block = hparams.n_block; const int n_ctx = hparams.n_ctx; - const int n_head_kv = hparams.n_head_kv; - const int head_dim = hparams.n_embd / hparams.n_head; + const int n_embd = hparams.n_embd; const int64_t n_mem = n_block*n_ctx; - const int64_t n_elements = head_dim*n_mem; + const int64_t n_elements = n_embd*n_mem; // create the ggml context { @@ -628,8 +627,8 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt2bpe_ } - model.memory_k = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_head_kv * n_elements); - model.memory_v = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_head_kv * n_elements); + model.memory_k = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_elements); + model.memory_v = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_elements); const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);