mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
gguf : get rid of n_mult, read n_ff from file
This commit is contained in:
parent
f44bbd3d88
commit
e732423280
@ -177,15 +177,12 @@ struct llama_hparams {
|
||||
uint32_t n_vocab = 32000;
|
||||
uint32_t n_ctx = 512; // this is provided as user input?
|
||||
uint32_t n_embd = 4096;
|
||||
uint32_t n_mult = 256;
|
||||
uint32_t n_head = 32;
|
||||
uint32_t n_head_kv = 32;
|
||||
uint32_t n_layer = 32;
|
||||
uint32_t n_rot = 64;
|
||||
uint32_t n_ff = 11008;
|
||||
|
||||
// LLaMAv2
|
||||
// TODO: load from model data hparams
|
||||
float f_ffn_mult = 1.0f;
|
||||
float f_rms_norm_eps = LLAMA_DEFAULT_RMS_EPS;
|
||||
|
||||
float rope_freq_base = 10000.0f;
|
||||
@ -467,7 +464,7 @@ struct llama_load_tensors_map {
|
||||
};
|
||||
|
||||
enum gguf_file_version {
|
||||
GGUF_FILE_VERSION_V1
|
||||
GGUF_FILE_VERSION_V1 = 1,
|
||||
|
||||
};
|
||||
|
||||
@ -490,6 +487,7 @@ struct ggml_context * ctx_data = NULL;
|
||||
};
|
||||
|
||||
gguf_ctx = gguf_init_from_file(fname, params);
|
||||
file_version = (enum gguf_file_version) gguf_get_version(gguf_ctx);
|
||||
|
||||
read_hparams();
|
||||
read_vocab();
|
||||
@ -505,6 +503,15 @@ struct ggml_context * ctx_data = NULL;
|
||||
return gguf_get_val_u32(gguf_ctx, i);
|
||||
}
|
||||
|
||||
float read_f32(const char * key) {
|
||||
int i = gguf_find_key(gguf_ctx, key);
|
||||
if (i == -1) {
|
||||
throw std::runtime_error(format("cannot find param with key %s\n", key));
|
||||
}
|
||||
|
||||
return gguf_get_val_f32(gguf_ctx, i);
|
||||
}
|
||||
|
||||
int read_n_vocab() {
|
||||
int i = gguf_find_key(gguf_ctx, "tokenizer.ggml.tokens");
|
||||
if (i == -1) {
|
||||
@ -514,18 +521,6 @@ struct ggml_context * ctx_data = NULL;
|
||||
return gguf_get_arr_n(gguf_ctx, i);
|
||||
}
|
||||
|
||||
int find_n_mult(const int n_ff, const int n_embd) {
|
||||
int n_mults[3] = {8192, 1, -1};
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
int calc_ff = (((8 * n_embd) / 3 + n_mults[i] - 1) / n_mults[i]) * n_mults[i];
|
||||
if (calc_ff == n_ff) {
|
||||
return n_mults[i];
|
||||
}
|
||||
}
|
||||
|
||||
throw std::runtime_error(format("failed to find n_mult for n_ff = %d and n_embd = %d\n", n_ff, n_embd));
|
||||
}
|
||||
|
||||
void read_hparams() {
|
||||
|
||||
// TODO make keysconstants in header
|
||||
@ -533,14 +528,12 @@ struct ggml_context * ctx_data = NULL;
|
||||
hparams.n_vocab = read_n_vocab();
|
||||
hparams.n_ctx = read_u32("llama.context_length");
|
||||
hparams.n_embd = read_u32("llama.embedding_length");
|
||||
uint32_t n_ff = read_u32("llama.feed_forward_length");
|
||||
GGML_UNUSED(n_ff);
|
||||
//hparams.n_mult = find_n_mult(n_ff, hparams.n_embd);
|
||||
hparams.n_ff = read_u32("llama.feed_forward_length");
|
||||
hparams.n_head = read_u32("llama.attention.head_count");
|
||||
hparams.n_layer = read_u32("llama.layer_count");
|
||||
hparams.n_rot = hparams.n_embd / hparams.n_head;
|
||||
//hparams.ftype = (enum llama_ftype) file.read_u32();
|
||||
|
||||
hparams.n_rot = read_u32("llama.rope.dimension_count");
|
||||
hparams.f_rms_norm_eps = read_f32("llama.attention.layer_norm_rms_epsilon");
|
||||
|
||||
// LLaMAv2
|
||||
// hparams.n_head_kv = read_u32("llama.attention.head_count_kv");
|
||||
}
|
||||
@ -1125,6 +1118,7 @@ static void llama_model_load_internal(
|
||||
bool vocab_only,
|
||||
llama_progress_callback progress_callback,
|
||||
void * progress_callback_user_data) {
|
||||
GGML_UNUSED(rms_norm_eps); // TODO: update function signature to remove this
|
||||
|
||||
model.t_start_us = ggml_time_us();
|
||||
|
||||
@ -1137,9 +1131,6 @@ static void llama_model_load_internal(
|
||||
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
// TODO: read from file
|
||||
hparams.f_rms_norm_eps = rms_norm_eps;
|
||||
|
||||
{
|
||||
switch (hparams.n_layer) {
|
||||
case 26: model.type = e_model::MODEL_3B; break;
|
||||
@ -1162,25 +1153,19 @@ static void llama_model_load_internal(
|
||||
if (model.type == e_model::MODEL_65B && n_gqa == 8) {
|
||||
fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
|
||||
model.type = e_model::MODEL_70B;
|
||||
hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model
|
||||
}
|
||||
}
|
||||
|
||||
hparams.rope_freq_base = rope_freq_base;
|
||||
hparams.rope_freq_scale = rope_freq_scale;
|
||||
}
|
||||
|
||||
// ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199
|
||||
const uint32_t n_ff_raw = 2*(4*hparams.n_embd)/3;
|
||||
const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw;
|
||||
const uint32_t n_ff = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
|
||||
//const uint32_t n_ff = 28672;
|
||||
|
||||
const uint32_t n_ff = hparams.n_ff;
|
||||
|
||||
{
|
||||
fprintf(stderr, "%s: format = %s\n", __func__, gguf_file_version_name(file_version));
|
||||
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
|
||||
fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx);
|
||||
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
|
||||
fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
|
||||
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
|
||||
fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
|
||||
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
|
||||
|
Loading…
Reference in New Issue
Block a user