From e732423280ff72ece7bc8dfe8ec1fa7e3153714c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=2E=20Yusuf=20Sar=C4=B1g=C3=B6z?= Date: Fri, 11 Aug 2023 23:50:38 +0300 Subject: [PATCH] gguf : get rid of n_mult, read n_ff from file --- gguf-llama.cpp | 55 ++++++++++++++++++-------------------------------- 1 file changed, 20 insertions(+), 35 deletions(-) diff --git a/gguf-llama.cpp b/gguf-llama.cpp index a8bd242b0..40d5ffd14 100644 --- a/gguf-llama.cpp +++ b/gguf-llama.cpp @@ -177,15 +177,12 @@ struct llama_hparams { uint32_t n_vocab = 32000; uint32_t n_ctx = 512; // this is provided as user input? uint32_t n_embd = 4096; - uint32_t n_mult = 256; uint32_t n_head = 32; uint32_t n_head_kv = 32; uint32_t n_layer = 32; uint32_t n_rot = 64; + uint32_t n_ff = 11008; - // LLaMAv2 - // TODO: load from model data hparams - float f_ffn_mult = 1.0f; float f_rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; float rope_freq_base = 10000.0f; @@ -467,7 +464,7 @@ struct llama_load_tensors_map { }; enum gguf_file_version { - GGUF_FILE_VERSION_V1 + GGUF_FILE_VERSION_V1 = 1, }; @@ -490,6 +487,7 @@ struct ggml_context * ctx_data = NULL; }; gguf_ctx = gguf_init_from_file(fname, params); + file_version = (enum gguf_file_version) gguf_get_version(gguf_ctx); read_hparams(); read_vocab(); @@ -505,6 +503,15 @@ struct ggml_context * ctx_data = NULL; return gguf_get_val_u32(gguf_ctx, i); } + float read_f32(const char * key) { + int i = gguf_find_key(gguf_ctx, key); + if (i == -1) { + throw std::runtime_error(format("cannot find param with key %s\n", key)); + } + + return gguf_get_val_f32(gguf_ctx, i); + } + int read_n_vocab() { int i = gguf_find_key(gguf_ctx, "tokenizer.ggml.tokens"); if (i == -1) { @@ -514,18 +521,6 @@ struct ggml_context * ctx_data = NULL; return gguf_get_arr_n(gguf_ctx, i); } - int find_n_mult(const int n_ff, const int n_embd) { - int n_mults[3] = {8192, 1, -1}; - for (int i = 0; i < 3; ++i) { - int calc_ff = (((8 * n_embd) / 3 + n_mults[i] - 1) / n_mults[i]) * n_mults[i]; - if (calc_ff == n_ff) { - return n_mults[i]; - } - } - - throw std::runtime_error(format("failed to find n_mult for n_ff = %d and n_embd = %d\n", n_ff, n_embd)); - } - void read_hparams() { // TODO make keysconstants in header @@ -533,14 +528,12 @@ struct ggml_context * ctx_data = NULL; hparams.n_vocab = read_n_vocab(); hparams.n_ctx = read_u32("llama.context_length"); hparams.n_embd = read_u32("llama.embedding_length"); - uint32_t n_ff = read_u32("llama.feed_forward_length"); - GGML_UNUSED(n_ff); - //hparams.n_mult = find_n_mult(n_ff, hparams.n_embd); + hparams.n_ff = read_u32("llama.feed_forward_length"); hparams.n_head = read_u32("llama.attention.head_count"); hparams.n_layer = read_u32("llama.layer_count"); - hparams.n_rot = hparams.n_embd / hparams.n_head; - //hparams.ftype = (enum llama_ftype) file.read_u32(); - + hparams.n_rot = read_u32("llama.rope.dimension_count"); + hparams.f_rms_norm_eps = read_f32("llama.attention.layer_norm_rms_epsilon"); + // LLaMAv2 // hparams.n_head_kv = read_u32("llama.attention.head_count_kv"); } @@ -1125,6 +1118,7 @@ static void llama_model_load_internal( bool vocab_only, llama_progress_callback progress_callback, void * progress_callback_user_data) { + GGML_UNUSED(rms_norm_eps); // TODO: update function signature to remove this model.t_start_us = ggml_time_us(); @@ -1137,9 +1131,6 @@ static void llama_model_load_internal( auto & hparams = model.hparams; - // TODO: read from file - hparams.f_rms_norm_eps = rms_norm_eps; - { switch (hparams.n_layer) { case 26: model.type = e_model::MODEL_3B; break; @@ -1162,25 +1153,19 @@ static void llama_model_load_internal( if (model.type == e_model::MODEL_65B && n_gqa == 8) { fprintf(stderr, "%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa); model.type = e_model::MODEL_70B; - hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model - } + } hparams.rope_freq_base = rope_freq_base; hparams.rope_freq_scale = rope_freq_scale; } - // ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199 - const uint32_t n_ff_raw = 2*(4*hparams.n_embd)/3; - const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw; - const uint32_t n_ff = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; - //const uint32_t n_ff = 28672; - + const uint32_t n_ff = hparams.n_ff; + { fprintf(stderr, "%s: format = %s\n", __func__, gguf_file_version_name(file_version)); fprintf(stderr, "%s: n_vocab = %u\n", __func__, hparams.n_vocab); fprintf(stderr, "%s: n_ctx = %u\n", __func__, hparams.n_ctx); fprintf(stderr, "%s: n_embd = %u\n", __func__, hparams.n_embd); - fprintf(stderr, "%s: n_mult = %u\n", __func__, hparams.n_mult); fprintf(stderr, "%s: n_head = %u\n", __func__, hparams.n_head); fprintf(stderr, "%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); fprintf(stderr, "%s: n_layer = %u\n", __func__, hparams.n_layer);