falcon-main.cpp : fix for falcon 40b

This commit is contained in:
klosax 2023-08-19 01:03:37 +02:00 committed by GitHub
parent bd5a57901b
commit 1d80eea574
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -605,11 +605,10 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt2bpe_
const int n_block = hparams.n_block;
const int n_ctx = hparams.n_ctx;
const int n_head_kv = hparams.n_head_kv;
const int head_dim = hparams.n_embd / hparams.n_head;
const int n_embd = hparams.n_embd;
const int64_t n_mem = n_block*n_ctx;
const int64_t n_elements = head_dim*n_mem;
const int64_t n_elements = n_embd*n_mem;
// create the ggml context
{
@ -628,8 +627,8 @@ bool falcon_model_load(const std::string & fname, falcon_model & model, gpt2bpe_
}
model.memory_k = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_head_kv * n_elements);
model.memory_v = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_head_kv * n_elements);
model.memory_k = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_elements);
model.memory_v = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_elements);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);