From 95385241a91a616788a3bb76d12c9b7b2379ca2d Mon Sep 17 00:00:00 2001 From: Olivier Chafik Date: Wed, 23 Aug 2023 20:33:05 +0100 Subject: [PATCH] examples : restore the functionality to import llama2.c models (#2685) * Fix import of llama2.c models that don't share weights between embedding layers * llama2c: reinstate ggmlv3 conversion output + update readme w/ gguf conv * llama2.c: comment out legacy "load from ggml model" logic * llama2.c: convert special-cased "<0xXX>" single byte tokens from tokenizer.bin --- examples/convert-llama2c-to-ggml/README.md | 14 +- .../convert-llama2c-to-ggml.cpp | 234 ++++++++++-------- 2 files changed, 144 insertions(+), 104 deletions(-) diff --git a/examples/convert-llama2c-to-ggml/README.md b/examples/convert-llama2c-to-ggml/README.md index 868f57d6d..fd561fcbc 100644 --- a/examples/convert-llama2c-to-ggml/README.md +++ b/examples/convert-llama2c-to-ggml/README.md @@ -12,15 +12,19 @@ usage: ./convert-llama2c-to-ggml [options] options: -h, --help show this help message and exit - --copy-vocab-from-model FNAME model path from which to copy vocab (default 'models/ggml-vocab.bin') + --copy-vocab-from-model FNAME model path from which to copy vocab (default 'tokenizer.bin') --llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model --llama2c-output-model FNAME model path to save the converted llama2.c model (default ak_llama_model.bin') ``` -An example command is as follows: +An example command using a model from [karpathy/tinyllamas](https://huggingface.co/karpathy/tinyllamas) is as follows: -`$ ./convert-llama2c-to-ggml --copy-vocab-from-model --llama2c-model --llama2c-output-model ` +`$ ./convert-llama2c-to-ggml --copy-vocab-from-model ../llama2.c/tokenizer.bin --llama2c-model stories42M.bin --llama2c-output-model stories42M.ggmlv3.bin` -Now you can use the model with command like: +For now the generated model is in the legacy GGJTv3 format, so you need to convert it to gguf manually: -`$ ./main -m -p "One day, Lily met a Shoggoth" -n 500 -c 256 -eps 1e-5` +`$ python ./convert-llama-ggmlv3-to-gguf.py --eps 1e-5 --input stories42M.ggmlv3.bin --output stories42M.gguf.bin` + +Now you can use the model with a command like: + +`$ ./main -m stories42M.gguf.bin -p "One day, Lily met a Shoggoth" -n 500 -c 256` diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 469d6e3de..1551a85cd 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -17,6 +17,9 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif +#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt' +#define LLAMA_FILE_VERSION_GGJT_V3 3 + //////////////////////////////////////// llama2.c model structs and functions to load models, alloc memory etc. typedef struct { int dim; // transformer dimension @@ -49,10 +52,10 @@ typedef struct { // float* freq_cis_real; // (seq_len, dim/2) // float* freq_cis_imag; // (seq_len, dim/2) // (optional) classifier weights for the logits, on the last layer - //float* wcls; + float* wcls; } TransformerWeights; -void malloc_weights(TransformerWeights* w, Config* p) { +void malloc_weights(TransformerWeights* w, Config* p, bool shared_weights) { // we calloc instead of malloc to keep valgrind happy w->token_embedding_table = new float[p->vocab_size * p->dim](); printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->token_embedding_table\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); @@ -86,9 +89,16 @@ void malloc_weights(TransformerWeights* w, Config* p) { w->rms_final_weight = new float[p->dim](); printf("[%s:AK] Allocating [%d] float space for w->rms_final_weight\n",__func__,p->dim); + + if (shared_weights) { + w->wcls = NULL; + } else { + w->wcls = new float[p->vocab_size * p->dim](); + printf("[%s:AK] Allocating [%d] x [%d] = [%d] float space for w->wcls\n",__func__,p->vocab_size , p->dim, p->vocab_size * p->dim); + } } -int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) { +int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f, bool shared_weights) { if (fread(w->token_embedding_table, sizeof(float), p->vocab_size * p->dim, f) != static_cast(p->vocab_size * p->dim)) return 1; if (fread(w->rms_att_weight, sizeof(float), p->n_layers * p->dim, f) != static_cast(p->n_layers * p->dim)) return 1; if (fread(w->wq, sizeof(float), p->n_layers * p->dim * p->dim, f) != static_cast(p->n_layers * p->dim * p->dim)) return 1; @@ -100,6 +110,22 @@ int checkpoint_init_weights(TransformerWeights *w, Config* p, FILE* f) { if (fread(w->w2, sizeof(float), p->n_layers * p->hidden_dim * p->dim, f) != static_cast(p->n_layers * p->hidden_dim * p->dim)) return 1; if (fread(w->w3, sizeof(float), p->n_layers * p->dim * p->hidden_dim, f) != static_cast(p->n_layers * p->dim * p->hidden_dim)) return 1; if (fread(w->rms_final_weight, sizeof(float), p->dim, f) != static_cast(p->dim)) return 1; + + // Skip freq_cis_real & freq_cis_imag + int head_size = p->dim / p->n_heads; + fseek(f, p->seq_len * head_size * sizeof(float), SEEK_CUR); + + if (!shared_weights && fread(w->wcls, sizeof(float), p->vocab_size * p->dim, f) != static_cast(p->vocab_size * p->dim)) return 1; + + // Check we didn't forget to read anything + auto curr = ftell(f); + fseek(f, 0, SEEK_END); + auto end = ftell(f); + if (curr != end) { + printf("Error: failed to read the checkpoint file to the end (curr = %ld, end = %ld)\n", curr, end); + return 1; + } + return 0; } @@ -115,6 +141,7 @@ void free_weights(TransformerWeights* w) { delete w->w2; delete w->w3; delete w->rms_final_weight; + if (w->wcls) delete w->wcls; } void print_sample_weights(TransformerWeights *w){ @@ -131,6 +158,7 @@ void print_sample_weights(TransformerWeights *w){ printf("%f\n", w->w2[0]); printf("%f\n", w->w3[0]); printf("%f\n", w->rms_att_weight[0]); + if (w->wcls) printf("%f\n", w->wcls[0]); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -509,26 +537,28 @@ bool is_ggml_file(const char *filename) { } void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) { - // heuristic to infer whether vocab is from ggml or from llama2.c vocabulary - if (is_ggml_file(filename)) { - - struct llama_context_params llama_params = llama_context_default_params(); - llama_params.vocab_only = true; - - struct llama_model * lmodel = llama_load_model_from_file(filename, llama_params); - struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params); - - const int n_vocab = llama_n_vocab(lctx); - vocab->id_to_token.resize(n_vocab); - for (int i=0; iid_to_token[i].text = llama_token_get_text(lctx, i); - vocab->id_to_token[i].score = llama_token_get_score(lctx, i); - vocab->id_to_token[i].type = llama_token_get_type(lctx, i); - vocab->token_to_id.emplace(vocab->id_to_token[i].text, i); - } - llama_free(lctx); - llama_free_model(lmodel); - } else { // assume llama2.c vocabulary +#pragma message("TODO: implement reading vocabulary using gguf") +// // heuristic to infer whether vocab is from ggml or from llama2.c vocabulary +// if (is_ggml_file(filename)) { +// +// struct llama_context_params llama_params = llama_context_default_params(); +// llama_params.vocab_only = true; +// +// struct llama_model * lmodel = llama_load_model_from_file(filename, llama_params); +// struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params); +// +// const int n_vocab = llama_n_vocab(lctx); +// vocab->id_to_token.resize(n_vocab); +// for (int i=0; iid_to_token[i].text = llama_token_get_text(lctx, i); +// vocab->id_to_token[i].score = llama_token_get_score(lctx, i); +// vocab->id_to_token[i].type = llama_token_get_type(lctx, i); +// vocab->token_to_id.emplace(vocab->id_to_token[i].text, i); +// } +// llama_free(lctx); +// llama_free_model(lmodel); +// } else + { // assume llama2.c vocabulary printf("Assuming llama2.c vocabulary since %s is not a ggml file\n", filename); llama_file file(filename, "rb"); const int n_vocab = config->vocab_size; @@ -538,6 +568,12 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) float_t score = file.read_f32(); uint32_t len = file.read_u32(); std::string text = file.read_string(len); + // Special-case handling of <0xXX> single byte tokens. + char byte_val; + if (sscanf(text.c_str(), "<0x%02hhX>", &byte_val) == 1) { + char cstr[2] = { byte_val, 0 }; + text = cstr; + } vocab->id_to_token[i].text = text; vocab->id_to_token[i].score = score; vocab->id_to_token[i].type = LLAMA_TOKEN_TYPE_UNDEFINED; @@ -589,83 +625,80 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod } #pragma message("TODO: implement file saving using gguf") - (void) vocab; - (void) model; - (void) w; -// // write_magic -// file.write_u32(LLAMA_FILE_MAGIC); // magic -// file.write_u32(LLAMA_FILE_VERSION); // version -// // write_hparams -// file.write_u32(model->hparams.n_vocab); -// file.write_u32(model->hparams.n_embd); -// file.write_u32(model->hparams.n_mult); -// file.write_u32(model->hparams.n_head); -// file.write_u32(model->hparams.n_layer); -// file.write_u32(model->hparams.n_rot); -// file.write_u32(LLAMA_FTYPE_ALL_F32); -// -// // write_vocab - for now we are just writing the existing BPE voc. assuming karpathy's vocabulary is the same. idk. -// uint32_t n_vocab = model->hparams.n_vocab; -// for (uint32_t i = 0; i < n_vocab; i++) { -// const auto & token_data = vocab->id_to_token.at(i); -// file.write_u32((uint32_t) token_data.tok.size()); -// file.write_raw(token_data.tok.data(), token_data.tok.size()); -// file.write_raw(&token_data.score, sizeof(token_data.score)); -// } -// -// // stuff AK weights into GG weights one by one. -// // w->token_embedding_table -> model->tok_embeddings -// // float* -> struct ggml_tensor -// stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table); -// stuff_karpathy_weights_into_gg(model->output, w->token_embedding_table); -// -// stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight); -// //print_row(model->norm, 0); -// -// // for rms-att-weight -// int row_length = model->hparams.n_embd; -// const auto & hparams = model->hparams; -// //int n_ff = model->hparams.n_embd; -// int n_ff = get_n_ff(&hparams); -// -// for (uint32_t i = 0; i < model->hparams.n_layer; ++i){ -// auto & layer = model->layers[i]; -// // 1d -// stuff_karpathy_weights_into_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]); -// stuff_karpathy_weights_into_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]); -// -// // from 3d matrix layer x dim x dim to 2d matrix dim x dim -// stuff_karpathy_weights_into_gg(layer.wq , &w->wq[i*row_length*row_length]); -// stuff_karpathy_weights_into_gg(layer.wk , &w->wk[i*row_length*row_length]); -// stuff_karpathy_weights_into_gg(layer.wv , &w->wv[i*row_length*row_length]); -// stuff_karpathy_weights_into_gg(layer.wo , &w->wo[i*row_length*row_length]); -// -// stuff_karpathy_weights_into_gg(layer.w1 , &w->w1[i*row_length*n_ff]); -// stuff_karpathy_weights_into_gg(layer.w2 , &w->w2[i*n_ff*row_length]); -// stuff_karpathy_weights_into_gg(layer.w3 , &w->w3[i*row_length*n_ff]); -// } -// // write tensors -// write_tensor(&file, model->tok_embeddings); -// write_tensor(&file, model->norm); -// write_tensor(&file, model->output); // ? -// for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { -// auto & layer = model->layers[i]; -// -// write_tensor(&file, layer.attention_norm); -// write_tensor(&file, layer.wq); -// write_tensor(&file, layer.wk); -// write_tensor(&file, layer.wv); -// write_tensor(&file, layer.wo); -// write_tensor(&file, layer.ffn_norm); -// write_tensor(&file, layer.w1); -// write_tensor(&file, layer.w2); -// write_tensor(&file, layer.w3); -// } + // write_magic + file.write_u32(LLAMA_FILE_MAGIC_GGJT); // magic + file.write_u32(LLAMA_FILE_VERSION_GGJT_V3); // version + // write_hparams + file.write_u32(model->hparams.n_vocab); + file.write_u32(model->hparams.n_embd); + file.write_u32(model->hparams.n_mult); + file.write_u32(model->hparams.n_head); + file.write_u32(model->hparams.n_layer); + file.write_u32(model->hparams.n_rot); + file.write_u32(LLAMA_FTYPE_ALL_F32); + + // write_vocab - for now we are just writing the existing BPE voc. assuming karpathy's vocabulary is the same. idk. + uint32_t n_vocab = model->hparams.n_vocab; + for (uint32_t i = 0; i < n_vocab; i++) { + const auto & token_data = vocab->id_to_token.at(i); + file.write_u32((uint32_t) token_data.text.size()); + file.write_raw(token_data.text.data(), token_data.text.size()); + file.write_raw(&token_data.score, sizeof(token_data.score)); + } + + // stuff AK weights into GG weights one by one. + // w->token_embedding_table -> model->tok_embeddings + // float* -> struct ggml_tensor + stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table); + stuff_karpathy_weights_into_gg(model->output, w->wcls ? w->wcls : w->token_embedding_table); + + stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight); + //print_row(model->norm, 0); + + // for rms-att-weight + int row_length = model->hparams.n_embd; + const auto & hparams = model->hparams; + //int n_ff = model->hparams.n_embd; + int n_ff = get_n_ff(&hparams); + + for (uint32_t i = 0; i < model->hparams.n_layer; ++i){ + auto & layer = model->layers[i]; + // 1d + stuff_karpathy_weights_into_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]); + stuff_karpathy_weights_into_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]); + + // from 3d matrix layer x dim x dim to 2d matrix dim x dim + stuff_karpathy_weights_into_gg(layer.wq , &w->wq[i*row_length*row_length]); + stuff_karpathy_weights_into_gg(layer.wk , &w->wk[i*row_length*row_length]); + stuff_karpathy_weights_into_gg(layer.wv , &w->wv[i*row_length*row_length]); + stuff_karpathy_weights_into_gg(layer.wo , &w->wo[i*row_length*row_length]); + + stuff_karpathy_weights_into_gg(layer.w1 , &w->w1[i*row_length*n_ff]); + stuff_karpathy_weights_into_gg(layer.w2 , &w->w2[i*n_ff*row_length]); + stuff_karpathy_weights_into_gg(layer.w3 , &w->w3[i*row_length*n_ff]); + } + // write tensors + write_tensor(&file, model->tok_embeddings); + write_tensor(&file, model->norm); + write_tensor(&file, model->output); // ? + for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { + auto & layer = model->layers[i]; + + write_tensor(&file, layer.attention_norm); + write_tensor(&file, layer.wq); + write_tensor(&file, layer.wk); + write_tensor(&file, layer.wv); + write_tensor(&file, layer.wo); + write_tensor(&file, layer.ffn_norm); + write_tensor(&file, layer.w1); + write_tensor(&file, layer.w2); + write_tensor(&file, layer.w3); + } } struct train_params get_default_train_params() { struct train_params params; - params.fn_vocab_model = "models/ggml-vocab.bin"; + params.fn_vocab_model = "tokenizer.bin"; params.fn_llama2c_output_model = "ak_llama_model.bin"; params.fn_train_data = "shakespeare.txt"; params.fn_checkpoint_in = "checkpoint.bin"; @@ -718,7 +751,7 @@ void print_usage(int /*argc*/, char ** argv, const struct train_params * params) fprintf(stderr, "\n"); fprintf(stderr, "options:\n"); fprintf(stderr, " -h, --help show this help message and exit\n"); - fprintf(stderr, " --copy-vocab-from-model FNAME llama2.c vocabulary or ggml model path from which to copy vocab (default '%s')\n", params->fn_vocab_model); + fprintf(stderr, " --copy-vocab-from-model FNAME llama2.c vocabulary or ggmlv3 model path from which to copy vocab (default '%s')\n", params->fn_vocab_model); fprintf(stderr, " --llama2c-model FNAME [REQUIRED] model path from which to load Karpathy's llama2.c model\n"); fprintf(stderr, " --llama2c-output-model FNAME model path to save the converted llama2.c model (default %s')\n", params->fn_llama2c_output_model); fprintf(stderr, "\n"); @@ -791,9 +824,12 @@ int main(int argc, char ** argv) { if (!file) { printf("Unable to open the checkpoint file %s!\n", params.fn_llama2c_model); return 1; } // read in the config header if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; } + auto shared_weights = config.vocab_size > 0; + config.vocab_size = abs(config.vocab_size); + // read in the Transformer weights - malloc_weights(&weights, &config); - if(checkpoint_init_weights(&weights, &config, file)) { return 1; } + malloc_weights(&weights, &config, shared_weights); + if(checkpoint_init_weights(&weights, &config, file, shared_weights)) { return 1; } fclose(file); }