mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
gguf : get rid of n_mult, read n_ff from file
This commit is contained in:
parent
f44bbd3d88
commit
e732423280
@ -177,15 +177,12 @@ struct llama_hparams {
|
|||||||
uint32_t n_vocab = 32000;
|
uint32_t n_vocab = 32000;
|
||||||
uint32_t n_ctx = 512; // this is provided as user input?
|
uint32_t n_ctx = 512; // this is provided as user input?
|
||||||
uint32_t n_embd = 4096;
|
uint32_t n_embd = 4096;
|
||||||
uint32_t n_mult = 256;
|
|
||||||
uint32_t n_head = 32;
|
uint32_t n_head = 32;
|
||||||
uint32_t n_head_kv = 32;
|
uint32_t n_head_kv = 32;
|
||||||
uint32_t n_layer = 32;
|
uint32_t n_layer = 32;
|
||||||
uint32_t n_rot = 64;
|
uint32_t n_rot = 64;
|
||||||
|
uint32_t n_ff = 11008;
|
||||||
|
|
||||||
// LLaMAv2
|
|
||||||
// TODO: load from model data hparams
|
|
||||||
float f_ffn_mult = 1.0f;
|
|
||||||
float f_rms_norm_eps = LLAMA_DEFAULT_RMS_EPS;
|
float f_rms_norm_eps = LLAMA_DEFAULT_RMS_EPS;
|
||||||
|
|
||||||
float rope_freq_base = 10000.0f;
|
float rope_freq_base = 10000.0f;
|
||||||
@ -467,7 +464,7 @@ struct llama_load_tensors_map {
|
|||||||
};
|
};
|
||||||
|
|
||||||
enum gguf_file_version {
|
enum gguf_file_version {
|
||||||
GGUF_FILE_VERSION_V1
|
GGUF_FILE_VERSION_V1 = 1,
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -490,6 +487,7 @@ struct ggml_context * ctx_data = NULL;
|
|||||||
};
|
};
|
||||||
|
|
||||||
gguf_ctx = gguf_init_from_file(fname, params);
|
gguf_ctx = gguf_init_from_file(fname, params);
|
||||||
|
file_version = (enum gguf_file_version) gguf_get_version(gguf_ctx);
|
||||||
|
|
||||||
read_hparams();
|
read_hparams();
|
||||||
read_vocab();
|
read_vocab();
|
||||||
@ -505,6 +503,15 @@ struct ggml_context * ctx_data = NULL;
|
|||||||
return gguf_get_val_u32(gguf_ctx, i);
|
return gguf_get_val_u32(gguf_ctx, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float read_f32(const char * key) {
|
||||||
|
int i = gguf_find_key(gguf_ctx, key);
|
||||||
|
if (i == -1) {
|
||||||
|
throw std::runtime_error(format("cannot find param with key %s\n", key));
|
||||||
|
}
|
||||||
|
|
||||||
|
return gguf_get_val_f32(gguf_ctx, i);
|
||||||
|
}
|
||||||
|
|
||||||
int read_n_vocab() {
|
int read_n_vocab() {
|
||||||
int i = gguf_find_key(gguf_ctx, "tokenizer.ggml.tokens");
|
int i = gguf_find_key(gguf_ctx, "tokenizer.ggml.tokens");
|
||||||
if (i == -1) {
|
if (i == -1) {
|
||||||
@ -514,18 +521,6 @@ struct ggml_context * ctx_data = NULL;
|
|||||||
return gguf_get_arr_n(gguf_ctx, i);
|
return gguf_get_arr_n(gguf_ctx, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
int find_n_mult(const int n_ff, const int n_embd) {
|
|
||||||
int n_mults[3] = {8192, 1, -1};
|
|
||||||
for (int i = 0; i < 3; ++i) {
|
|
||||||
int calc_ff = (((8 * n_embd) / 3 + n_mults[i] - 1) / n_mults[i]) * n_mults[i];
|
|
||||||
if (calc_ff == n_ff) {
|
|
||||||
return n_mults[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
throw std::runtime_error(format("failed to find n_mult for n_ff = %d and n_embd = %d\n", n_ff, n_embd));
|
|
||||||
}
|
|
||||||
|
|
||||||
void read_hparams() {
|
void read_hparams() {
|
||||||
|
|
||||||
// TODO make keysconstants in header
|
// TODO make keysconstants in header
|
||||||
@ -533,14 +528,12 @@ struct ggml_context * ctx_data = NULL;
|
|||||||
hparams.n_vocab = read_n_vocab();
|
hparams.n_vocab = read_n_vocab();
|
||||||
hparams.n_ctx = read_u32("llama.context_length");
|
hparams.n_ctx = read_u32("llama.context_length");
|
||||||
hparams.n_embd = read_u32("llama.embedding_length");
|
hparams.n_embd = read_u32("llama.embedding_length");
|
||||||
uint32_t n_ff = read_u32("llama.feed_forward_length");
|
hparams.n_ff = read_u32("llama.feed_forward_length");
|
||||||
GGML_UNUSED(n_ff);
|
|
||||||
//hparams.n_mult = find_n_mult(n_ff, hparams.n_embd);
|
|
||||||
hparams.n_head = read_u32("llama.attention.head_count");
|
hparams.n_head = read_u32("llama.attention.head_count");
|
||||||
hparams.n_layer = read_u32("llama.layer_count");
|
hparams.n_layer = read_u32("llama.layer_count");
|
||||||
hparams.n_rot = hparams.n_embd / hparams.n_head;
|
hparams.n_rot = read_u32("llama.rope.dimension_count");
|
||||||
//hparams.ftype = (enum llama_ftype) file.read_u32();
|
hparams.f_rms_norm_eps = read_f32("llama.attention.layer_norm_rms_epsilon");
|
||||||
|
|
||||||
// LLaMAv2
|
// LLaMAv2
|
||||||
// hparams.n_head_kv = read_u32("llama.attention.head_count_kv");
|
// hparams.n_head_kv = read_u32("llama.attention.head_count_kv");
|
||||||
}
|
}
|
||||||
@ -1125,6 +1118,7 @@ static void llama_model_load_internal(
|
|||||||
bool vocab_only,
|
bool vocab_only,
|
||||||
llama_progress_callback progress_callback,
|
llama_progress_callback progress_callback,
|
||||||
void * progress_callback_user_data) {
|
void * progress_callback_user_data) {
|
||||||
|
GGML_UNUSED(rms_norm_eps); // TODO: update function signature to remove this
|
||||||
|
|
||||||
model.t_start_us = ggml_time_us();
|
model.t_start_us = ggml_time_us();
|
||||||
|
|
||||||
@ -1137,9 +1131,6 @@ static void llama_model_load_internal(
|
|||||||
|
|
||||||
auto & hparams = model.hparams;
|
auto & hparams = model.hparams;
|
||||||
|
|
||||||
// TODO: read from file
|
|
||||||
hparams.f_rms_norm_eps = rms_norm_eps;
|
|
||||||
|
|
||||||
{
|
{
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 26: model.type = e_model::MODEL_3B; break;
|
case 26: model.type = e_model::MODEL_3B; break;
|
||||||
@ -1162,25 +1153,19 @@ static void llama_model_load_internal(
|
|||||||
if (model.type == e_model::MODEL_65B && n_gqa == 8) {
|
if (model.type == e_model::MODEL_65B && n_gqa == 8) {
|
||||||
fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
|
fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa);
|
||||||
model.type = e_model::MODEL_70B;
|
model.type = e_model::MODEL_70B;
|
||||||
hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model
|
}
|
||||||
}
|
|
||||||
|
|
||||||
hparams.rope_freq_base = rope_freq_base;
|
hparams.rope_freq_base = rope_freq_base;
|
||||||
hparams.rope_freq_scale = rope_freq_scale;
|
hparams.rope_freq_scale = rope_freq_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
// ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199
|
const uint32_t n_ff = hparams.n_ff;
|
||||||
const uint32_t n_ff_raw = 2*(4*hparams.n_embd)/3;
|
|
||||||
const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw;
|
|
||||||
const uint32_t n_ff = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult;
|
|
||||||
//const uint32_t n_ff = 28672;
|
|
||||||
|
|
||||||
{
|
{
|
||||||
fprintf(stderr, "%s: format = %s\n", __func__, gguf_file_version_name(file_version));
|
fprintf(stderr, "%s: format = %s\n", __func__, gguf_file_version_name(file_version));
|
||||||
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
|
fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab);
|
||||||
fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx);
|
fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx);
|
||||||
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
|
fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd);
|
||||||
fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult);
|
|
||||||
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
|
fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head);
|
||||||
fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
|
fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
|
||||||
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
|
fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);
|
||||||
|
Loading…
Reference in New Issue
Block a user