llama.cpp : print kv general.name

This commit is contained in:
klosax 2023-08-18 01:06:27 +02:00 committed by GitHub
parent d9e6890a51
commit 306070c896
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1329,6 +1329,8 @@ static void llama_model_load_internal(
auto & hparams = model.hparams; auto & hparams = model.hparams;
std::string general_name = "n/a";
// read hparams // read hparams
{ {
struct gguf_context * ctx = ml->ctx_gguf; struct gguf_context * ctx = ml->ctx_gguf;
@ -1347,6 +1349,10 @@ static void llama_model_load_internal(
} \ } \
} }
// get general kv
GGUF_GET(general_name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name");
// get hparams kv
GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens"); GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens");
GGUF_GET(hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length"); GGUF_GET(hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length");
GGUF_GET(hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length"); GGUF_GET(hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length");
@ -1359,6 +1365,7 @@ static void llama_model_load_internal(
// n_head_kv is optional, default to n_head // n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head; hparams.n_head_kv = hparams.n_head;
GGUF_GET(hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv"); GGUF_GET(hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv");
#undef GGUF_GET #undef GGUF_GET
switch (hparams.n_layer) { switch (hparams.n_layer) {
@ -1422,22 +1429,24 @@ static void llama_model_load_internal(
} }
{ {
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->file_version)); LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, general_name.c_str());
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->file_version));
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd);
LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head);
LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv);
LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer);
LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim
LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa());
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps);
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff);
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base);
LLAMA_LOG_INFO("%s: model size = %.2fB\n", __func__, ml->n_elements*1e-9); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale);
LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type));
LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml->n_elements*1e-9);
// TODO: print number of tensors for each quantization // TODO: print number of tensors for each quantization
} }