diff --git a/.gitignore b/.gitignore index 9c749f1ef..a4df837a4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.o *.a *.so +*.gguf *.bin .DS_Store .build/ @@ -47,6 +48,8 @@ models-mnt /server /Pipfile /embd-input-test +/gguf +/gguf-llama-simple /libllama.so /llama-bench build-info.h @@ -65,7 +68,6 @@ perf-*.txt examples/jeopardy/results.txt - pyproject.toml poetry.lock poetry.toml diff --git a/CMakeLists.txt b/CMakeLists.txt index 824d9f2cf..bb63ef98e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -497,9 +497,11 @@ else() endif() # -# Build libraries +# libraries # +# ggml + add_library(ggml OBJECT ggml.c ggml.h @@ -524,10 +526,11 @@ if (BUILD_SHARED_LIBS) install(TARGETS ggml_shared LIBRARY) endif() +# llama + add_library(llama llama.cpp llama.h - llama-util.h ) target_include_directories(llama PUBLIC .) @@ -546,6 +549,10 @@ if (BUILD_SHARED_LIBS) install(TARGETS llama LIBRARY) endif() +# +# install +# + include(GNUInstallDirs) install( FILES convert.py @@ -584,6 +591,8 @@ endif() # programs, examples and tests # +add_subdirectory(common) + if (LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) include(CTest) add_subdirectory(tests) diff --git a/Makefile b/Makefile index 502781c69..d31acc450 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test llama-bench +BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch convert-llama2c-to-ggml simple server embd-input-test gguf llama-bench # Binaries only useful for tests TEST_TARGETS = tests/test-llama-grammar tests/test-grammar-parser tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0 @@ -45,8 +45,8 @@ OPT = -Ofast else OPT = -O3 endif -CFLAGS = -I. $(OPT) -std=c11 -fPIC -CXXFLAGS = -I. -I./examples $(OPT) -std=c++11 -fPIC +CFLAGS = -I. $(OPT) -std=c11 -fPIC +CXXFLAGS = -I. -I./common $(OPT) -std=c++11 -fPIC LDFLAGS = ifdef LLAMA_DEBUG @@ -329,23 +329,23 @@ ggml-alloc.o: ggml-alloc.c ggml.h ggml-alloc.h OBJS += ggml-alloc.o -llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h llama-util.h +llama.o: llama.cpp ggml.h ggml-alloc.h ggml-cuda.h ggml-metal.h llama.h $(CXX) $(CXXFLAGS) -c $< -o $@ -common.o: examples/common.cpp examples/common.h +common.o: common/common.cpp common/common.h $(CXX) $(CXXFLAGS) -c $< -o $@ -console.o: examples/console.cpp examples/console.h +console.o: common/console.cpp common/console.h $(CXX) $(CXXFLAGS) -c $< -o $@ -grammar-parser.o: examples/grammar-parser.cpp examples/grammar-parser.h +grammar-parser.o: common/grammar-parser.cpp common/grammar-parser.h $(CXX) $(CXXFLAGS) -c $< -o $@ libllama.so: llama.o ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) clean: - rm -vf *.o *.so *.dll main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server simple vdot train-text-from-scratch convert-llama2c-to-ggml embd-input-test llama-bench build-info.h $(TEST_TARGETS) + rm -vf *.o *.so *.dll main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server simple vdot train-text-from-scratch convert-llama2c-to-ggml embd-input-test gguf llama-bench build-info.h $(TEST_TARGETS) # # Examples @@ -385,7 +385,10 @@ $(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-in embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %$(DSO_EXT),$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput -train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp build-info.h ggml.o llama.o $(OBJS) +gguf: examples/gguf/gguf.cpp build-info.h ggml.o llama.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) + +train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) convert-llama2c-to-ggml: examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp build-info.h ggml.o llama.o $(OBJS) @@ -418,7 +421,7 @@ vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) tests/test-llama-grammar: tests/test-llama-grammar.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) -tests/test-grammar-parser: tests/test-grammar-parser.cpp examples/grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS) +tests/test-grammar-parser: tests/test-grammar-parser.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) tests/test-double-float: tests/test-double-float.cpp build-info.h ggml.o llama.o common.o $(OBJS) diff --git a/README.md b/README.md index 9f8512dc5..82e070ac3 100644 --- a/README.md +++ b/README.md @@ -9,11 +9,17 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ -### 🚧 Incoming breaking change + refactoring: +### Hot topics -See PR https://github.com/ggerganov/llama.cpp/pull/2398 for more info. +A new file format has been introduced: [GGUF](https://github.com/ggerganov/llama.cpp/pull/2398) -To devs: avoid making big changes to `llama.h` / `llama.cpp` until merged +Last revision compatible with the old format: [dadbed9](https://github.com/ggerganov/llama.cpp/commit/dadbed99e65252d79f81101a392d0d6497b86caa) + +### Current `master` should be considered in Beta - expect some issues for a few days! + +### Be prepared to re-convert and / or re-quantize your GGUF models while this notice is up! + +### Issues with non-GGUF models will be considered with low priority! ---- @@ -291,7 +297,7 @@ When built with Metal support, you can enable GPU inference with the `--gpu-laye Any value larger than 0 will offload the computation to the GPU. For example: ```bash -./main -m ./models/7B/ggml-model-q4_0.bin -n 128 -ngl 1 +./main -m ./models/7B/ggml-model-q4_0.gguf -n 128 -ngl 1 ``` ### MPI Build @@ -330,7 +336,7 @@ The above will distribute the computation across 2 processes on the first host a Finally, you're ready to run a computation using `mpirun`: ```bash -mpirun -hostfile hostfile -n 3 ./main -m ./models/7B/ggml-model-q4_0.bin -n 128 +mpirun -hostfile hostfile -n 3 ./main -m ./models/7B/ggml-model-q4_0.gguf -n 128 ``` ### BLAS Build @@ -513,10 +519,10 @@ python3 convert.py models/7B/ python convert.py models/7B/ --vocabtype bpe # quantize the model to 4-bits (using q4_0 method) -./quantize ./models/7B/ggml-model-f16.bin ./models/7B/ggml-model-q4_0.bin q4_0 +./quantize ./models/7B/ggml-model-f16.gguf ./models/7B/ggml-model-q4_0.gguf q4_0 # run the inference -./main -m ./models/7B/ggml-model-q4_0.bin -n 128 +./main -m ./models/7B/ggml-model-q4_0.gguf -n 128 ``` When running the larger models, make sure you have enough disk space to store all the intermediate files. @@ -572,7 +578,7 @@ Here is an example of a few-shot interaction, invoked with the command ./examples/chat-13B.sh # custom arguments using a 13B model -./main -m ./models/13B/ggml-model-q4_0.bin -n 256 --repeat_penalty 1.0 --color -i -r "User:" -f prompts/chat-with-bob.txt +./main -m ./models/13B/ggml-model-q4_0.gguf -n 256 --repeat_penalty 1.0 --color -i -r "User:" -f prompts/chat-with-bob.txt ``` Note the use of `--color` to distinguish between user input and generated text. Other parameters are explained in more detail in the [README](examples/main/README.md) for the `main` example program. @@ -635,6 +641,8 @@ OpenLLaMA is an openly licensed reproduction of Meta's original LLaMA model. It ### Using [GPT4All](https://github.com/nomic-ai/gpt4all) +*Note: these instructions are likely obsoleted by the GGUF update* + - Obtain the `tokenizer.model` file from LLaMA model and put it to `models` - Obtain the `added_tokens.json` file from Alpaca model and put it to `models` - Obtain the `gpt4all-lora-quantized.bin` file from GPT4All model and put it to `models/gpt4all-7B` @@ -710,7 +718,7 @@ If your issue is with model generation quality, then please at least scan the fo #### How to run 1. Download/extract: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research -2. Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` +2. Run `./perplexity -m models/7B/ggml-model-q4_0.gguf -f wiki.test.raw` 3. Output: ``` perplexity : calculating perplexity over 655 chunks @@ -809,13 +817,13 @@ docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:full --all-in- On completion, you are ready to play! ```bash -docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -n 512 +docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:full --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 ``` or with a light image: ```bash -docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:light -m /models/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -n 512 +docker run -v /path/to/models:/models ghcr.io/ggerganov/llama.cpp:light -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 ``` ### Docker With CUDA @@ -846,8 +854,8 @@ The resulting images, are essentially the same as the non-CUDA images: After building locally, Usage is similar to the non-CUDA examples, but you'll need to add the `--gpus` flag. You will also want to use the `--n-gpu-layers` flag. ```bash -docker run --gpus all -v /path/to/models:/models local/llama.cpp:full-cuda --run -m /models/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 -docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /models/7B/ggml-model-q4_0.bin -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 +docker run --gpus all -v /path/to/models:/models local/llama.cpp:full-cuda --run -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 +docker run --gpus all -v /path/to/models:/models local/llama.cpp:light-cuda -m /models/7B/ggml-model-q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 512 --n-gpu-layers 1 ``` ### Contributing diff --git a/ci/run.sh b/ci/run.sh index 8dc394964..54ba6d710 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -159,17 +159,17 @@ function gg_run_open_llama_3b_v2 { python3 ../convert.py ${path_models} - model_f16="${path_models}/ggml-model-f16.bin" - model_q8_0="${path_models}/ggml-model-q8_0.bin" - model_q4_0="${path_models}/ggml-model-q4_0.bin" - model_q4_1="${path_models}/ggml-model-q4_1.bin" - model_q5_0="${path_models}/ggml-model-q5_0.bin" - model_q5_1="${path_models}/ggml-model-q5_1.bin" - model_q2_k="${path_models}/ggml-model-q2_k.bin" - model_q3_k="${path_models}/ggml-model-q3_k.bin" - model_q4_k="${path_models}/ggml-model-q4_k.bin" - model_q5_k="${path_models}/ggml-model-q5_k.bin" - model_q6_k="${path_models}/ggml-model-q6_k.bin" + model_f16="${path_models}/ggml-model-f16.gguf" + model_q8_0="${path_models}/ggml-model-q8_0.gguf" + model_q4_0="${path_models}/ggml-model-q4_0.gguf" + model_q4_1="${path_models}/ggml-model-q4_1.gguf" + model_q5_0="${path_models}/ggml-model-q5_0.gguf" + model_q5_1="${path_models}/ggml-model-q5_1.gguf" + model_q2_k="${path_models}/ggml-model-q2_k.gguf" + model_q3_k="${path_models}/ggml-model-q3_k.gguf" + model_q4_k="${path_models}/ggml-model-q4_k.gguf" + model_q5_k="${path_models}/ggml-model-q5_k.gguf" + model_q6_k="${path_models}/ggml-model-q6_k.gguf" wiki_test_60="${path_wiki}/wiki.test-60.raw" @@ -285,17 +285,17 @@ function gg_run_open_llama_7b_v2 { python3 ../convert.py ${path_models} - model_f16="${path_models}/ggml-model-f16.bin" - model_q8_0="${path_models}/ggml-model-q8_0.bin" - model_q4_0="${path_models}/ggml-model-q4_0.bin" - model_q4_1="${path_models}/ggml-model-q4_1.bin" - model_q5_0="${path_models}/ggml-model-q5_0.bin" - model_q5_1="${path_models}/ggml-model-q5_1.bin" - model_q2_k="${path_models}/ggml-model-q2_k.bin" - model_q3_k="${path_models}/ggml-model-q3_k.bin" - model_q4_k="${path_models}/ggml-model-q4_k.bin" - model_q5_k="${path_models}/ggml-model-q5_k.bin" - model_q6_k="${path_models}/ggml-model-q6_k.bin" + model_f16="${path_models}/ggml-model-f16.gguf" + model_q8_0="${path_models}/ggml-model-q8_0.gguf" + model_q4_0="${path_models}/ggml-model-q4_0.gguf" + model_q4_1="${path_models}/ggml-model-q4_1.gguf" + model_q5_0="${path_models}/ggml-model-q5_0.gguf" + model_q5_1="${path_models}/ggml-model-q5_1.gguf" + model_q2_k="${path_models}/ggml-model-q2_k.gguf" + model_q3_k="${path_models}/ggml-model-q3_k.gguf" + model_q4_k="${path_models}/ggml-model-q4_k.gguf" + model_q5_k="${path_models}/ggml-model-q5_k.gguf" + model_q6_k="${path_models}/ggml-model-q6_k.gguf" wiki_test="${path_wiki}/wiki.test.raw" diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt new file mode 100644 index 000000000..dead56118 --- /dev/null +++ b/common/CMakeLists.txt @@ -0,0 +1,20 @@ +# common + +set(TARGET common) + +add_library(${TARGET} OBJECT + common.h + common.cpp + console.h + console.cpp + grammar-parser.h + grammar-parser.cpp + ) + +if (BUILD_SHARED_LIBS) + set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) +endif() + +target_include_directories(${TARGET} PUBLIC .) +target_compile_features(${TARGET} PUBLIC cxx_std_11) +target_link_libraries(${TARGET} PRIVATE llama) diff --git a/examples/common.cpp b/common/common.cpp similarity index 92% rename from examples/common.cpp rename to common/common.cpp index bd39d9220..d7e1a5725 100644 --- a/examples/common.cpp +++ b/common/common.cpp @@ -170,18 +170,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.n_ctx = std::stoi(argv[i]); - } else if (arg == "-gqa" || arg == "--gqa") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.n_gqa = std::stoi(argv[i]); - } else if (arg == "-eps" || arg == "--rms-norm-eps") { - if (++i >= argc) { - invalid_param = true; - break; - } - params.rms_norm_eps = std::stof(argv[i]); } else if (arg == "--rope-freq-base") { if (++i >= argc) { invalid_param = true; @@ -439,7 +427,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } params.hellaswag_tasks = std::stoi(argv[i]); } else if (arg == "--ignore-eos") { - params.logit_bias[llama_token_eos()] = -INFINITY; + params.ignore_eos = true; } else if (arg == "--no-penalize-nl") { params.penalize_nl = false; } else if (arg == "-l" || arg == "--logit-bias") { @@ -561,8 +549,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); - fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); - fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps); fprintf(stdout, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); fprintf(stdout, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); fprintf(stdout, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); @@ -650,24 +636,15 @@ std::string gpt_random_prompt(std::mt19937 & rng) { return "The"; } -// TODO: not great allocating this every time -std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { - // initialize to prompt numer of chars, since n_tokens <= n_prompt_chars - std::vector res(text.size() + (int) add_bos); - const int n = llama_tokenize(ctx, text.c_str(), res.data(), res.size(), add_bos); - assert(n >= 0); - res.resize(n); - - return res; -} +// +// Model utils +// struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; lparams.n_batch = params.n_batch; - lparams.n_gqa = params.n_gqa; - lparams.rms_norm_eps = params.rms_norm_eps; lparams.n_gpu_layers = params.n_gpu_layers; lparams.main_gpu = params.main_gpu; lparams.tensor_split = params.tensor_split; @@ -685,7 +662,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param return lparams; } -std::tuple llama_init_from_gpt_params(const gpt_params & params) { +std::tuple llama_init_from_gpt_params(gpt_params & params) { auto lparams = llama_context_params_from_gpt_params(params); llama_model * model = llama_load_model_from_file(params.model.c_str(), lparams); @@ -714,5 +691,77 @@ std::tuple llama_init_from_gpt_par } } + if (params.ignore_eos) { + params.logit_bias[llama_token_eos(lctx)] = -INFINITY; + } + return std::make_tuple(model, lctx); } + +// +// Vocab utils +// + +std::vector llama_tokenize( + struct llama_context * ctx, + const std::string & text, + bool add_bos) { + // upper limit for the number of tokens + int n_tokens = text.length() + add_bos; + std::vector result(n_tokens); + n_tokens = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize(ctx, text.c_str(), result.data(), result.size(), add_bos); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +std::string llama_token_to_str(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int n_tokens = llama_token_to_str(ctx, token, result.data(), result.size()); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_token_to_str(ctx, token, result.data(), result.size()); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + + return std::string(result.data(), result.size()); +} + +std::vector llama_tokenize_bpe( + struct llama_context * ctx, + const std::string & text, + bool add_bos) { + int n_tokens = text.length() + add_bos; + std::vector result(n_tokens); + n_tokens = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_tokenize_bpe(ctx, text.c_str(), result.data(), result.size(), add_bos); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + return result; +} + +std::string llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int n_tokens = llama_token_to_str_bpe(ctx, token, result.data(), result.size()); + if (n_tokens < 0) { + result.resize(-n_tokens); + const int check = llama_token_to_str_bpe(ctx, token, result.data(), result.size()); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + + return std::string(result.data(), result.size()); +} + diff --git a/examples/common.h b/common/common.h similarity index 88% rename from examples/common.h rename to common/common.h index 375bc0a3d..c50a6edfc 100644 --- a/examples/common.h +++ b/common/common.h @@ -22,19 +22,16 @@ struct gpt_params { int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 512; // context size int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS) - int32_t n_gqa = 1; // grouped-query attention factor (TODO: move to hparams) int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_gpu_layers = 0; // number of layers to store in VRAM int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. - float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; // rms norm epsilon float rope_freq_base = 10000.0f; // RoPE base frequency float rope_freq_scale = 1.0f; // RoPE frequency scaling factor // sampling parameters - std::unordered_map logit_bias; // logit bias for specific tokens int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled @@ -48,12 +45,14 @@ struct gpt_params { float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate + std::unordered_map logit_bias; // logit bias for specific tokens + // Classifier-Free Guidance // https://arxiv.org/abs/2306.17806 std::string cfg_negative_prompt; // string to help guidance float cfg_scale = 1.f; // How strong is guidance - std::string model = "models/7B/ggml-model.bin"; // model path + std::string model = "models/7B/ggml-model-f16.gguf"; // model path std::string model_alias = "unknown"; // model alias std::string prompt = ""; std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state @@ -83,6 +82,7 @@ struct gpt_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix + bool ignore_eos = false; // ignore generated EOS tokens bool instruct = false; // instruction mode (used for Alpaca models) bool penalize_nl = true; // consider newlines as a repeatable token bool perplexity = false; // compute perplexity over the prompt @@ -100,15 +100,31 @@ void gpt_print_usage(int argc, char ** argv, const gpt_params & params); std::string gpt_random_prompt(std::mt19937 & rng); -// -// Vocab utils -// - -std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos); - // // Model utils // -std::tuple llama_init_from_gpt_params(const gpt_params & params); +std::tuple llama_init_from_gpt_params(gpt_params & params); struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params); + +// +// Vocab utils +// + +std::vector llama_tokenize( + struct llama_context * ctx, + const std::string & text, + bool add_bos); + +std::vector llama_tokenize_bpe( + struct llama_context * ctx, + const std::string & text, + bool add_bos); + +std::string llama_token_to_str( + const struct llama_context * ctx, + llama_token token); + +std::string llama_token_to_str_bpe( + const struct llama_context * ctx, + llama_token token); diff --git a/examples/console.cpp b/common/console.cpp similarity index 100% rename from examples/console.cpp rename to common/console.cpp diff --git a/examples/console.h b/common/console.h similarity index 100% rename from examples/console.h rename to common/console.h diff --git a/examples/grammar-parser.cpp b/common/grammar-parser.cpp similarity index 100% rename from examples/grammar-parser.cpp rename to common/grammar-parser.cpp diff --git a/examples/grammar-parser.h b/common/grammar-parser.h similarity index 100% rename from examples/grammar-parser.h rename to common/grammar-parser.h diff --git a/convert-falcon-hf-to-gguf.py b/convert-falcon-hf-to-gguf.py new file mode 100644 index 000000000..b3e190a0f --- /dev/null +++ b/convert-falcon-hf-to-gguf.py @@ -0,0 +1,282 @@ +# HF falcon--> gguf conversion + +import gguf +import os +import sys +import struct +import json +import numpy as np +import torch + +from typing import Any, List +from pathlib import Path +from transformers import AutoTokenizer + +def bytes_to_unicode(): + # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def count_model_parts(dir_model: str) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.startswith("pytorch_model-"): + num_parts += 1 + + if num_parts > 0: + print("gguf: found " + str(num_parts) + " model parts") + return num_parts + + +if len(sys.argv) < 3: + print("Usage: convert-h5-to-ggml.py dir-model ftype\n") + print(" ftype == 0 -> float32") + print(" ftype == 1 -> float16") + sys.exit(1) + + +# output in the same directory as the model +dir_model = sys.argv[1] +last_dir = os.path.basename(os.path.normpath(dir_model)) + +# possible tensor data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 + +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if len(sys.argv) > 2: + ftype = int(sys.argv[2]) + if ftype < 0 or ftype > 1: + print("Invalid ftype: " + str(ftype)) + + sys.exit(1) + +fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" + +print("gguf: loading model "+last_dir) + +with open(dir_model + "/config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +if hparams["architectures"][0] != "RWForCausalLM": + print("Model architecture not supported: " + hparams["architectures"][0]) + + sys.exit() + +# get number of model parts +num_parts = count_model_parts(dir_model) + +ARCH=gguf.MODEL_ARCH.FALCON +gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) + +print("gguf: get model metadata") + +block_count = hparams["n_layer"] + +gguf_writer.add_name(last_dir) +gguf_writer.add_context_length(2048) # not in config.json +gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform +gguf_writer.add_embedding_length(hparams["hidden_size"]) +gguf_writer.add_feed_forward_length(4 * hparams["hidden_size"]) +gguf_writer.add_block_count(block_count) +gguf_writer.add_head_count(hparams["n_head"]) +if "n_head_kv" in hparams: gguf_writer.add_head_count_kv(hparams["n_head_kv"]) +gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"]) + +# TOKENIZATION + +print("gguf: get tokenizer metadata") + +tokens: List[str] = [] +merges: List[str] = [] + + +if Path(dir_model + "/tokenizer.json").is_file(): + # gpt2 tokenizer + gguf_writer.add_tokenizer_model("gpt2") + + print("gguf: get gpt2 tokenizer merges") + + with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: + tokenizer_json = json.load(f) + merges = tokenizer_json["model"]["merges"] + + gguf_writer.add_token_merges(merges) + + print("gguf: get gpt2 tokenizer vocab") + + vocab_size = len(tokenizer_json["model"]["vocab"]) + + # ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py + tokenizer = AutoTokenizer.from_pretrained(dir_model) + + reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} + byte_encoder = bytes_to_unicode() + byte_decoder = {v: k for k, v in byte_encoder.items()} + + for i in range(vocab_size): + if i in reverse_vocab: + try: + text = bytearray([byte_decoder[c] for c in reverse_vocab[i]]) + except KeyError: + text = bytearray() + for c in reverse_vocab[i]: + if ord(c) < 256: # single byte character + text.append(byte_decoder[ord(c)]) + else: # multibyte special token character + text.extend(c.encode('utf-8')) + else: + print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.") + pad_token = f"[PAD{i}]".encode("utf8") + text = bytearray(pad_token) + + tokens.append(text) + + gguf_writer.add_token_list(tokens) + + if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file(): + print("gguf: get special token ids") + + with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_config = json.load(f) + + # find special token ids + + if "bos_token" in tokenizer_config: + for key in tokenizer_json["added_tokens"]: + if key["content"] == tokenizer_config["bos_token"]: + gguf_writer.add_bos_token_id(key["id"]) + + if "eos_token" in tokenizer_config: + for key in tokenizer_json["added_tokens"]: + if key["content"] == tokenizer_config["eos_token"]: + gguf_writer.add_eos_token_id(key["id"]) + + if "unk_token" in tokenizer_config: + for key in tokenizer_json["added_tokens"]: + if key["content"] == tokenizer_config["unk_token"]: + gguf_writer.add_unk_token_id(key["id"]) + + if "sep_token" in tokenizer_config: + for key in tokenizer_json["added_tokens"]: + if key["content"] == tokenizer_config["sep_token"]: + gguf_writer.add_sep_token_id(key["id"]) + + if "pad_token" in tokenizer_config: + for key in tokenizer_json["added_tokens"]: + if key["content"] == tokenizer_config["pad_token"]: + gguf_writer.add_pad_token_id(key["id"]) + + +# TENSORS + +tensor_map = gguf.get_tensor_name_map(ARCH,block_count) + +# params for qkv transform +n_head = hparams["n_head"] +n_head_kv = hparams["n_head_kv"] if "n_head_kv" in hparams else 1 +head_dim = hparams["hidden_size"] // n_head + +# tensor info +print("gguf: get tensor metadata") + +if num_parts == 0: + part_names = ("pytorch_model.bin",) +else: + part_names = ( + f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) + ) + +for part_name in part_names: + print("gguf: loading model part '" + part_name + "'") + model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") + + for name in model_part.keys(): + data = model_part[name] + + old_dtype = data.dtype + + # convert any unsupported data types to float32 + if data.dtype != torch.float16 and data.dtype != torch.float32: + data = data.to(torch.float32) + + # QKV tensor transform + # The original query_key_value tensor contains n_head_kv "kv groups", + # each consisting of n_head/n_head_kv query weights followed by one key + # and one value weight (shared by all query heads in the kv group). + # This layout makes it a big pain to work with in GGML. + # So we rearrange them here,, so that we have n_head query weights + # followed by n_head_kv key weights followed by n_head_kv value weights, + # in contiguous fashion. + # ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py + + if "query_key_value" in name: + qkv = data.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head) + q = qkv[:, :-2 ].reshape(n_head * head_dim, head_dim * n_head) + k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head) + v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head) + data = torch.cat((q,k,v)).reshape_as(data) + + data = data.squeeze().numpy() + + # map tensor names + if name.endswith(".weight") and name[:-7] in tensor_map: + name = tensor_map[name[:-7]] + ".weight" + elif name.endswith(".bias") and name[:-5] in tensor_map: + name = tensor_map[name[:-5]] + ".bias" + else: + print("Can not map tensor '" + name + "'") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) + + gguf_writer.add_tensor(name, data) + + +print("gguf: write header") +gguf_writer.write_header_to_file() +print("gguf: write metadata") +gguf_writer.write_kv_data_to_file() +print("gguf: write tensors") +gguf_writer.write_tensors_to_file() + +gguf_writer.close() + +print("gguf: model successfully exported to '" + fname_out + "'") +print("") diff --git a/convert-gptneox-hf-to-gguf.py b/convert-gptneox-hf-to-gguf.py new file mode 100644 index 000000000..a7cefc6f3 --- /dev/null +++ b/convert-gptneox-hf-to-gguf.py @@ -0,0 +1,266 @@ +# HF gptneox--> gguf conversion + +import gguf +import os +import sys +import struct +import json +import numpy as np +import torch + +from typing import Any, List +from pathlib import Path +from transformers import AutoTokenizer + +# ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py + + +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a significant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8+n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def count_model_parts(dir_model: str) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.startswith("pytorch_model-"): + num_parts += 1 + + if num_parts > 0: + print("gguf: found " + str(num_parts) + " model parts") + return num_parts + + +if len(sys.argv) < 3: + print("Usage: convert-h5-to-ggml.py dir-model ftype\n") + print(" ftype == 0 -> float32") + print(" ftype == 1 -> float16") + sys.exit(1) + + +# output in the same directory as the model +dir_model = sys.argv[1] +last_dir = os.path.basename(os.path.normpath(dir_model)) + +# possible tensor data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 + +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if len(sys.argv) > 2: + ftype = int(sys.argv[2]) + if ftype < 0 or ftype > 1: + print("Invalid ftype: " + str(ftype)) + + sys.exit(1) + +fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" + +print("gguf: loading model "+last_dir) + +with open(dir_model + "/config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +if hparams["architectures"][0] != "GPTNeoXForCausalLM": + print("Model architecture not supported: " + hparams["architectures"][0]) + + sys.exit() + +# get number of model parts +num_parts = count_model_parts(dir_model) + +ARCH=gguf.MODEL_ARCH.GPTNEOX +gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) + +print("gguf: get model metadata") + +block_count = hparams["num_hidden_layers"] + +gguf_writer.add_name(last_dir) +gguf_writer.add_context_length(hparams["max_position_embeddings"]) +gguf_writer.add_embedding_length(hparams["hidden_size"]) +gguf_writer.add_block_count(block_count) +gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) +gguf_writer.add_rope_dimension_count(int(hparams["rotary_pct"]*(hparams["hidden_size"]//hparams["num_attention_heads"]))) +gguf_writer.add_head_count(hparams["num_attention_heads"]) +gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True) +gguf_writer.add_layer_norm_eps(hparams["layer_norm_eps"]) + +# TOKENIZATION + +print("gguf: get tokenizer metadata") + +tokens: List[str] = [] +merges: List[str] = [] + + +if Path(dir_model + "/tokenizer.json").is_file(): + # gpt2 tokenizer + gguf_writer.add_tokenizer_model("gpt2") + + print("gguf: get gpt2 tokenizer merges") + + with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: + tokenizer_json = json.load(f) + merges = tokenizer_json["model"]["merges"] + + gguf_writer.add_token_merges(merges) + + print("gguf: get gpt2 tokenizer vocab") + + vocab_size = len(tokenizer_json["model"]["vocab"]) + + # ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py + tokenizer = AutoTokenizer.from_pretrained(dir_model) + + reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} + byte_encoder = bytes_to_unicode() + byte_decoder = {v: k for k, v in byte_encoder.items()} + + for i in range(vocab_size): + if i in reverse_vocab: + try: + text = bytearray([byte_decoder[c] for c in reverse_vocab[i]]) + except KeyError: + text = bytearray() + for c in reverse_vocab[i]: + if ord(c) < 256: # single byte character + text.append(byte_decoder[ord(c)]) + else: # multibyte special token character + text.extend(c.encode('utf-8')) + else: + print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.") + pad_token = f"[PAD{i}]".encode("utf8") + text = bytearray(pad_token) + + tokens.append(text) + + gguf_writer.add_token_list(tokens) + + if "added_tokens" in tokenizer_json and Path(dir_model + "/tokenizer_config.json").is_file(): + print("gguf: get special token ids") + + with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_config = json.load(f) + + # find special token ids + + if "bos_token" in tokenizer_config: + for key in tokenizer_json["added_tokens"]: + if key["content"] == tokenizer_config["bos_token"]: + gguf_writer.add_bos_token_id(key["id"]) + + if "eos_token" in tokenizer_config: + for key in tokenizer_json["added_tokens"]: + if key["content"] == tokenizer_config["eos_token"]: + gguf_writer.add_eos_token_id(key["id"]) + + if "unk_token" in tokenizer_config: + for key in tokenizer_json["added_tokens"]: + if key["content"] == tokenizer_config["unk_token"]: + gguf_writer.add_unk_token_id(key["id"]) + + if "sep_token" in tokenizer_config: + for key in tokenizer_json["added_tokens"]: + if key["content"] == tokenizer_config["sep_token"]: + gguf_writer.add_sep_token_id(key["id"]) + + if "pad_token" in tokenizer_config: + for key in tokenizer_json["added_tokens"]: + if key["content"] == tokenizer_config["pad_token"]: + gguf_writer.add_pad_token_id(key["id"]) + + +# TENSORS + +tensor_map = gguf.get_tensor_name_map(ARCH,block_count) + +# tensor info +print("gguf: get tensor metadata") + +if num_parts == 0: + part_names = ("pytorch_model.bin",) +else: + part_names = ( + f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) + ) + +for part_name in part_names: + print("gguf: loading model part '" + part_name + "'") + model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") + + for name in model_part.keys(): + data = model_part[name] + + # we don't need these + if name.endswith(".attention.masked_bias") or name.endswith(".attention.bias") or name.endswith(".attention.rotary_emb.inv_freq"): + continue + + old_dtype = data.dtype + + # convert any unsupported data types to float32 + if data.dtype != torch.float16 and data.dtype != torch.float32: + data = data.to(torch.float32) + + data = data.squeeze().numpy() + + # map tensor names + if name.endswith(".weight") and name[:-7] in tensor_map: + name = tensor_map[name[:-7]] + ".weight" + elif name.endswith(".bias") and name[:-5] in tensor_map: + name = tensor_map[name[:-5]] + ".bias" + else: + print("Can not map tensor '" + name + "'") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) + + gguf_writer.add_tensor(name, data) + + +print("gguf: write header") +gguf_writer.write_header_to_file() +print("gguf: write metadata") +gguf_writer.write_kv_data_to_file() +print("gguf: write tensors") +gguf_writer.write_tensors_to_file() + +gguf_writer.close() + +print("gguf: model successfully exported to '" + fname_out + "'") +print("") diff --git a/convert-llama-7b-pth-to-gguf.py b/convert-llama-7b-pth-to-gguf.py new file mode 100644 index 000000000..ab5c80b69 --- /dev/null +++ b/convert-llama-7b-pth-to-gguf.py @@ -0,0 +1,307 @@ +# 7b pth llama --> gguf conversion +# Only models with a single datafile are supported, like 7B +# HF files required in the model dir: config.json tokenizer_config.json tokenizer.json tokenizer.model + +import gguf +import os +import sys +import struct +import json +import numpy as np +import torch + +from typing import Any, List +from pathlib import Path +from sentencepiece import SentencePieceProcessor + +#NDArray = np.ndarray[Any, Any] +# compatible with python < 3.9 +NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' + + +def count_model_parts(dir_model: str) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.startswith("consolidated."): + num_parts += 1 + + if num_parts > 0: + print("gguf: found " + str(num_parts) + " model parts") + return num_parts + + +if len(sys.argv) < 3: + print("Usage: convert-h5-to-ggml.py dir-model ftype\n") + print(" ftype == 0 -> float32") + print(" ftype == 1 -> float16") + + sys.exit(1) + + +# output in the same directory as the model +dir_model = sys.argv[1] +last_dir = os.path.basename(os.path.normpath(dir_model)) + + +# possible tensor data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 + +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if len(sys.argv) > 2: + ftype = int(sys.argv[2]) + if ftype < 0 or ftype > 1: + print("Invalid ftype: " + str(ftype)) + + sys.exit(1) + +fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" + +print("gguf: loading model "+last_dir) + +with open(dir_model + "/config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +if hparams["architectures"][0] != "LlamaForCausalLM": + print("Model architecture not supported: " + hparams["architectures"][0]) + sys.exit() + +# get number of model parts +num_parts = count_model_parts(dir_model) + +if num_parts > 1: + print("gguf: Only models with a single datafile are supported.") + + sys.exit() + +ARCH=gguf.MODEL_ARCH.LLAMA +gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) + + +print("gguf: get model metadata") + +block_count = hparams["num_hidden_layers"] +head_count = hparams["num_attention_heads"] + +if "num_key_value_heads" in hparams: + head_count_kv = hparams["num_key_value_heads"] +else: + head_count_kv = head_count + +if "_name_or_path" in hparams: + hf_repo = hparams["_name_or_path"] +else: + hf_repo = "" + +if "max_sequence_length" in hparams: + ctx_length = hparams["max_sequence_length"] +elif "max_position_embeddings" in hparams: + ctx_length = hparams["max_position_embeddings"] +else: + print("gguf: can not find ctx length parameter.") + + sys.exit() + + +gguf_writer.add_name(last_dir) +gguf_writer.add_source_hf_repo(hf_repo) +gguf_writer.add_tensor_data_layout("Meta AI original pth") +gguf_writer.add_context_length(ctx_length) +gguf_writer.add_embedding_length(hparams["hidden_size"]) +gguf_writer.add_block_count(block_count) +gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) +gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) +gguf_writer.add_head_count(head_count) +gguf_writer.add_head_count_kv(head_count_kv) +gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + +if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]: + if "type" in hparams["rope_scaling"]: + if hparams["rope_scaling"]["type"] == "linear": + gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"]) + + +# TOKENIZATION + +print("gguf: get tokenizer metadata") + +tokens: List[bytes] = [] +scores: List[float] = [] +toktypes: List[int] = [] + +if Path(dir_model + "/tokenizer.model").is_file(): + # vocab type sentencepiece + print("gguf: get sentencepiece tokenizer vocab and scores") + + tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model") + + for i in range(tokenizer.vocab_size()): + text: bytes + score: float + + piece = tokenizer.id_to_piece(i) + text = piece.encode("utf-8") + score = tokenizer.get_score(i) + + toktype = 1 # defualt to normal token type + if tokenizer.is_unknown(i): + toktype = 2 + if tokenizer.is_control(i): + toktype = 3 + + # toktype = 4 is user-defined = tokens from added_tokens.json + + if tokenizer.is_unused(i): + toktype = 5 + if tokenizer.is_byte(i): + toktype = 6 + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + if Path(dir_model + "/added_tokens.json").is_file(): + with open(dir_model + "/added_tokens.json", "r", encoding="utf-8") as f: + addtokens_json = json.load(f) + + print("gguf: get added tokens") + + for key in addtokens_json: + tokens.append( key.encode("utf-8") ) + scores.append(-1000.0) + toktypes.append(4) # user-defined token type + + gguf_writer.add_tokenizer_model("llama") + gguf_writer.add_token_list(tokens) + gguf_writer.add_token_scores(scores) + gguf_writer.add_token_types(toktypes) + + +print("gguf: get special token ids") + +if Path(dir_model + "/tokenizer.json").is_file(): + # Look for special tokens in tokenizer.json if it exists + + with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: + tokenizer = json.load(f) + + if "added_tokens" in tokenizer and Path(dir_model + "/tokenizer_config.json").is_file(): + + with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_config = json.load(f) + + if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["bos_token"]["content"]: + gguf_writer.add_bos_token_id(key["id"]) + + if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["eos_token"]["content"]: + gguf_writer.add_eos_token_id(key["id"]) + + if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["unk_token"]["content"]: + gguf_writer.add_unk_token_id(key["id"]) + + if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["sep_token"]["content"]: + gguf_writer.add_sep_token_id(key["id"]) + + if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["pad_token"]["content"]: + gguf_writer.add_pad_token_id(key["id"]) +else: + # If no tokenizer.json: Look for special tokens in config.json + + if "bos_token_id" in hparams and hparams["bos_token_id"] != None: + gguf_writer.add_bos_token_id(hparams["bos_token_id"]) + + if "eos_token_id" in hparams and hparams["eos_token_id"] != None: + gguf_writer.add_eos_token_id(hparams["eos_token_id"]) + + if "unk_token_id" in hparams and hparams["unk_token_id"] != None: + gguf_writer.add_unk_token_id(hparams["unk_token_id"]) + + if "sep_token_id" in hparams and hparams["sep_token_id"] != None: + gguf_writer.add_sep_token_id(hparams["sep_token_id"]) + + if "pad_token_id" in hparams and hparams["pad_token_id"] != None: + gguf_writer.add_pad_token_id(hparams["pad_token_id"]) + + +# TENSORS + +tensor_map = gguf.get_tensor_name_map(ARCH,block_count) + +# tensor info +print("gguf: get tensor metadata") + +part_names = (f"consolidated.{n:02}.pth" for n in range(0, num_parts)) + +for part_name in part_names: + print("gguf: loading model part '" + part_name + "'") + model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") + + for name in model_part.keys(): + data = model_part[name] + + # we don't need these + if name == "rope.freqs": + continue + + old_dtype = data.dtype + + # convert any unsupported data types to float32 + if data.dtype != torch.float16 and data.dtype != torch.float32: + data = data.to(torch.float32) + + data = data.squeeze().numpy() + + # map tensor names + if name.endswith(".weight") and name[:-7] in tensor_map: + name = tensor_map[name[:-7]] + ".weight" + elif name.endswith(".bias") and name[:-5] in tensor_map: + name = tensor_map[name[:-5]] + ".bias" + else: + print("Can not map tensor '" + name + "'") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) + + gguf_writer.add_tensor(name, data) + + +print("gguf: write header") +gguf_writer.write_header_to_file() +print("gguf: write metadata") +gguf_writer.write_kv_data_to_file() +print("gguf: write tensors") +gguf_writer.write_tensors_to_file() + +gguf_writer.close() + + +print("gguf: model successfully exported to '" + fname_out + "'") +print("") diff --git a/convert-llama-ggmlv3-to-gguf.py b/convert-llama-ggmlv3-to-gguf.py new file mode 100644 index 000000000..30038072f --- /dev/null +++ b/convert-llama-ggmlv3-to-gguf.py @@ -0,0 +1,334 @@ +import sys, struct, math, argparse +from pathlib import Path + +import numpy as np + +import gguf + +# Note: Does not support GGML_QKK_64 +QK_K = 256 +# Items here are (block size, type size) +GGML_QUANT_SIZES = { + gguf.GGMLQuantizationType.F32 : (1, 4), + gguf.GGMLQuantizationType.F16 : (1, 2), + gguf.GGMLQuantizationType.Q4_0 : (32, 2 + 16), + gguf.GGMLQuantizationType.Q4_1 : (32, 2 + 2 + 16), + gguf.GGMLQuantizationType.Q5_0 : (32, 2 + 4 + 16), + gguf.GGMLQuantizationType.Q5_1 : (32, 2 + 2 + 4 + 16), + gguf.GGMLQuantizationType.Q8_0 : (32, 2 + 32), + gguf.GGMLQuantizationType.Q8_1 : (32, 4 + 4 + 32), + gguf.GGMLQuantizationType.Q2_K : (256, 2 + 2 + QK_K // 16 + QK_K // 4), + gguf.GGMLQuantizationType.Q3_K : (256, 2 + QK_K // 4 + QK_K // 8 + 12), + gguf.GGMLQuantizationType.Q4_K : (256, 2 + 2 + QK_K // 2 + 12), + gguf.GGMLQuantizationType.Q5_K : (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), + gguf.GGMLQuantizationType.Q6_K : (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), + gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8), +} + +class Hyperparameters: + def __init__(self): + self.n_vocab = self.n_embd = self.n_mult = self.n_head = self.n_layer = self.n_rot = self.ftype = 0 + self.n_ff = 0 + + def set_n_ff(self, model): + ff_tensor_idx = model.tensor_map.get(b'layers.0.feed_forward.w1.weight') + assert ff_tensor_idx is not None, 'Missing layer 0 FF tensor' + ff_tensor = model.tensors[ff_tensor_idx] + self.n_ff = ff_tensor.dims[1] + + def load(self, data, offset): + ( + self.n_vocab, + self.n_embd, + self.n_mult, + self.n_head, + self.n_layer, + self.n_rot, + self.ftype, + ) = struct.unpack('<7I', data[offset:offset + (4 * 7)]) + return 4 * 7 + + def __str__(self): + return f'' + +class Vocab: + def __init__(self): + self.items = [] + + def load(self, data, offset, n_vocab): + orig_offset = offset + for _ in range(n_vocab): + itemlen = struct.unpack('= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}' + assert name_len < 4096, 'Absurd tensor name length' + quant = GGML_QUANT_SIZES.get(dtype) + assert quant is not None, 'Unknown tensor type' + (blksize, tysize) = quant + offset += 12 + self.dtype= dtype + self.dims = struct.unpack(f'<{n_dims}I', data[offset:offset + (4 * n_dims)]) + offset += 4 * n_dims + self.name = bytes(data[offset:offset + name_len]) + offset += name_len + pad = ((offset + 31) & ~31) - offset + offset += pad + n_elems = np.prod(self.dims) + n_bytes = (n_elems * tysize) // blksize + self.start_offset = offset + self.len_bytes = n_bytes + offset += n_bytes + # print(n_dims, name_len, dtype, self.dims, self.name, pad) + return offset - orig_offset + +class GGMLV3Model: + def __init__(self): + self.hyperparameters = None + self.vocab = None + self.tensor_map = {} + self.tensors = [] + + def validate_header(self, data, offset): + if bytes(data[offset:offset + 4]) != b'tjgg' or struct.unpack(' 0: + gguf_writer.add_token_types(toktypes) + return + print(f'* Adding {hp.n_vocab} vocab item(s)') + for (tokid, (vbytes, vscore)) in enumerate(self.model.vocab.items): + tt = 1 # Normal + if len(vbytes) == 0: + tt = 3 # Control + elif tokid >= 3 and tokid <= 258 and len(vbytes) == 1: + hv = hex(vbytes[0])[2:].upper() + vbytes = bytes(f'<0x{hv}>', encoding = 'UTF-8') + tt = 6 # Byte + else: + vbytes = vbytes.replace(b' ', b'\xe2\x96\x81') + toktypes.append(tt) + tokens.append(vbytes) + scores.append(vscore) + gguf_writer.add_token_list(tokens) + gguf_writer.add_token_scores(scores) + gguf_writer.add_token_types(toktypes) + + def add_tensors(self, gguf_writer): + nm = self.name_map + data = self.data + print(f'* Adding {len(self.model.tensors)} tensor(s)') + for tensor in self.model.tensors: + name = str(tensor.name, 'UTF-8') + if name.endswith('.weight'): + name = name[:-7] + suffix = '.weight' + elif name.endswith('.bias'): + name = name[:-5] + suffix = '.bias' + mapped_name = nm.get(name) + assert mapped_name is not None, f'Bad name {name}' + mapped_name += suffix + tempdims = list(tensor.dims[:]) + if len(tempdims) > 1: + temp = tempdims[1] + tempdims[1] = tempdims[0] + tempdims[0] = temp + # print(f'+ {tensor.name} | {mapped_name} {tensor.dims} :: {tempdims}') + gguf_writer.add_tensor(mapped_name, data[tensor.start_offset:tensor.start_offset + tensor.len_bytes], raw_shape = tempdims, raw_dtype = tensor.dtype) + +def handle_metadata(cfg, hp): + import convert + assert cfg.model_metadata_dir.is_dir(), 'Metadata dir is not a directory' + hf_config_path = cfg.model_metadata_dir / "config.json" + orig_config_path = cfg.model_metadata_dir / "params.json" + # We pass a fake model here. "original" mode will check the shapes of some + # tensors if information is missing in the .json file: other than that, the + # model data isn't used so this should be safe (at least for now). + fakemodel = { + 'tok_embeddings.weight': convert.LazyTensor.__new__(convert.LazyTensor), + 'layers.0.feed_forward.w1.weight': convert.LazyTensor.__new__(convert.LazyTensor), + } + fakemodel['tok_embeddings.weight'].shape = [hp.n_vocab] + fakemodel['layers.0.feed_forward.w1.weight'].shape = [hp.n_ff] + if hf_config_path.exists(): + params = convert.Params.loadHFTransformerJson(fakemodel, hf_config_path) + elif orig_config_path.exists(): + params = convert.Params.loadOriginalParamsJson(fakemodel, orig_config_path) + else: + raise ValueError('Unable to load metadata') + vocab = convert.load_vocab(cfg.vocab_dir if cfg.vocab_dir is not None else cfg.model_metadata_dir, cfg.vocabtype) + convert.check_vocab_size(params, vocab) + return (params, vocab) + +def handle_args(): + parser = argparse.ArgumentParser(description = 'Convert GGMLv3 models to GGUF') + parser.add_argument('--input', '-i', type = Path, help = 'Input GGMLv3 filename') + parser.add_argument('--output', '-o', type = Path, help ='Output GGUF filename') + parser.add_argument('--name', help = 'Set model name') + parser.add_argument('--desc', help = 'Set model description') + parser.add_argument('--gqa', type = int, default = 1, help = 'grouped-query attention factor (use 8 for LLaMA2 70B)') + parser.add_argument('--eps', default = '5.0e-06', help = 'RMS norm eps: Use 1e-6 for LLaMA1 and OpenLLaMA, use 1e-5 for LLaMA2') + parser.add_argument('--context-length', '-c', type=int, default = 2048, help = 'Default max context length: LLaMA1 is typically 2048, LLaMA2 is typically 4096') + parser.add_argument('--model-metadata-dir', '-m', type = Path, help ='Load HuggingFace/.pth vocab and metadata from the specified directory') + parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir") + parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm)", default="spm") + return parser.parse_args() + +def main(): + cfg = handle_args() + print(f'* Using config: {cfg}') + print('\n=== WARNING === Be aware that this conversion script is best-effort. Use a native GGUF model if possible. === WARNING ===\n') + data = np.memmap(cfg.input, mode = 'r') + model = GGMLV3Model() + print('* Scanning GGML input file') + offset = model.load(data, 0) + print(f'* GGML model hyperparameters: {model.hyperparameters}') + vocab_override = None + params_override = None + if cfg.model_metadata_dir is not None: + (params_override, vocab_override) = handle_metadata(cfg, model.hyperparameters) + print('!! Note: When overriding params the --gqa, --eps and --context-length options are ignored.') + print(f'* Overriding params: {params_override}') + print(f'* Overriding vocab: {vocab_override}') + else: + print('\n=== WARNING === Special tokens may not be converted correctly. Use --model-metadata-dir if possible === WARNING ===\n') + converter = GGMLToGGUF(model, data, cfg, params_override = params_override, vocab_override = vocab_override) + converter.save() + print(f'* Successful completion. Output saved to: {cfg.output}') + +main() diff --git a/convert-llama-hf-to-gguf.py b/convert-llama-hf-to-gguf.py new file mode 100644 index 000000000..f8cfdaa80 --- /dev/null +++ b/convert-llama-hf-to-gguf.py @@ -0,0 +1,327 @@ +# HF llama --> gguf conversion + +import gguf +import os +import sys +import struct +import json +import numpy as np +import torch + +from typing import Any, List, Optional +from pathlib import Path +from sentencepiece import SentencePieceProcessor + +#NDArray = np.ndarray[Any, Any] +# compatible with python < 3.9 +NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' + +# reverse HF permute back to original pth layout +# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/convert_llama_weights_to_hf.py + + +def reverse_hf_permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: + if n_kv_head is not None and n_head != n_kv_head: + n_head //= n_kv_head + + return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) + .swapaxes(1, 2) + .reshape(weights.shape)) + + +def count_model_parts(dir_model: str) -> int: + num_parts = 0 + + for filename in os.listdir(dir_model): + if filename.startswith("pytorch_model-"): + num_parts += 1 + + if num_parts > 0: + print("gguf: found " + str(num_parts) + " model parts") + + return num_parts + + +if len(sys.argv) < 3: + print("Usage: convert-h5-to-ggml.py dir-model ftype\n") + print(" ftype == 0 -> float32") + print(" ftype == 1 -> float16") + + sys.exit(1) + + +# output in the same directory as the model +dir_model = sys.argv[1] +last_dir = os.path.basename(os.path.normpath(dir_model)) + + +# possible tensor data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 + + +# map from ftype to string +ftype_str = ["f32", "f16"] + +ftype = 1 +if len(sys.argv) > 2: + ftype = int(sys.argv[2]) + if ftype < 0 or ftype > 1: + print("Invalid ftype: " + str(ftype)) + + sys.exit(1) + +fname_out = sys.argv[1] + "/ggml-model-" + ftype_str[ftype] + ".gguf" + +print("gguf: loading model "+last_dir) + +with open(dir_model + "/config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +if hparams["architectures"][0] != "LlamaForCausalLM": + print("Model architecture not supported: " + hparams["architectures"][0]) + + sys.exit() + +# get number of model parts +num_parts = count_model_parts(dir_model) + +ARCH=gguf.MODEL_ARCH.LLAMA +gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) + +print("gguf: get model metadata") + +block_count = hparams["num_hidden_layers"] +head_count = hparams["num_attention_heads"] + +if "num_key_value_heads" in hparams: + head_count_kv = hparams["num_key_value_heads"] +else: + head_count_kv = head_count + +if "_name_or_path" in hparams: + hf_repo = hparams["_name_or_path"] +else: + hf_repo = "" + +if "max_sequence_length" in hparams: + ctx_length = hparams["max_sequence_length"] +elif "max_position_embeddings" in hparams: + ctx_length = hparams["max_position_embeddings"] +else: + print("gguf: can not find ctx length parameter.") + + sys.exit() + + +gguf_writer.add_name(last_dir) +gguf_writer.add_source_hf_repo(hf_repo) +gguf_writer.add_tensor_data_layout("Meta AI original pth") +gguf_writer.add_context_length(ctx_length) +gguf_writer.add_embedding_length(hparams["hidden_size"]) +gguf_writer.add_block_count(block_count) +gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) +gguf_writer.add_rope_dimension_count(hparams["hidden_size"] // hparams["num_attention_heads"]) +gguf_writer.add_head_count(head_count) +gguf_writer.add_head_count_kv(head_count_kv) +gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"]) + +if "rope_scaling" in hparams and hparams["rope_scaling"] != None and "factor" in hparams["rope_scaling"]: + if "type" in hparams["rope_scaling"]: + if hparams["rope_scaling"]["type"] == "linear": + gguf_writer.add_rope_scale_linear(hparams["rope_scaling"]["factor"]) + + +# TOKENIZATION + +print("gguf: get tokenizer metadata") + +tokens: List[bytes] = [] +scores: List[float] = [] +toktypes: List[int] = [] + +if Path(dir_model + "/tokenizer.model").is_file(): + # vocab type sentencepiece + print("gguf: get sentencepiece tokenizer vocab, scores and token types") + + tokenizer = SentencePieceProcessor(dir_model + "/tokenizer.model") + + for i in range(tokenizer.vocab_size()): + text: bytes + score: float + + piece = tokenizer.id_to_piece(i) + text = piece.encode("utf-8") + score = tokenizer.get_score(i) + + toktype = 1 # defualt to normal token type + if tokenizer.is_unknown(i): + toktype = 2 + if tokenizer.is_control(i): + toktype = 3 + + # toktype = 4 is user-defined = tokens from added_tokens.json + + if tokenizer.is_unused(i): + toktype = 5 + if tokenizer.is_byte(i): + toktype = 6 + + tokens.append(text) + scores.append(score) + toktypes.append(toktype) + + if Path(dir_model + "/added_tokens.json").is_file(): + with open(dir_model + "/added_tokens.json", "r", encoding="utf-8") as f: + addtokens_json = json.load(f) + + print("gguf: get added tokens") + + for key in addtokens_json: + tokens.append( key.encode("utf-8") ) + scores.append(-1000.0) + toktypes.append(4) # user-defined token type + + + gguf_writer.add_tokenizer_model("llama") + gguf_writer.add_token_list(tokens) + gguf_writer.add_token_scores(scores) + gguf_writer.add_token_types(toktypes) + + +print("gguf: get special token ids") + +if Path(dir_model + "/tokenizer.json").is_file(): + # Look for special tokens in tokenizer.json if it exists + + with open(dir_model + "/tokenizer.json", "r", encoding="utf-8") as f: + tokenizer = json.load(f) + + if "added_tokens" in tokenizer and Path(dir_model + "/tokenizer_config.json").is_file(): + + with open(dir_model + "/tokenizer_config.json", "r", encoding="utf-8") as f: + tokenizer_config = json.load(f) + + if "bos_token" in tokenizer_config and tokenizer_config["bos_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["bos_token"]["content"]: + gguf_writer.add_bos_token_id(key["id"]) + + if "eos_token" in tokenizer_config and tokenizer_config["eos_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["eos_token"]["content"]: + gguf_writer.add_eos_token_id(key["id"]) + + if "unk_token" in tokenizer_config and tokenizer_config["unk_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["unk_token"]["content"]: + gguf_writer.add_unk_token_id(key["id"]) + + if "sep_token" in tokenizer_config and tokenizer_config["sep_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["sep_token"]["content"]: + gguf_writer.add_sep_token_id(key["id"]) + + if "pad_token" in tokenizer_config and tokenizer_config["pad_token"] != None: + for key in tokenizer["added_tokens"]: + if key["content"] == tokenizer_config["pad_token"]["content"]: + gguf_writer.add_pad_token_id(key["id"]) +else: + # If no tokenizer.json: Look for special tokens in config.json + + if "bos_token_id" in hparams and hparams["bos_token_id"] != None: + gguf_writer.add_bos_token_id(hparams["bos_token_id"]) + + if "eos_token_id" in hparams and hparams["eos_token_id"] != None: + gguf_writer.add_eos_token_id(hparams["eos_token_id"]) + + if "unk_token_id" in hparams and hparams["unk_token_id"] != None: + gguf_writer.add_unk_token_id(hparams["unk_token_id"]) + + if "sep_token_id" in hparams and hparams["sep_token_id"] != None: + gguf_writer.add_sep_token_id(hparams["sep_token_id"]) + + if "pad_token_id" in hparams and hparams["pad_token_id"] != None: + gguf_writer.add_pad_token_id(hparams["pad_token_id"]) + + +# TENSORS + +tensor_map = gguf.get_tensor_name_map(ARCH,block_count) + +# tensor info +print("gguf: get tensor metadata") + +if num_parts == 0: + part_names = ("pytorch_model.bin",) +else: + part_names = ( + f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) + ) + +for part_name in part_names: + print("gguf: loading model part '" + part_name + "'") + model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") + + for name in model_part.keys(): + data = model_part[name] + + # we don't need these + if name.endswith(".rotary_emb.inv_freq"): + continue + + old_dtype = data.dtype + + # convert any unsupported data types to float32 + if data.dtype != torch.float16 and data.dtype != torch.float32: + data = data.to(torch.float32) + + data = data.squeeze().numpy() + + # reverse permute these + if name.endswith(".q_proj.weight"): + data = reverse_hf_permute(data, head_count) + if name.endswith(".k_proj.weight"): + data = reverse_hf_permute(data, head_count, head_count_kv) + + # map tensor names + if name.endswith(".weight") and name[:-7] in tensor_map: + name = tensor_map[name[:-7]] + ".weight" + elif name.endswith(".bias") and name[:-5] in tensor_map: + name = tensor_map[name[:-5]] + ".bias" + else: + print("Can not map tensor '" + name + "'") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) + + gguf_writer.add_tensor(name, data) + + +print("gguf: write header") +gguf_writer.write_header_to_file() +print("gguf: write metadata") +gguf_writer.write_kv_data_to_file() +print("gguf: write tensors") +gguf_writer.write_tensors_to_file() + +gguf_writer.close() + + +print("gguf: model successfully exported to '" + fname_out + "'") +print("") diff --git a/convert.py b/convert.py index f3bf17980..c29c032cd 100644 --- a/convert.py +++ b/convert.py @@ -1,4 +1,6 @@ #!/usr/bin/env python + +import gguf import argparse import concurrent.futures import copy @@ -16,13 +18,12 @@ import signal import struct import sys import zipfile +import numpy as np + from abc import ABCMeta, abstractmethod from dataclasses import dataclass from pathlib import Path -from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, - Literal, Optional, Sequence, Tuple, TypeVar, Union) - -import numpy as np +from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Literal, Optional, Sequence, Tuple, TypeVar, Union) from sentencepiece import SentencePieceProcessor # type: ignore if TYPE_CHECKING: @@ -33,57 +34,44 @@ if hasattr(faulthandler, 'register') and hasattr(signal, 'SIGUSR1'): NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]' +ARCH=gguf.MODEL_ARCH.LLAMA +NAMES=gguf.MODEL_TENSOR_NAMES[ARCH] + +# +# data types +# @dataclass(frozen=True) class UnquantizedDataType: name: str - -DT_F16 = UnquantizedDataType('F16') -DT_F32 = UnquantizedDataType('F32') -DT_I32 = UnquantizedDataType('I32') +DT_F16 = UnquantizedDataType('F16') +DT_F32 = UnquantizedDataType('F32') +DT_I32 = UnquantizedDataType('I32') DT_BF16 = UnquantizedDataType('BF16') - -@dataclass(frozen=True) -class QuantizedDataType: - groupsize: int - have_addends: bool - have_g_idx: bool - - -DT_Q4_0 = QuantizedDataType(groupsize=32, have_addends=False, have_g_idx=False) -DT_Q4_1 = QuantizedDataType(groupsize=32, have_addends=True, have_g_idx=False) - -DataType = Union[UnquantizedDataType, QuantizedDataType] - -DATA_TYPE_TO_FTYPE: Dict[DataType, int] = { - DT_F32: 0, - DT_F16: 1, - DT_Q4_0: 2, - DT_Q4_1: 3, -} - -FTYPE_TO_DATA_TYPE: Dict[int, DataType] = \ - {ftype: dtype for (dtype, ftype) in DATA_TYPE_TO_FTYPE.items()} +DataType = Union[UnquantizedDataType] DATA_TYPE_TO_NUMPY: Dict[DataType, 'np.dtype[Any]'] = { DT_BF16: np.dtype(np.uint16), - DT_F16: np.dtype(np.float16), - DT_F32: np.dtype(np.float32), - DT_I32: np.dtype(np.int32), + DT_F16: np.dtype(np.float16), + DT_F32: np.dtype(np.float32), + DT_I32: np.dtype(np.int32), } NUMPY_TYPE_TO_DATA_TYPE: Dict['np.dtype[Any]', DataType] = \ {dtype: data_type for (data_type, dtype) in DATA_TYPE_TO_NUMPY.items()} +SAFETENSORS_DATA_TYPES: Dict[str, DataType] = { + 'BF16': DT_BF16, + 'F16': DT_F16, + 'F32': DT_F32, + 'I32': DT_I32, +} class GGMLFileType(enum.Enum): - AllF32 = 0 + AllF32 = 0 MostlyF16 = 1 # except 1d tensors - MostlyQ4_0 = 2 # except 1d tensors - MostlyQ4_1 = 3 # except 1d tensors - PerLayerIsQ4_1 = 4 # but tok_embeddings.weight and output.weight are F16 def type_for_tensor(self, name: str, tensor: 'LazyTensor') -> DataType: if len(tensor.shape) == 1: @@ -93,60 +81,34 @@ class GGMLFileType(enum.Enum): return DT_F32 elif self == GGMLFileType.MostlyF16: return DT_F16 - elif self == GGMLFileType.MostlyQ4_0: - return DT_Q4_0 - elif self == GGMLFileType.MostlyQ4_1: - return DT_Q4_1 - elif self == GGMLFileType.PerLayerIsQ4_1: - if name in ('output.weight', 'tok_embeddings.weight'): - return DT_F16 - else: - return DT_Q4_1 else: raise ValueError(self) -def make_tensors_list() -> List[str]: - ret = [ - 'tok_embeddings.weight', - 'norm.weight', - 'output.weight', - ] - for i in range(80): # maximum number of layer - ret += [ - f'layers.{i}.attention.wq.weight', - f'layers.{i}.attention.wk.weight', - f'layers.{i}.attention.wv.weight', - f'layers.{i}.attention.wo.weight', - f'layers.{i}.attention_norm.weight', - f'layers.{i}.feed_forward.w1.weight', - f'layers.{i}.feed_forward.w2.weight', - f'layers.{i}.feed_forward.w3.weight', - f'layers.{i}.ffn_norm.weight', - ] - return ret - - -TENSORS_LIST = make_tensors_list() -TENSORS_SET = set(TENSORS_LIST) - - -def find_n_mult(n_ff: int, n_embd: int) -> int: - # hardcoded magic range - for n_mult in range(8192, 1, -1): - calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult - if calc_ff == n_ff: - return n_mult - raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).") +# +# hparams loading +# @dataclass class Params: - n_vocab: int - n_embd: int - n_mult: int - n_head: int - n_layer: int - n_kv_head: Optional[int] # This parameter is only used for Llama 2 + n_vocab: int + n_embd: int + n_mult: int + n_layer: int + n_ctx: int + n_ff: int + n_head: int + n_head_kv: int + f_norm_eps: float + + @staticmethod + def find_n_mult(n_ff: int, n_embd: int) -> int: + # hardcoded magic range + for n_mult in range(8192, 1, -1): + calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult + if calc_ff == n_ff: + return n_mult + raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).") @staticmethod def guessed(model: 'LazyModel') -> 'Params': @@ -165,37 +127,57 @@ class Params: raise Exception("failed to guess 'n_layer'. This model is unknown or unsupported.\n" "Suggestion: provide 'config.json' of the model in the same directory containing model files.") - n_head=n_embd // 128 # guessed + n_head = n_embd // 128 # guessed + n_mult = 256 # guessed + + # TODO: verify this + n_ff = int(2 * (4 * n_embd) / 3) + n_ff = n_mult * ((n_ff + n_mult - 1) // n_mult) return Params( - n_vocab = n_vocab, - n_embd = n_embd, - n_mult = 256, - n_head = n_head, - n_layer = n_layer, - n_kv_head = None, + n_vocab = n_vocab, + n_embd = n_embd, + n_mult = n_mult, + n_layer = n_layer, + n_ctx = -1, + n_ff = n_ff, + n_head = n_head, + n_head_kv = n_head, + f_norm_eps = 1e-5, ) @staticmethod def loadHFTransformerJson(model: 'LazyModel', config_path: 'Path') -> 'Params': config = json.load(open(config_path)) - n_vocab = config["vocab_size"]; - n_embd = config["hidden_size"]; - n_head = config["num_attention_heads"]; - n_layer = config["num_hidden_layers"]; - n_ff = config["intermediate_size"]; - n_kv_head = config.get("num_key_value_heads") + n_vocab = config["vocab_size"] + n_embd = config["hidden_size"] + n_layer = config["num_hidden_layers"] + n_ff = config["intermediate_size"] + n_head = config["num_attention_heads"] + n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head + f_norm_eps = config["rms_norm_eps"] - n_mult = find_n_mult(n_ff, n_embd); + n_mult = Params.find_n_mult(n_ff, n_embd) + + if "max_sequence_length" in config: + n_ctx = config["max_sequence_length"] + elif "max_position_embeddings" in config: + n_ctx = config["max_position_embeddings"] + else: + raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n" + "Suggestion: provide 'config.json' of the model in the same directory containing model files.") return Params( - n_vocab = n_vocab, - n_embd = n_embd, - n_mult = n_mult, - n_head = n_head, - n_layer = n_layer, - n_kv_head = n_kv_head, + n_vocab = n_vocab, + n_embd = n_embd, + n_mult = n_mult, + n_layer = n_layer, + n_ctx = n_ctx, + n_ff = n_ff, + n_head = n_head, + n_head_kv = n_head_kv, + f_norm_eps = f_norm_eps, ) # LLaMA v2 70B params.json @@ -204,22 +186,32 @@ class Params: def loadOriginalParamsJson(model: 'LazyModel', config_path: 'Path') -> 'Params': config = json.load(open(config_path)) - n_vocab = config["vocab_size"]; - n_embd = config["dim"]; - n_head = config["n_heads"]; - n_layer = config["n_layers"]; - n_mult = config["multiple_of"]; + n_vocab = config["vocab_size"] + n_embd = config["dim"] + n_layer = config["n_layers"] + n_mult = config["multiple_of"] + n_ctx = 2048 if config["norm_eps"] == 1e-06 else 4096 # hack to determine LLaMA v1 vs v2 + n_ff = -1 + n_head = config["n_heads"] + n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head + f_norm_eps = config["norm_eps"] if n_vocab == -1: n_vocab = model["tok_embeddings.weight"].shape[0] + if n_ff == -1: + n_ff = model["layers.0.feed_forward.w1.weight"].shape[0] + return Params( - n_vocab = n_vocab, - n_embd = n_embd, - n_mult = n_mult, - n_head = n_head, - n_layer = n_layer, - n_kv_head = None, + n_vocab = n_vocab, + n_embd = n_embd, + n_mult = n_mult, + n_layer = n_layer, + n_ctx = n_ctx, + n_ff = n_ff, + n_head = n_head, + n_head_kv = n_head_kv, + f_norm_eps = f_norm_eps, ) @staticmethod @@ -234,30 +226,73 @@ class Params: else: params = Params.guessed(model_plus.model) - print(f'params: n_vocab:{params.n_vocab} n_embd:{params.n_embd} n_mult:{params.n_mult} n_head:{params.n_head} n_layer:{params.n_layer}') return params -class SentencePieceVocab: - def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path], vocabtype: Optional[str]) -> None: - self.vocabtype = vocabtype - if self.vocabtype == "bpe": - self.sentencepiece_tokenizer = json.loads(open(str(fname_tokenizer)).read()) - else: - self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) +# +# vocab +# + +class BpeVocab: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: + self.bpe_tokenizer = json.loads(open(str(fname_tokenizer), encoding="utf-8").read()) added_tokens: Dict[str, int] if fname_added_tokens is not None: - added_tokens = json.load(open(fname_added_tokens)) + added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) else: added_tokens = {} - if self.vocabtype == "bpe": - vocab_size: int = len(self.sentencepiece_tokenizer) - else: - vocab_size: int = self.sentencepiece_tokenizer.vocab_size() - expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) - actual_ids = sorted(added_tokens.values()) + + vocab_size: int = len(self.bpe_tokenizer) + expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) + actual_ids = sorted(added_tokens.values()) if expected_ids != actual_ids: raise Exception(f"Expected added token IDs to be sequential and start at {len(added_tokens)}; got {actual_ids}") + + items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) + self.added_tokens_list = [text for (text, idx) in items] + self.vocab_size_base: int = vocab_size + self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_list) + self.fname_tokenizer = fname_tokenizer + self.fname_added_tokens = fname_added_tokens + + def bpe_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + tokenizer = self.bpe_tokenizer + from transformers.models.gpt2 import tokenization_gpt2 + byte_encoder = tokenization_gpt2.bytes_to_unicode() + byte_decoder = {v: k for k, v in byte_encoder.items()} + for i, item in enumerate(tokenizer): + text: bytes = item.encode("utf-8") + score: float = -i + yield text, score, gguf.TokenType.USER_DEFINED + + def added_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + for text in self.added_tokens_list: + score = -1000.0 + yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED + + def all_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: + yield from self.bpe_tokens() + yield from self.added_tokens() + + def __repr__(self) -> str: + return f"BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" + + +class SentencePieceVocab: + def __init__(self, fname_tokenizer: Path, fname_added_tokens: Optional[Path]) -> None: + self.sentencepiece_tokenizer = SentencePieceProcessor(str(fname_tokenizer)) + added_tokens: Dict[str, int] + if fname_added_tokens is not None: + added_tokens = json.load(open(fname_added_tokens, encoding="utf-8")) + else: + added_tokens = {} + + vocab_size: int = self.sentencepiece_tokenizer.vocab_size() + expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) + actual_ids = sorted(added_tokens.values()) + if expected_ids != actual_ids: + raise Exception(f"Expected added token IDs to be sequential and start at {len(added_tokens)}; got {actual_ids}") + items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) self.added_tokens_list = [text for (text, idx) in items] self.vocab_size_base: int = vocab_size @@ -265,117 +300,66 @@ class SentencePieceVocab: self.fname_tokenizer = fname_tokenizer self.fname_added_tokens = fname_added_tokens - def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float]]: + def sentencepiece_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: tokenizer = self.sentencepiece_tokenizer - if self.vocabtype == "bpe": - from transformers.models.gpt2 import tokenization_gpt2 - byte_encoder = tokenization_gpt2.bytes_to_unicode() - byte_decoder = {v: k for k, v in byte_encoder.items()} - for i, item in enumerate(tokenizer): - text: bytes - text = b''.join([x.to_bytes(1, byteorder='big') for x in [byte_decoder[y] for y in item]]) - score: float = -i - yield text, score - else: - for i in range(tokenizer.vocab_size()): - text: bytes - if tokenizer.is_unknown(i): - text = " \u2047 ".encode("utf-8") - elif tokenizer.is_control(i): - text = b"" - elif tokenizer.is_byte(i): - piece = tokenizer.id_to_piece(i) - if len(piece) != 6: - raise Exception(f"Invalid token: {piece}") - byte_value = int(piece[3:-1], 16) - text = struct.pack("B", byte_value) - else: - text = tokenizer.id_to_piece(i).replace("\u2581", " ").encode("utf-8") - score: float = tokenizer.get_score(i) - yield text, score + for i in range(tokenizer.vocab_size()): + piece = tokenizer.id_to_piece(i) + text: bytes = piece.encode("utf-8") + score: float = tokenizer.get_score(i) - def added_tokens(self) -> Iterable[Tuple[bytes, float]]: + toktype = gguf.TokenType.NORMAL + if tokenizer.is_unknown(i): + toktype = gguf.TokenType.UNKNOWN + if tokenizer.is_control(i): + toktype = gguf.TokenType.CONTROL + + # NOTE: I think added_tokens are user defined. + # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto + # if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED + + if tokenizer.is_unused(i): + toktype = gguf.TokenType.UNUSED + if tokenizer.is_byte(i): + toktype = gguf.TokenType.BYTE + + yield text, score, toktype + + def added_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: for text in self.added_tokens_list: score = -1000.0 - yield text.encode("utf-8"), score + yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED - def all_tokens(self) -> Iterable[Tuple[bytes, float]]: + def all_tokens(self) -> Iterable[Tuple[bytes, float, gguf.TokenType]]: yield from self.sentencepiece_tokens() yield from self.added_tokens() def __repr__(self) -> str: return f"" - -class GGMLVocab: - def __init__(self, tokens: List[Tuple[bytes, float]]): - self.tokens = tokens - self.vocab_size = len(tokens) - - def all_tokens(self) -> Iterable[Tuple[bytes, float]]: - return self.tokens - - def __repr__(self) -> str: - return f"" +Vocab = Union[BpeVocab, SentencePieceVocab] -Vocab = Union[SentencePieceVocab, GGMLVocab] +# +# data loading +# TODO: reuse (probably move to gguf.py?) +# - -def permute(weights: NDArray, n_head: int, n_kv_head: Optional[int] = None) -> NDArray: - if n_kv_head is not None and n_head != n_kv_head: - n_head //= n_kv_head +def permute(weights: NDArray, n_head: int, n_head_kv: int) -> NDArray: + #print( "permute debug " + str(weights.shape[0]) + " x " + str(weights.shape[1]) + " nhead " + str(n_head) + " nheadkv " + str(n_kv_head) ) + if n_head_kv is not None and n_head != n_head_kv: + n_head //= n_head_kv return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:]) .swapaxes(1, 2) .reshape(weights.shape)) -def dequantize_q4(qvalues_pack32: NDArray, scales: NDArray, addends: Optional[NDArray], g_idx: Optional[NDArray]) -> NDArray: - # First reinterpret each row from a list of int32s containing 8 values each - # to a list of uint8s containing 2 values each. - qvalues_pack8 = qvalues_pack32.view(np.uint8) - - # Then split out the two values per int8 (which requires an actual - # conversion because numpy doesn't natively support int4s). - qvalues = np.zeros([qvalues_pack8.shape[0], qvalues_pack8.shape[1] * 2], dtype=np.uint8) - qvalues[:, 0::2] = qvalues_pack8 & 0xf - qvalues[:, 1::2] = qvalues_pack8 >> 4 - - assert addends is None or addends.shape == scales.shape - assert qvalues.shape[0] == scales.shape[0] - assert qvalues.shape[1] % scales.shape[1] == 0 - if g_idx is None: - repeat_count = qvalues.shape[1] // scales.shape[1] - scales = scales[:, :, np.newaxis] - if addends is not None: - addends = addends[:, :, np.newaxis] - # Reshape so that the below computation broadcasts over scales and addends: - qvalues.shape = (qvalues.shape[0], scales.shape[1], int(repeat_count)) - else: - # In this case the scale and addend is selected for each column by g_idx: - assert addends is not None - scales = scales[:, g_idx] - addends = addends[:, g_idx] - if addends is None: - # Q4_0 - qvalues = qvalues.view(np.int8) - qvalues -= 8 - # And do the actual 'value = scale * qvalue + addend' computation. - values = scales * qvalues - if addends is not None: - values += addends - if g_idx is None: - values.shape = (values.shape[0], values.shape[1] * values.shape[2]) - return values - - class Tensor(metaclass=ABCMeta): data_type: DataType @abstractmethod def astype(self, data_type: DataType) -> 'Tensor': ... @abstractmethod - def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'Tensor': ... + def permute(self, n_head: int, n_head_kv: int) -> 'Tensor': ... @abstractmethod def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': ... @abstractmethod @@ -413,8 +397,8 @@ class UnquantizedTensor(Tensor): r = self.ndarray.shape[0] // 3 return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) - def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'UnquantizedTensor': - return UnquantizedTensor(permute(self.ndarray, n_head, n_kv_head)) + def permute(self, n_head: int, n_head_kv: int) -> 'UnquantizedTensor': + return UnquantizedTensor(permute(self.ndarray, n_head, n_head_kv)) def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, convert: bool = False) -> NDArray: @@ -433,183 +417,25 @@ def load_unquantized(lazy_tensor: 'LazyTensor', expected_dtype: Any = None, conv return tensor.ndarray -class GGMLQuantizedTensor(Tensor): - data_type: QuantizedDataType - - def __init__(self, ndarray: NDArray, shape: List[int], data_type: DataType) -> None: - rows, columns = shape - assert data_type in (DT_Q4_1, DT_Q4_0) # for now - assert isinstance(data_type, QuantizedDataType) # redundant, but mypy complains without this - assert columns % data_type.groupsize == 0 - words_in_block = 6 if data_type == DT_Q4_1 else 5 - self.ndarray = ndarray.view(dtype=np.uint32).reshape((rows, columns // data_type.groupsize, words_in_block)) - self.shape = shape[:] - self.data_type = data_type - - def astype(self, data_type: DataType) -> Tensor: - if data_type == self.data_type: - return self - scales = self.ndarray[:, :, 0].view(np.float32) - if self.data_type.have_addends: - addends = self.ndarray[:, :, 1].view(np.float32) - else: - addends = None - qweights = self.ndarray[:, :, -4:].reshape([self.shape[0], self.shape[1] // 8]) - - dq = dequantize_q4(qweights, scales, addends, g_idx=None) - return UnquantizedTensor(dq).astype(data_type) - - def to_ggml(self) -> 'GGMLQuantizedTensor': - return self - - def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> 'GGMLQuantizedTensor': - return GGMLQuantizedTensor(permute(self.ndarray, n_head, n_kv_head), self.shape, self.data_type) - - def permute_part(self, n_part: int, n_head: int) -> 'UnquantizedTensor': - r = self.ndarray.shape[0] // 3 - return UnquantizedTensor(permute(self.ndarray[r * n_part : r * n_part + r, ...], n_head)) - - def part(self, n_part: int) -> 'UnquantizedTensor': - r = self.ndarray.shape[0] // 3 - return UnquantizedTensor(self.ndarray[r * n_part : r * n_part + r, ...]) - -GGMLCompatibleTensor = Union[UnquantizedTensor, GGMLQuantizedTensor] +GGMLCompatibleTensor = Union[UnquantizedTensor] class DeferredPermutedTensor(Tensor): - def __init__(self, base: Tensor, n_head: int, n_kv_head: Optional[int] = None) -> None: + def __init__(self, base: Tensor, n_head: int, n_head_kv: int) -> None: self.base = base self.n_head = n_head - self.n_kv_head = n_kv_head self.data_type = self.base.data_type def astype(self, data_type: DataType) -> Tensor: - return self.base.astype(data_type).permute(self.n_head, self.n_kv_head) + return self.base.astype(data_type).permute(self.n_head, self.n_head_kv) def to_ggml(self) -> GGMLCompatibleTensor: - return self.base.to_ggml().permute(self.n_head, self.n_kv_head) + return self.base.to_ggml().permute(self.n_head, self.n_head_kv) - def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor: + def permute(self, n_head: int, n_head_kv: int) -> Tensor: raise Exception("shouldn't permute twice") -class GPTQForLLaMaQuantizedTensor(Tensor): - def __init__(self, model: 'LazyModel', namebase: str) -> None: - qweight = load_unquantized(model[f"{namebase}.qweight"], np.int32) - scales = load_unquantized(model[f"{namebase}.scales"], np.float32, convert=True) - - bias = model.get(f"{namebase}.bias") - if bias is not None: - # Q4_1 does not support bias; good thing the bias is always all zeros. - assert not np.any(load_unquantized(bias)) - - if f"{namebase}.zeros" in model: - zeros = load_unquantized(model[f"{namebase}.zeros"], np.float32) - else: - qzeros = load_unquantized(model[f"{namebase}.qzeros"], np.int32) - assert qzeros.dtype == np.int32 - zeros = dequantize_q4(qzeros, scales, scales, g_idx=None) - assert zeros.dtype == np.float32 - - assert zeros.shape == scales.shape - - # Output is transposed compared to the input, and addends have their sign flipped. - # Scales and zeros similarly must be transposed but only for newer - # versions of GPTQ-for-LLaMa; the older versions can be identified by - # having shape (n_embd, 1). - qweight = qweight.T - if scales.shape[1] != 1: - scales = scales.T - zeros = zeros.T - - # Output also has signs flipped for the addends. - self.qweight = qweight - self.scales = scales - self.addends = -zeros - - self.g_idx: Optional[NDArray] - if f"{namebase}.g_idx" in model: - self.g_idx = load_unquantized(model[f"{namebase}.g_idx"], np.int32) - assert self.g_idx.shape == (qweight.shape[1] * 8,) - else: - self.g_idx = None - - self.shape = [self.qweight.shape[0], self.qweight.shape[1] * 8] - self.data_type = QuantizedDataType(groupsize=self.groupsize(), have_addends=True, - have_g_idx=(self.g_idx is not None)) - - def inspect(self, row: int, col: int) -> None: - '''For debugging.''' - qweight = (self.qweight[row, col // 8] >> (4 * (col & 7))) & 0xf - if self.g_idx is not None: - group = self.g_idx[col] - else: - group = int(col // self.groupsize()) - scale = self.scales[row, group] - addend = self.addends[row, group] - with np.printoptions(precision=None, suppress=True): - print(f'scale:{scale} addend:{addend} qweight:{qweight}') - print('possible values:', np.arange(16) * scale + addend) - print('actual value:', qweight * scale + addend) - - def astype(self, data_type: DataType) -> Tensor: - if isinstance(data_type, QuantizedDataType): - assert self.g_idx is None and data_type.have_addends is True and data_type.have_g_idx is False - return self.regroup(data_type.groupsize) - - dequantized = dequantize_q4(np.ascontiguousarray(self.qweight), self.scales, self.addends, self.g_idx) - return UnquantizedTensor(dequantized).astype(data_type) - - def groupsize(self) -> int: - assert self.addends.shape == self.scales.shape - assert self.shape[1] % self.scales.shape[1] == 0 - return self.shape[1] // self.scales.shape[1] - - def regroup(self, new_groupsize: int = 32) -> 'GPTQForLLaMaQuantizedTensor': - # Old versions of GPTQ-for-LLaMa shared scales and addends between all the - # columns in a row. Newer versions share them between every set of N - # columns in a row, where N is the `groupsize` parameter, usually 128. The - # output format shares them between every set of 32 columns. To handle - # this, duplicate scales and addends for every smaller group. - # (In the above, 'row' and 'column' are in the sense of the output.) - assert self.g_idx is None - old_groupsize = self.groupsize() - assert old_groupsize >= new_groupsize and old_groupsize % new_groupsize == 0, old_groupsize - ret = copy.copy(self) - ret.addends = self.addends.repeat(old_groupsize // new_groupsize, axis=1) - ret.scales = self.scales.repeat(old_groupsize // new_groupsize, axis=1) - ret.data_type = QuantizedDataType(groupsize=new_groupsize, have_addends=True, have_g_idx=False) - return ret - - def permute(self, n_head: int, n_kv_head: Optional[int] = None) -> Tensor: - return DeferredPermutedTensor(self, n_head, n_kv_head) - - def to_ggml(self) -> GGMLQuantizedTensor: - # The output format looks like this: - # For each row: - # For each group of 32 columns: - # - addend (float32, 4 bytes) - # - scale (float32, 4 bytes) - # - weights (int4 * 32, 16 bytes) - - if self.groupsize() != 32: - raise Exception("should have been regrouped before converting to ggml") - - # Since the output format is mixed between integers and floats, we have - # to hackily view the floats as int32s just so numpy will let us - # concatenate them. - addends_view = self.addends.view(dtype=np.int32)[:, :, np.newaxis] - scales_view = self.scales.view(dtype=np.int32)[:, :, np.newaxis] - - # Split into groups of 4 columns (i.e. 32 columns of quantized data): - grouped = self.qweight.reshape([self.qweight.shape[0], self.qweight.shape[1] // 4, 4]) - - # And concatenate: - grouped = np.concatenate([scales_view, addends_view, grouped], axis=2, casting='no') - - return GGMLQuantizedTensor(grouped, self.shape, DT_Q4_1) - - @dataclass class LazyTensor: _load: Callable[[], Tensor] @@ -632,17 +458,6 @@ class LazyTensor: def validate_conversion_to(self, data_type: DataType) -> None: if data_type == self.data_type: return - if isinstance(data_type, QuantizedDataType): - if not isinstance(self.data_type, QuantizedDataType): - raise Exception(f"Can't turn an unquantized tensor into a quantized type ({data_type})") - if self.data_type.have_g_idx: - sys.stderr.write( - "Error: Input uses the newer GPTQ-for-LLaMa format (using g_idx), " - "which is not yet natively supported by GGML. " - "For now you can still convert this model by passing `--outtype f16` to dequantize, " - "but that will result in a much larger output file for no quality benefit.\n") - sys.exit(1) - assert not data_type.have_g_idx and self.data_type.have_addends and data_type.have_addends LazyModel = Dict[str, LazyTensor] @@ -713,10 +528,10 @@ def merge_multifile_models(models_plus: List[ModelPlus]) -> ModelPlus: return ModelPlus(model, paths, format, vocab) -def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_kv_head: Optional[int] = None) -> LazyTensor: +def permute_lazy(lazy_tensor: LazyTensor, n_head: int, n_head_kv: int) -> LazyTensor: def load() -> Tensor: - return lazy_tensor.load().permute(n_head, n_kv_head) - return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_kv_head}) ' + lazy_tensor.description) + return lazy_tensor.load().permute(n_head, n_head_kv) + return LazyTensor(load, lazy_tensor.shape, lazy_tensor.data_type, f'permute({n_head}, {n_head_kv}) ' + lazy_tensor.description) def permute_part_lazy(lazy_tensor: LazyTensor, n_part: int, n_head: int) -> LazyTensor: def load() -> Tensor: @@ -732,66 +547,6 @@ def part_lazy(lazy_tensor: LazyTensor, n_part: int) -> LazyTensor: s[0] = s[0] // 3 return LazyTensor(load, s, lazy_tensor.data_type, 'part ' + lazy_tensor.description) -def convert_transformers_to_orig(model: LazyModel, params: Params) -> LazyModel: - out: LazyModel = {} - out["tok_embeddings.weight"] = model["model.embed_tokens.weight"] - out["norm.weight"] = model["model.norm.weight"] - out["output.weight"] = model["lm_head.weight"] - - for i in itertools.count(): - if f"model.layers.{i}.self_attn.q_proj.weight" in model: - out[f"layers.{i}.attention.wq.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head) - out[f"layers.{i}.attention.wk.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_kv_head) - out[f"layers.{i}.attention.wv.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] - elif f"model.layers.{i}.self_attn.W_pack.weight" in model: - out[f"layers.{i}.attention.wq.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head) - out[f"layers.{i}.attention.wk.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head) - out[f"layers.{i}.attention.wv.weight"] = part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) - else: - break - - out[f"layers.{i}.attention.wo.weight"] = model[f"model.layers.{i}.self_attn.o_proj.weight"] - - out[f"layers.{i}.feed_forward.w1.weight"] = model[f"model.layers.{i}.mlp.gate_proj.weight"] - out[f"layers.{i}.feed_forward.w2.weight"] = model[f"model.layers.{i}.mlp.down_proj.weight"] - out[f"layers.{i}.feed_forward.w3.weight"] = model[f"model.layers.{i}.mlp.up_proj.weight"] - - out[f"layers.{i}.attention_norm.weight"] = model[f"model.layers.{i}.input_layernorm.weight"] - out[f"layers.{i}.ffn_norm.weight"] = model[f"model.layers.{i}.post_attention_layernorm.weight"] - return out - - -def handle_quantization(model: LazyModel) -> LazyModel: - '''Convert a model with entries for 'foo.qweight', 'foo.scales', etc. - (which resolve to UnquantizedTensors with the raw data) to one with entries - for 'foo.weight' (which resolve to QuantizedTensors). - ''' - def convert(name: str) -> Tuple[str, LazyTensor]: - if name.endswith(".qweight"): - namebase = name.rsplit('.', 1)[0] - orig_name = namebase + ".weight" - - lazy_tensor = model[name] - assert len(lazy_tensor.shape) == 2 - real_shape = [lazy_tensor.shape[1], lazy_tensor.shape[0] * 8] - - # Calculate type. This replicates the logic in - # GPTQForLLaMaQuantizedTensor (which is executed when the modelis - # actually loaded). - lazy_scales = model[f"{namebase}.scales"] - scales_width = 1 if lazy_scales.shape[1] == 1 else lazy_scales.shape[0] - assert real_shape[1] % scales_width == 0 - groupsize = real_shape[1] // scales_width - have_g_idx = f"{namebase}.g_idx" in model - data_type = QuantizedDataType(groupsize=groupsize, have_addends=True, have_g_idx=have_g_idx) - - def load() -> Tensor: - return GPTQForLLaMaQuantizedTensor(model, namebase) - - return (orig_name, LazyTensor(load, real_shape, data_type, '[quantized]')) - else: - return (name, model[name]) - return dict(convert(name) for name in model) # Functionality that simulates `torch.load` but where individual tensors are # only loaded into memory on demand, not all at once. @@ -885,14 +640,6 @@ def lazy_load_torch_file(outer_fp: IO[bytes], path: Path) -> ModelPlus: return ModelPlus(model=as_dict, paths=[path], format='torch', vocab=None) -SAFETENSORS_DATA_TYPES: Dict[str, DataType] = { - 'BF16': DT_BF16, - 'F16': DT_F16, - 'F32': DT_F32, - 'I32': DT_I32, -} - - def lazy_load_safetensors_file(fp: IO[bytes], path: Path) -> ModelPlus: header_size, = struct.unpack(' bytes: return ret -def lazy_load_ggml_file(fp: io.BufferedReader, path: Path) -> ModelPlus: - magic = must_read(fp, 4)[::-1] - if magic in (b'ggmf', b'ggjt'): - version, = struct.unpack("i", must_read(fp, 4)) - assert version == 1 - else: - assert magic == b'ggml' - version = None - n_vocab, n_embd, n_mult, n_head, n_layer, rot, file_type = struct.unpack('<7i', must_read(fp, 28)) - - tokens: List[Tuple[bytes, float]] = [] - for i in range(n_vocab): - if i == 32000: - # HACK: GPT4All messed with the format without changing the magic - # number. Specifically, they changed the vocab section to contain - # `n_vocab - 1` tokens instead of `n_vocab` (i.e. omitting the - # extra pad token). Try to detect if we're reading a file like - # this. - orig_pos = fp.tell() - fp.seek(20, io.SEEK_CUR) - is_gpt4all = fp.read(21) == b'tok_embeddings.weight' - fp.seek(orig_pos) - if is_gpt4all: - break - - length, = struct.unpack("i", must_read(fp, 4)) - text = must_read(fp, length) - if magic != b'ggml': - score, = struct.unpack("f", must_read(fp, 4)) - tokens.append((text, score)) - vocab = GGMLVocab(tokens) if magic != b'ggml' else None - - model: LazyModel = {} - # Use mmap for the actual data to avoid race conditions with the file offset. - off = fp.raw.tell() - mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ)) - fp.raw.seek(off) # needed on Windows - - def read_tensor() -> None: # this is a function so that variables captured in `load` don't change - shape_len, name_len, ftype = struct.unpack("iii", must_read(fp, 12)) - assert 0 <= shape_len <= 3 - shape: List[int] = list(struct.unpack(f"{shape_len}i", must_read(fp, 4 * shape_len))) - shape = shape[::-1] - name = must_read(fp, name_len).decode('utf-8') - data_type = FTYPE_TO_DATA_TYPE[ftype] - - if magic == b'ggjt': - fp.seek((fp.tell() + 31) & -32) - - if data_type == DT_Q4_1: - # See GPTQForLLaMaQuantizedTensor.ggml_ndarray() - size = 24 * (shape[1] // 32) * shape[0] - elif data_type == DT_Q4_0: - size = 20 * (shape[1] // 32) * shape[0] - else: - numpy_dtype = DATA_TYPE_TO_NUMPY[data_type] - elm_count = math.prod(shape) - size = elm_count * numpy_dtype.itemsize - offset = fp.tell() - buf = mapped[offset:offset+size] - fp.seek(size, io.SEEK_CUR) - - def load() -> Tensor: - if isinstance(data_type, QuantizedDataType): - ndarray = np.frombuffer(buf, dtype=np.uint32) - return GGMLQuantizedTensor(ndarray, shape, data_type) - else: - return UnquantizedTensor(np.frombuffer(buf, dtype=numpy_dtype).reshape(shape)) - description = f'ggml offset={offset} type={data_type} path={path}' - model[name] = LazyTensor(load, shape, data_type, description) - - while fp.read(1) != b'': - fp.seek(-1, io.SEEK_CUR) - read_tensor() - - return ModelPlus(model=model, paths=[path], format='ggml', vocab=vocab) - - @functools.lru_cache(maxsize=None) def lazy_load_file(path: Path) -> ModelPlus: fp = open(path, 'rb') @@ -1010,9 +679,6 @@ def lazy_load_file(path: Path) -> ModelPlus: if first8[:2] == b'PK': # A zip file, i.e. PyTorch format return lazy_load_torch_file(fp, path) - elif first8[2:4] == b'gg': - # GGML format - return lazy_load_ggml_file(fp, path) elif struct.unpack(' ModelPlus: In = TypeVar('In') Out = TypeVar('Out') - def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int) -> Iterable[Out]: '''Parallel map, but with backpressure. If the caller doesn't call `next` fast enough, this will stop calling `func` at some point rather than @@ -1043,8 +708,7 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc def check_vocab_size(params: Params, vocab: Vocab) -> None: if params.n_vocab != vocab.vocab_size: - # GGMLVocab comes from the same file as the model so shouldn't mismatch: - assert isinstance(vocab, SentencePieceVocab) + assert isinstance(vocab, BpeVocab) or isinstance(vocab, SentencePieceVocab) if params.n_vocab == vocab.vocab_size_base: print("Ignoring added_tokens.json since model matches vocab size without it.") vocab.added_tokens_list = [] @@ -1061,98 +725,154 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None: class OutputFile: def __init__(self, fname_out: Path) -> None: - self.fout = open(fname_out, "wb") + self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) - def write_file_header(self, params: Params, file_type: GGMLFileType) -> None: - self.fout.write(b"ggjt"[::-1]) # magic - values = [ - 1, # file version - params.n_vocab, - params.n_embd, - params.n_mult, - params.n_head, - params.n_layer, - params.n_embd // params.n_head, # rot (obsolete) - file_type.value, - ] - self.fout.write(struct.pack("i" * len(values), *values)) + def add_meta_arch(self, params: Params) -> None: + self.gguf.add_name ("LLaMA") + self.gguf.add_context_length (params.n_ctx) + self.gguf.add_embedding_length (params.n_embd) + self.gguf.add_block_count (params.n_layer) + self.gguf.add_feed_forward_length (params.n_ff) + self.gguf.add_rope_dimension_count(params.n_embd // params.n_head) + self.gguf.add_head_count (params.n_head) + self.gguf.add_head_count_kv (params.n_head_kv) + self.gguf.add_layer_norm_rms_eps (params.f_norm_eps) - def write_tensor_header(self, name: str, shape: Sequence[int], data_type: DataType) -> None: - sname = name.encode('utf-8') - self.fout.write(struct.pack("iii", len(shape), len(sname), DATA_TYPE_TO_FTYPE[data_type])) - self.fout.write(struct.pack("i" * len(shape), *shape[::-1])) - self.fout.write(sname) - self.fout.seek((self.fout.tell() + 31) & -32) + def add_meta_vocab(self, vocab: Vocab) -> None: + tokens = [] + scores = [] + toktypes = [] + # NOTE: `all_tokens` returns the the base vocabulary and added tokens + # TODO: add special tokens? + for text, score, toktype in vocab.all_tokens(): + tokens.append(text) + scores.append(score) + toktypes.append(toktype) - def write_vocab(self, vocab: Vocab) -> None: - for text, score in vocab.all_tokens(): - self.fout.write(struct.pack("i", len(text))) - self.fout.write(text) - self.fout.write(struct.pack("f", score)) + self.gguf.add_tokenizer_model("llama") + self.gguf.add_token_list(tokens) + self.gguf.add_token_scores(scores) + self.gguf.add_token_types(toktypes) + + def add_tensor_info(self, name: str, tensor: LazyTensor) -> None: + n_elements = 1 + for dim in tensor.shape: + n_elements *= dim + data_type = DATA_TYPE_TO_NUMPY[tensor.data_type] + data_nbytes = n_elements * data_type.itemsize + self.gguf.add_tensor_info(name, tensor.shape, data_type, data_nbytes) + + def write_meta(self) -> None: + self.gguf.write_header_to_file() + self.gguf.write_kv_data_to_file() + + def write_tensor_info(self) -> None: + self.gguf.write_ti_data_to_file() + + def close(self) -> None: + self.gguf.close() @staticmethod - def write_vocab_only(fname_out: Path, vocab: Vocab) -> None: - of = OutputFile(fname_out) - params = Params(n_vocab=vocab.vocab_size, n_embd=0, n_mult=0, n_head=1, n_layer=0) - of = OutputFile(fname_out) - of.write_file_header(params, file_type=GGMLFileType.AllF32) - of.write_vocab(vocab) - of.fout.close() - - @staticmethod - def write_all(fname_out: Path, params: Params, file_type: GGMLFileType, model: LazyModel, vocab: Vocab) -> None: + def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab) -> None: check_vocab_size(params, vocab) + of = OutputFile(fname_out) - of.write_file_header(params, file_type) - print("Writing vocab...") - of.write_vocab(vocab) + + # meta data + of.add_meta_arch(params) + of.add_meta_vocab(vocab) + of.write_meta() + + of.close() + + @staticmethod + def write_all(fname_out: Path, params: Params, model: LazyModel, vocab: Vocab) -> None: + check_vocab_size(params, vocab) + + of = OutputFile(fname_out) + + # meta data + of.add_meta_arch(params) + of.add_meta_vocab(vocab) + + # tensor info + for name, lazy_tensor in model.items(): + of.add_tensor_info(name, lazy_tensor) + + of.write_meta() + of.write_tensor_info() def do_item(item: Tuple[str, LazyTensor]) -> NDArray: name, lazy_tensor = item return lazy_tensor.load().to_ggml().ndarray + # tensor data ndarrays = bounded_parallel_map(do_item, model.items(), concurrency=8) for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)): size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape) padi = len(str(len(model))) print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type}") - of.write_tensor_header(name, lazy_tensor.shape, lazy_tensor.data_type) - ndarray.tofile(of.fout) - of.fout.close() + of.gguf.write_tensor_data(ndarray) + of.close() def pick_output_type(model: LazyModel, output_type_str: Optional[str]) -> GGMLFileType: - wq_type = model["layers.0.attention.wq.weight"].data_type - if output_type_str == "f32" or (output_type_str is None and wq_type in (DT_F32, DT_BF16)): + wq_type = model[NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0)+".weight"].data_type + + if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32): return GGMLFileType.AllF32 - if output_type_str == "f16" or (output_type_str is None and wq_type == DT_F16): + if output_type_str == "f16" or (output_type_str is None and wq_type in (DT_F16, DT_BF16)): return GGMLFileType.MostlyF16 - if output_type_str == "q4_1" or (output_type_str is None and isinstance(wq_type, QuantizedDataType) and - wq_type.have_addends): - if isinstance(model["output.weight"].data_type, QuantizedDataType): - return GGMLFileType.MostlyQ4_1 - else: - return GGMLFileType.PerLayerIsQ4_1 - if output_type_str == "q4_0" or (output_type_str is None and isinstance(wq_type, QuantizedDataType)): - return GGMLFileType.MostlyQ4_0 + name_to_type = {name: lazy_tensor.data_type for (name, lazy_tensor) in model.items()} + raise Exception(f"Unexpected combination of types: {name_to_type}") - -def do_necessary_conversions(model: LazyModel, params: Params) -> LazyModel: - model = handle_quantization(model) - - if "lm_head.weight" in model: - model = convert_transformers_to_orig(model, params) - model = filter_and_sort_tensors(model) - - return model - - def convert_to_output_type(model: LazyModel, output_type: GGMLFileType) -> LazyModel: return {name: tensor.astype(output_type.type_for_tensor(name, tensor)) for (name, tensor) in model.items()} +def convert_model_names(model: LazyModel, params: Params) -> LazyModel: + tmap = gguf.get_tensor_name_map(ARCH, params.n_layer) + + tmp = model + + # HF models permut or pack some of the tensors, so we need to undo that + for i in itertools.count(): + if f"model.layers.{i}.self_attn.q_proj.weight" in model: + print(f"Permuting layer {i}") + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.q_proj.weight"], params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_lazy(model[f"model.layers.{i}.self_attn.k_proj.weight"], params.n_head, params.n_head_kv) + #tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = model[f"model.layers.{i}.self_attn.v_proj.weight"] + elif f"model.layers.{i}.self_attn.W_pack.weight" in model: + print(f"Unpacking and permuting layer {i}") + tmp[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 0, params.n_head, params.n_head) + tmp[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_part_lazy(model[f"model.layers.{i}.self_attn.W_pack.weight"], 1, params.n_head, params.n_head_kv) + tmp[f"model.layers.{i}.self_attn.v_proj.weight"] = part_lazy (model[f"model.layers.{i}.self_attn.W_pack.weight"], 2) + else: + break + + out: LazyModel = {} + for name, lazy_tensor in model.items(): + name_new = name + + if name in tmap: + name_new = tmap[name] + elif name.endswith(".weight") and name[:-7] in tmap: + name_new = tmap[name[:-7]] + ".weight" + elif name.endswith(".bias") and name[:-5] in tmap: + name_new = tmap[name[:-5]] + ".bias" + else: + raise Exception(f"Unexpected tensor name: {name}") + + if gguf.should_skip_tensor_TMP(ARCH, params.n_layer, name_new): + print(f"skipping tensor {name_new}") + continue + else: + print(f"{name:48s} -> {name_new:40s} | {lazy_tensor.data_type} | {lazy_tensor.shape}") + out[name_new] = lazy_tensor + + return out def nth_multifile_path(path: Path, n: int) -> Optional[Path]: '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return @@ -1203,11 +923,6 @@ def load_some_model(path: Path) -> ModelPlus: # Try the PyTorch patterns too, with lower priority globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"] files = [file for glob in globs for file in path.glob(glob)] - if not files: - # Try GGML too, but with lower priority, since if both a non-GGML - # model and a GGML model exist in the same directory, we assume the - # latter was converted from the former. - files = list(path.glob("ggml-model*.bin*")) if not files: raise Exception(f"Can't find model in directory {path}") if len(files) > 1: @@ -1224,19 +939,14 @@ def load_some_model(path: Path) -> ModelPlus: return model_plus -def filter_and_sort_tensors(model: LazyModel) -> LazyModel: - return {name: model[name] for name in TENSORS_LIST if name in model} - - -def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab: - print(f"vocabtype: {vocabtype}") +def load_vocab(path: Path, vocabtype: Optional[str]) -> Union[BpeVocab, SentencePieceVocab]: # Be extra-friendly and accept either a file or a directory. Also, if it's # a directory, it might be the model directory, and tokenizer.model might # be in the parent of that. if path.is_dir(): vocab_file = "tokenizer.model" if vocabtype == 'bpe': - vocab_file = "vocab.json" + vocab_file = "vocab.json" path2 = path / vocab_file # Use `.parent` instead of /.. to handle the symlink case better. path3 = path.parent / vocab_file @@ -1248,21 +958,24 @@ def load_vocab(path: Path, vocabtype: Optional[str]) -> SentencePieceVocab: raise FileNotFoundError( f"Could not find tokenizer.model in {path} or its parent; " "if it's in another directory, pass the directory as --vocab-dir") + + print(f"Loading vocab file '{path}', type '{vocabtype}'") + added_tokens_path = path.parent / "added_tokens.json" - print(f"Loading vocab file {path}") - return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None, - vocabtype) + if vocabtype == "bpe": + return BpeVocab(path, added_tokens_path if added_tokens_path.exists() else None) + elif vocabtype == "spm": + return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) + else: + raise ValueError(f"Unsupported vocabulary type {vocabtype}") def default_outfile(model_paths: List[Path], file_type: GGMLFileType) -> Path: namestr = { - GGMLFileType.AllF32: "f32", + GGMLFileType.AllF32: "f32", GGMLFileType.MostlyF16: "f16", - GGMLFileType.MostlyQ4_0: "q4_0", - GGMLFileType.MostlyQ4_1: "q4_1", - GGMLFileType.PerLayerIsQ4_1: "q4_1", }[file_type] - ret = model_paths[0].parent / f"ggml-model-{namestr}.bin" + ret = model_paths[0].parent / f"ggml-model-{namestr}.gguf" if ret in model_paths: sys.stderr.write( f"Error: Default output path ({ret}) would overwrite the input. " @@ -1281,44 +994,59 @@ def do_dump_model(model_plus: ModelPlus) -> None: def main(args_in: Optional[List[str]] = None) -> None: parser = argparse.ArgumentParser(description="Convert a LLaMa model to a GGML compatible file") - parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") - parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") - parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") - parser.add_argument("--outtype", choices=["f32", "f16", "q4_1", "q4_0"], help="output format (default: based on input)") - parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") - parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") - parser.add_argument("model", type=Path, - help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") - parser.add_argument("--vocabtype", default='spm', choices=["spm", "bpe"], help="vocab format (default: spm)") + parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model") + parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file") + parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") + parser.add_argument("--outtype", choices=["f32", "f16"], help="output format (default: based on input)") + parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") + parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm") + parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") args = parser.parse_args(args_in) - vocab: Vocab if args.dump_single: model_plus = lazy_load_file(args.model) do_dump_model(model_plus) - elif args.vocab_only: + + model_plus = load_some_model(args.model) + + params = Params.load(model_plus) + if params.n_ctx == -1: + if args.ctx is None: + raise Exception("The model doesn't have a context size, and you didn't specify one with --ctx\n" + "Please specify one with --ctx:\n" + " - LLaMA v1: --ctx 2048\n" + " - LLaMA v2: --ctx 4096\n") + params.n_ctx = args.ctx + + print(f"params = {params}") + + vocab: Vocab + if args.vocab_only: vocab = load_vocab(args.vocab_dir or args.model, args.vocabtype) assert args.outfile, "need --outfile if using --vocab-only" outfile = args.outfile - OutputFile.write_vocab_only(outfile, vocab) + OutputFile.write_vocab_only(outfile, params, vocab) print(f"Wrote {outfile}") else: - model_plus = load_some_model(args.model) if args.dump: do_dump_model(model_plus) return + if model_plus.vocab is not None and args.vocab_dir is None: vocab = model_plus.vocab else: vocab_dir = args.vocab_dir if args.vocab_dir else model_plus.paths[0].parent vocab = load_vocab(vocab_dir, args.vocabtype) - params = Params.load(model_plus) - model = model_plus.model - model = do_necessary_conversions(model, params) + + model = model_plus.model + model = convert_model_names(model, params) output_type = pick_output_type(model, args.outtype) - model = convert_to_output_type(model, output_type) - outfile = args.outfile or default_outfile(model_plus.paths, output_type) - OutputFile.write_all(outfile, params, output_type, model, vocab) + model = convert_to_output_type(model, output_type) + outfile = args.outfile or default_outfile(model_plus.paths, output_type) + + OutputFile.write_all(outfile, params, model, vocab) print(f"Wrote {outfile}") diff --git a/docs/token_generation_performance_tips.md b/docs/token_generation_performance_tips.md index 69ba6173c..c9acff7d4 100644 --- a/docs/token_generation_performance_tips.md +++ b/docs/token_generation_performance_tips.md @@ -3,7 +3,7 @@ ## Verifying that the model is running on the GPU with cuBLAS Make sure you compiled llama with the correct env variables according to [this guide](../README.md#cublas), so that llama accepts the `-ngl N` (or `--n-gpu-layers N`) flag. When running llama, you may configure `N` to be very large, and llama will offload the maximum possible number of layers to the GPU, even if it's less than the number you configured. For example: ```shell -./main -m "path/to/model.bin" -ngl 200000 -p "Please sir, may I have some " +./main -m "path/to/model.gguf" -ngl 200000 -p "Please sir, may I have some " ``` When running llama, before it starts the inference work, it will output diagnostic information that shows whether cuBLAS is offloading work to the GPU. Look for these lines: @@ -25,9 +25,9 @@ GPU: A6000 (48GB VRAM) CPU: 7 physical cores RAM: 32GB -Model: `TheBloke_Wizard-Vicuna-30B-Uncensored-GGML/Wizard-Vicuna-30B-Uncensored.ggmlv3.q4_0.bin` (30B parameters, 4bit quantization, GGML) +Model: `TheBloke_Wizard-Vicuna-30B-Uncensored-GGML/Wizard-Vicuna-30B-Uncensored.q4_0.gguf` (30B parameters, 4bit quantization, GGML) -Run command: `./main -m "path/to/model.bin" -p "-p "An extremely detailed description of the 10 best ethnic dishes will follow, with recipes: " -n 1000 [additional benchmark flags]` +Run command: `./main -m "path/to/model.gguf" -p "An extremely detailed description of the 10 best ethnic dishes will follow, with recipes: " -n 1000 [additional benchmark flags]` Result: diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d53652815..d2176c910 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -6,27 +6,6 @@ find_package(Threads REQUIRED) # ... -# common - -set(TARGET common) - -add_library(${TARGET} OBJECT - common.h - common.cpp - console.h - console.cpp - grammar-parser.h - grammar-parser.cpp - ) - -if (BUILD_SHARED_LIBS) - set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) -endif() - -target_include_directories(${TARGET} PUBLIC .) -target_compile_features(${TARGET} PUBLIC cxx_std_11) -target_link_libraries(${TARGET} PRIVATE llama) - # examples include_directories(${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 1a238c4dd..469d6e3de 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -1,5 +1,6 @@ #include "ggml.h" #include "llama.h" + #include #include #include @@ -138,14 +139,16 @@ void print_sample_weights(TransformerWeights *w){ struct llama_vocab { using id = int32_t; using token = std::string; + using ttype = llama_token_type; - struct token_score { - token tok; + struct token_data { + token text; float score; + ttype type; }; std::unordered_map token_to_id; - std::vector id_to_token; + std::vector id_to_token; }; struct my_llama_hparams { @@ -502,7 +505,7 @@ bool is_ggml_file(const char *filename) { return false; } uint32_t magic = file.read_u32(); - return magic == LLAMA_FILE_MAGIC; + return magic == GGUF_MAGIC; } void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) { @@ -515,36 +518,30 @@ void load_vocab(const char *filename, Config *config, struct llama_vocab *vocab) struct llama_model * lmodel = llama_load_model_from_file(filename, llama_params); struct llama_context * lctx = llama_new_context_with_model(lmodel, llama_params); - std::vector strings; - std::vector scores; - int n_vocab = llama_n_vocab(lctx); - strings.resize(n_vocab, NULL); - scores.resize(n_vocab, 0); - n_vocab = llama_get_vocab(lctx, strings.data(), scores.data(), n_vocab); - GGML_ASSERT(n_vocab == llama_n_vocab(lctx)); + const int n_vocab = llama_n_vocab(lctx); vocab->id_to_token.resize(n_vocab); for (int i=0; iid_to_token[i].tok = tok; - vocab->id_to_token[i].score = score; - vocab->token_to_id.emplace(tok, i); + vocab->id_to_token[i].text = llama_token_get_text(lctx, i); + vocab->id_to_token[i].score = llama_token_get_score(lctx, i); + vocab->id_to_token[i].type = llama_token_get_type(lctx, i); + vocab->token_to_id.emplace(vocab->id_to_token[i].text, i); } llama_free(lctx); llama_free_model(lmodel); } else { // assume llama2.c vocabulary printf("Assuming llama2.c vocabulary since %s is not a ggml file\n", filename); llama_file file(filename, "rb"); - uint32_t n_vocab = config->vocab_size; + const int n_vocab = config->vocab_size; /* uint32_t max_token_length = */ file.read_u32(); // unused vocab->id_to_token.resize(n_vocab); - for (uint32_t i=0; iid_to_token[i].tok = tok; + std::string text = file.read_string(len); + vocab->id_to_token[i].text = text; vocab->id_to_token[i].score = score; - vocab->token_to_id.emplace(tok, i); + vocab->id_to_token[i].type = LLAMA_TOKEN_TYPE_UNDEFINED; + vocab->token_to_id.emplace(text, i); } } } @@ -590,75 +587,80 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod if (file.fp == NULL) { return; } - // write_magic - file.write_u32(LLAMA_FILE_MAGIC); // magic - file.write_u32(LLAMA_FILE_VERSION); // version - // write_hparams - file.write_u32(model->hparams.n_vocab); - file.write_u32(model->hparams.n_embd); - file.write_u32(model->hparams.n_mult); - file.write_u32(model->hparams.n_head); - file.write_u32(model->hparams.n_layer); - file.write_u32(model->hparams.n_rot); - file.write_u32(LLAMA_FTYPE_ALL_F32); - // write_vocab - for now we are just writing the existing BPE voc. assuming karpathy's vocabulary is the same. idk. - uint32_t n_vocab = model->hparams.n_vocab; - for (uint32_t i = 0; i < n_vocab; i++) { - const auto & token_score = vocab->id_to_token.at(i); - file.write_u32((uint32_t) token_score.tok.size()); - file.write_raw(token_score.tok.data(), token_score.tok.size()); - file.write_raw(&token_score.score, sizeof(token_score.score)); - } - - // stuff AK weights into GG weights one by one. - // w->token_embedding_table -> model->tok_embeddings - // float* -> struct ggml_tensor - stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table); - stuff_karpathy_weights_into_gg(model->output, w->token_embedding_table); - - stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight); - //print_row(model->norm, 0); - - // for rms-att-weight - int row_length = model->hparams.n_embd; - const auto & hparams = model->hparams; - //int n_ff = model->hparams.n_embd; - int n_ff = get_n_ff(&hparams); - - for (uint32_t i = 0; i < model->hparams.n_layer; ++i){ - auto & layer = model->layers[i]; - // 1d - stuff_karpathy_weights_into_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]); - stuff_karpathy_weights_into_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]); - - // from 3d matrix layer x dim x dim to 2d matrix dim x dim - stuff_karpathy_weights_into_gg(layer.wq , &w->wq[i*row_length*row_length]); - stuff_karpathy_weights_into_gg(layer.wk , &w->wk[i*row_length*row_length]); - stuff_karpathy_weights_into_gg(layer.wv , &w->wv[i*row_length*row_length]); - stuff_karpathy_weights_into_gg(layer.wo , &w->wo[i*row_length*row_length]); - - stuff_karpathy_weights_into_gg(layer.w1 , &w->w1[i*row_length*n_ff]); - stuff_karpathy_weights_into_gg(layer.w2 , &w->w2[i*n_ff*row_length]); - stuff_karpathy_weights_into_gg(layer.w3 , &w->w3[i*row_length*n_ff]); - } - // write tensors - write_tensor(&file, model->tok_embeddings); - write_tensor(&file, model->norm); - write_tensor(&file, model->output); // ? - for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { - auto & layer = model->layers[i]; - - write_tensor(&file, layer.attention_norm); - write_tensor(&file, layer.wq); - write_tensor(&file, layer.wk); - write_tensor(&file, layer.wv); - write_tensor(&file, layer.wo); - write_tensor(&file, layer.ffn_norm); - write_tensor(&file, layer.w1); - write_tensor(&file, layer.w2); - write_tensor(&file, layer.w3); - } +#pragma message("TODO: implement file saving using gguf") + (void) vocab; + (void) model; + (void) w; +// // write_magic +// file.write_u32(LLAMA_FILE_MAGIC); // magic +// file.write_u32(LLAMA_FILE_VERSION); // version +// // write_hparams +// file.write_u32(model->hparams.n_vocab); +// file.write_u32(model->hparams.n_embd); +// file.write_u32(model->hparams.n_mult); +// file.write_u32(model->hparams.n_head); +// file.write_u32(model->hparams.n_layer); +// file.write_u32(model->hparams.n_rot); +// file.write_u32(LLAMA_FTYPE_ALL_F32); +// +// // write_vocab - for now we are just writing the existing BPE voc. assuming karpathy's vocabulary is the same. idk. +// uint32_t n_vocab = model->hparams.n_vocab; +// for (uint32_t i = 0; i < n_vocab; i++) { +// const auto & token_data = vocab->id_to_token.at(i); +// file.write_u32((uint32_t) token_data.tok.size()); +// file.write_raw(token_data.tok.data(), token_data.tok.size()); +// file.write_raw(&token_data.score, sizeof(token_data.score)); +// } +// +// // stuff AK weights into GG weights one by one. +// // w->token_embedding_table -> model->tok_embeddings +// // float* -> struct ggml_tensor +// stuff_karpathy_weights_into_gg(model->tok_embeddings, w->token_embedding_table); +// stuff_karpathy_weights_into_gg(model->output, w->token_embedding_table); +// +// stuff_karpathy_weights_into_gg(model->norm, w->rms_final_weight); +// //print_row(model->norm, 0); +// +// // for rms-att-weight +// int row_length = model->hparams.n_embd; +// const auto & hparams = model->hparams; +// //int n_ff = model->hparams.n_embd; +// int n_ff = get_n_ff(&hparams); +// +// for (uint32_t i = 0; i < model->hparams.n_layer; ++i){ +// auto & layer = model->layers[i]; +// // 1d +// stuff_karpathy_weights_into_gg(layer.attention_norm, &w->rms_att_weight[i*row_length]); +// stuff_karpathy_weights_into_gg(layer.ffn_norm , &w->rms_ffn_weight[i*row_length]); +// +// // from 3d matrix layer x dim x dim to 2d matrix dim x dim +// stuff_karpathy_weights_into_gg(layer.wq , &w->wq[i*row_length*row_length]); +// stuff_karpathy_weights_into_gg(layer.wk , &w->wk[i*row_length*row_length]); +// stuff_karpathy_weights_into_gg(layer.wv , &w->wv[i*row_length*row_length]); +// stuff_karpathy_weights_into_gg(layer.wo , &w->wo[i*row_length*row_length]); +// +// stuff_karpathy_weights_into_gg(layer.w1 , &w->w1[i*row_length*n_ff]); +// stuff_karpathy_weights_into_gg(layer.w2 , &w->w2[i*n_ff*row_length]); +// stuff_karpathy_weights_into_gg(layer.w3 , &w->w3[i*row_length*n_ff]); +// } +// // write tensors +// write_tensor(&file, model->tok_embeddings); +// write_tensor(&file, model->norm); +// write_tensor(&file, model->output); // ? +// for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { +// auto & layer = model->layers[i]; +// +// write_tensor(&file, layer.attention_norm); +// write_tensor(&file, layer.wq); +// write_tensor(&file, layer.wk); +// write_tensor(&file, layer.wv); +// write_tensor(&file, layer.wo); +// write_tensor(&file, layer.ffn_norm); +// write_tensor(&file, layer.w1); +// write_tensor(&file, layer.w2); +// write_tensor(&file, layer.w3); +// } } struct train_params get_default_train_params() { diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index 2185b9b0e..8a6ad882e 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -167,7 +167,7 @@ llama_token sampling_id(struct MyModel* mymodel) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // TODO: Apply penalties - // float nl_logit = logits[llama_token_nl()]; + // float nl_logit = logits[llama_token_nl(ctx)]; // auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); // llama_sample_repetition_penalty(ctx, &candidates_p, // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, @@ -176,7 +176,7 @@ llama_token sampling_id(struct MyModel* mymodel) { // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, // last_n_repeat, alpha_frequency, alpha_presence); // if (!penalize_nl) { - // logits[llama_token_nl()] = nl_logit; + // logits[llama_token_nl(ctx)] = nl_logit; // } if (temp <= 0) { @@ -211,7 +211,7 @@ const char * sampling(struct MyModel * mymodel) { llama_context * ctx = mymodel->ctx; int id = sampling_id(mymodel); static std::string ret; - if (id == llama_token_eos()) { + if (id == llama_token_eos(ctx)) { ret = ""; } else { ret = llama_token_to_str(ctx, id); diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 5192d6df5..8788571cb 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -67,7 +67,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); + fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str()); } fprintf(stderr, "\n"); } diff --git a/examples/gguf/gguf.cpp b/examples/gguf/gguf.cpp new file mode 100644 index 000000000..dee00df87 --- /dev/null +++ b/examples/gguf/gguf.cpp @@ -0,0 +1,246 @@ +#include "ggml.h" +#include "llama.h" + +#include +#include +#include +#include +#include +#include + +#undef MIN +#undef MAX +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define MAX(a, b) ((a) > (b) ? (a) : (b)) + +template +static std::string to_string(const T & val) { + std::stringstream ss; + ss << val; + return ss.str(); +} + +bool gguf_ex_write(const std::string & fname) { + struct gguf_context * ctx = gguf_init_empty(); + + gguf_set_val_u8 (ctx, "some.parameter.uint8", 0x12); + gguf_set_val_i8 (ctx, "some.parameter.int8", -0x13); + gguf_set_val_u16 (ctx, "some.parameter.uint16", 0x1234); + gguf_set_val_i16 (ctx, "some.parameter.int16", -0x1235); + gguf_set_val_u32 (ctx, "some.parameter.uint32", 0x12345678); + gguf_set_val_i32 (ctx, "some.parameter.int32", -0x12345679); + gguf_set_val_f32 (ctx, "some.parameter.float32", 0.123456789f); + gguf_set_val_bool(ctx, "some.parameter.bool", true); + gguf_set_val_str (ctx, "some.parameter.string", "hello world"); + + gguf_set_arr_data(ctx, "some.parameter.arr.i16", GGUF_TYPE_INT16, std::vector{ 1, 2, 3, 4, }.data(), 4); + gguf_set_arr_data(ctx, "some.parameter.arr.f32", GGUF_TYPE_FLOAT32, std::vector{ 3.145f, 2.718f, 1.414f, }.data(), 3); + gguf_set_arr_str (ctx, "some.parameter.arr.str", std::vector{ "hello", "world", "!" }.data(), 3); + + struct ggml_init_params params = { + /*.mem_size =*/ 128ull*1024ull*1024ull, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ctx_data = ggml_init(params); + + const int n_tensors = 10; + + // tensor infos + for (int i = 0; i < n_tensors; ++i) { + const std::string name = "tensor_" + to_string(i); + + int64_t ne[GGML_MAX_DIMS] = { 1 }; + int32_t n_dims = rand() % GGML_MAX_DIMS + 1; + + for (int j = 0; j < n_dims; ++j) { + ne[j] = rand() % 10 + 1; + } + + struct ggml_tensor * cur = ggml_new_tensor(ctx_data, GGML_TYPE_F32, n_dims, ne); + ggml_set_name(cur, name.c_str()); + + { + float * data = (float *) cur->data; + for (int j = 0; j < ggml_nelements(cur); ++j) { + data[j] = 100 + i; + } + } + + gguf_add_tensor(ctx, cur); + } + + gguf_write_to_file(ctx, fname.c_str(), false); + + fprintf(stdout, "%s: wrote file '%s;\n", __func__, fname.c_str()); + + ggml_free(ctx_data); + gguf_free(ctx); + + return true; +} + +// just read tensor info +bool gguf_ex_read_0(const std::string & fname) { + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ NULL, + }; + + struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + + fprintf(stdout, "%s: version: %d\n", __func__, gguf_get_version(ctx)); + fprintf(stdout, "%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); + fprintf(stdout, "%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx)); + + // kv + { + const int n_kv = gguf_get_n_kv(ctx); + + fprintf(stdout, "%s: n_kv: %d\n", __func__, n_kv); + + for (int i = 0; i < n_kv; ++i) { + const char * key = gguf_get_key(ctx, i); + + fprintf(stdout, "%s: kv[%d]: key = %s\n", __func__, i, key); + } + } + + // find kv string + { + const char * findkey = "some.parameter.string"; + + const int keyidx = gguf_find_key(ctx, findkey); + if (keyidx == -1) { + fprintf(stdout, "%s: find key: %s not found.\n", __func__, findkey); + } else { + const char * key_value = gguf_get_val_str(ctx, keyidx); + fprintf(stdout, "%s: find key: %s found, kv[%d] value = %s\n", __func__, findkey, keyidx, key_value); + } + } + + // tensor info + { + const int n_tensors = gguf_get_n_tensors(ctx); + + fprintf(stdout, "%s: n_tensors: %d\n", __func__, n_tensors); + + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name (ctx, i); + const size_t offset = gguf_get_tensor_offset(ctx, i); + + fprintf(stdout, "%s: tensor[%d]: name = %s, offset = %zu\n", __func__, i, name, offset); + } + } + + gguf_free(ctx); + + return true; +} + +// read and create ggml_context containing the tensors and their data +bool gguf_ex_read_1(const std::string & fname) { + struct ggml_context * ctx_data = NULL; + + struct gguf_init_params params = { + /*.no_alloc = */ false, + /*.ctx = */ &ctx_data, + }; + + struct gguf_context * ctx = gguf_init_from_file(fname.c_str(), params); + + fprintf(stdout, "%s: version: %d\n", __func__, gguf_get_version(ctx)); + fprintf(stdout, "%s: alignment: %zu\n", __func__, gguf_get_alignment(ctx)); + fprintf(stdout, "%s: data offset: %zu\n", __func__, gguf_get_data_offset(ctx)); + + // kv + { + const int n_kv = gguf_get_n_kv(ctx); + + fprintf(stdout, "%s: n_kv: %d\n", __func__, n_kv); + + for (int i = 0; i < n_kv; ++i) { + const char * key = gguf_get_key(ctx, i); + + fprintf(stdout, "%s: kv[%d]: key = %s\n", __func__, i, key); + } + } + + // tensor info + { + const int n_tensors = gguf_get_n_tensors(ctx); + + fprintf(stdout, "%s: n_tensors: %d\n", __func__, n_tensors); + + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name (ctx, i); + const size_t offset = gguf_get_tensor_offset(ctx, i); + + fprintf(stdout, "%s: tensor[%d]: name = %s, offset = %zu\n", __func__, i, name, offset); + } + } + + // data + { + const int n_tensors = gguf_get_n_tensors(ctx); + + for (int i = 0; i < n_tensors; ++i) { + fprintf(stdout, "%s: reading tensor %d data\n", __func__, i); + + const char * name = gguf_get_tensor_name(ctx, i); + + struct ggml_tensor * cur = ggml_get_tensor(ctx_data, name); + + fprintf(stdout, "%s: tensor[%d]: n_dims = %d, name = %s, data = %p\n", __func__, i, cur->n_dims, cur->name, cur->data); + + // print first 10 elements + const float * data = (const float *) cur->data; + + printf("%s data[:10] : ", name); + for (int j = 0; j < MIN(10, ggml_nelements(cur)); ++j) { + printf("%f ", data[j]); + } + printf("\n\n"); + + // check data + { + const float * data = (const float *) cur->data; + for (int j = 0; j < ggml_nelements(cur); ++j) { + if (data[j] != 100 + i) { + fprintf(stderr, "%s: tensor[%d]: data[%d] = %f\n", __func__, i, j, data[j]); + return false; + } + } + } + } + } + + fprintf(stdout, "%s: ctx_data size: %zu\n", __func__, ggml_get_mem_size(ctx_data)); + + ggml_free(ctx_data); + gguf_free(ctx); + + return true; +} + +int main(int argc, char ** argv) { + if (argc < 3) { + fprintf(stdout, "usage: %s data.gguf r|w\n", argv[0]); + return -1; + } + + const std::string fname(argv[1]); + const std::string mode (argv[2]); + + GGML_ASSERT((mode == "r" || mode == "w") && "mode must be r or w"); + + if (mode == "w") { + GGML_ASSERT(gguf_ex_write(fname) && "failed to write gguf file"); + } else if (mode == "r") { + GGML_ASSERT(gguf_ex_read_0(fname) && "failed to read gguf file"); + GGML_ASSERT(gguf_ex_read_1(fname) && "failed to read gguf file"); + } + + return 0; +} diff --git a/examples/gptneox-wip/cmpnct_gpt2bpe.hpp b/examples/gptneox-wip/cmpnct_gpt2bpe.hpp new file mode 100644 index 000000000..9d433f4b1 --- /dev/null +++ b/examples/gptneox-wip/cmpnct_gpt2bpe.hpp @@ -0,0 +1,1133 @@ +#ifndef CMPNCT_GPT2BPE +#define CMPNCT_GPT2BPE + +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +// Unicode GPT2 Byte Pair Encoding Tokenizer +// Adapted from https://github.com/cmp-nct/ggllm.cpp [MIT License] +// Removed loading of merges from HF json and parts made for a specific vocab + + +//----------------- +// Unicode library (from cmpnct_unicode.cpp) +//----------------- + +// Minimal library for high performance handling and categorization of UTF8 strings and characters +// Using std::string + +enum CNCTCharType { + DIGIT, // a numerical char in any language + LETTER, // a letter in any language + WHITESPACE, // any form of whitespace + ACCENT_MARK, // letter modifiers like ´ in é + PUNCTUATION, // punctuation including brackets + SYMBOL, // math, currency, other symbols + CONTROL, // control characters + MIXED, // a mix of the above + UNIDENTIFIED // something more exotic like emoji or separators +}; + +struct CNCTUnicode; + +struct CNCTString { + std::string str; + size_t utf8_chars; + + CNCTCharType char_type=UNIDENTIFIED; + bool is_sequential=false; + + size_t seq_offset_bytes=0; + size_t seq_offset_utf8_chars=0; + + bool operator==(const std::string &other) const; + bool operator==(const char other) const; + bool operator==(const CNCTString &other) const; + CNCTString &operator+=(const std::string &other); + CNCTString &operator+=(const char other); + friend CNCTString operator+(CNCTString lhs, const std::string &rhs); + friend CNCTString operator+(CNCTString lhs, const char rhs); + CNCTString& operator+=(const CNCTString& other); + friend CNCTString operator+(CNCTString lhs, const CNCTString& rhs); +}; + +struct CNCTUnicode { + static bool check_code_range(int c, const std::vector>& ranges); + static CNCTCharType get_code_type(int c); + static CNCTCharType get_code_type(const std::string &utf8_char); + static int utf8_len(const char c); + static int strlen_utf8(std::string src); + static std::vector split_utf8(const std::string &src); + static std::vector split_utf8_enhanced(const std::string &src); + static CNCTCharType string_identify(const std::string& str); + static bool string_test(const std::string& str, CNCTCharType chartype); +}; + +static const std::vector> digit_ranges = { +{0x30, 0x39}, {0xB2, 0xB3}, {0xB9, 0xB9}, {0x660, 0x669}, {0x6F0, 0x6F9}, {0x7C0, 0x7C9}, {0x966, 0x96F}, {0x9E6, 0x9EF}, {0xA66, 0xA6F}, {0xAE6, 0xAEF}, {0xB66, 0xB6F}, {0xBE6, 0xBEF}, {0xC66, 0xC6F}, +{0xCE6, 0xCEF}, {0xD66, 0xD6F}, {0xDE6, 0xDEF}, {0xE50, 0xE59}, {0xED0, 0xED9}, {0xF20, 0xF29}, {0x1040, 0x1049}, {0x1090, 0x1099}, {0x1369, 0x1371}, {0x17E0, 0x17E9}, {0x1810, 0x1819}, {0x1946, 0x194F}, +{0x19D0, 0x19DA}, {0x1A80, 0x1A89}, {0x1A90, 0x1A99}, {0x1B50, 0x1B59}, {0x1BB0, 0x1BB9}, {0x1C40, 0x1C49}, {0x1C50, 0x1C59}, {0x2070, 0x2070}, {0x2074, 0x2079}, {0x2080, 0x2089}, {0x2460, 0x2468}, +{0x2474, 0x247C}, {0x2488, 0x2490}, {0x24EA, 0x24EA}, {0x24F5, 0x24FD}, {0x24FF, 0x24FF}, {0x2776, 0x277E}, {0x2780, 0x2788}, {0x278A, 0x2792}, {0xA620, 0xA629}, {0xA8D0, 0xA8D9}, {0xA900, 0xA909}, +{0xA9D0, 0xA9D9}, {0xA9F0, 0xA9F9}, {0xAA50, 0xAA59}, {0xABF0, 0xABF9}, {0xFF10, 0xFF19}, {0x104A0, 0x104A9}, {0x10A40, 0x10A43}, {0x10D30, 0x10D39}, {0x10E60, 0x10E68}, {0x11052, 0x1105A}, +{0x11066, 0x1106F}, {0x110F0, 0x110F9}, {0x11136, 0x1113F}, {0x111D0, 0x111D9}, {0x112F0, 0x112F9}, {0x11450, 0x11459}, {0x114D0, 0x114D9}, {0x11650, 0x11659}, {0x116C0, 0x116C9}, {0x11730, 0x11739}, +{0x118E0, 0x118E9}, {0x11950, 0x11959}, {0x11C50, 0x11C59}, {0x11D50, 0x11D59}, {0x11DA0, 0x11DA9}, {0x16A60, 0x16A69}, {0x16B50, 0x16B59}, {0x1D7CE, 0x1D7FF}, {0x1E140, 0x1E149}, {0x1E2F0, 0x1E2F9}, +{0x1E950, 0x1E959}, {0x1F100, 0x1F10A}, {0x1FBF0, 0x1FBF9}, +}; + +static const std::vector> letter_ranges = { +{0x41, 0x5A}, {0x61, 0x7A}, {0xAA, 0xAA}, {0xB5, 0xB5}, {0xBA, 0xBA}, {0xC0, 0xD6}, {0xD8, 0xF6}, {0xF8, 0x2C1}, {0x2C6, 0x2D1}, {0x2E0, 0x2E4}, {0x2EC, 0x2EC}, {0x2EE, 0x2EE}, {0x370, 0x374}, +{0x376, 0x377}, {0x37A, 0x37D}, {0x37F, 0x37F}, {0x386, 0x386}, {0x388, 0x38A}, {0x38C, 0x38C}, {0x38E, 0x3A1}, {0x3A3, 0x3F5}, {0x3F7, 0x481}, {0x48A, 0x52F}, {0x531, 0x556}, {0x559, 0x559}, +{0x560, 0x588}, {0x5D0, 0x5EA}, {0x5EF, 0x5F2}, {0x620, 0x64A}, {0x66E, 0x66F}, {0x671, 0x6D3}, {0x6D5, 0x6D5}, {0x6E5, 0x6E6}, {0x6EE, 0x6EF}, {0x6FA, 0x6FC}, {0x6FF, 0x6FF}, {0x710, 0x710}, +{0x712, 0x72F}, {0x74D, 0x7A5}, {0x7B1, 0x7B1}, {0x7CA, 0x7EA}, {0x7F4, 0x7F5}, {0x7FA, 0x7FA}, {0x800, 0x815}, {0x81A, 0x81A}, {0x824, 0x824}, {0x828, 0x828}, {0x840, 0x858}, {0x860, 0x86A}, +{0x8A0, 0x8B4}, {0x8B6, 0x8C7}, {0x904, 0x939}, {0x93D, 0x93D}, {0x950, 0x950}, {0x958, 0x961}, {0x971, 0x980}, {0x985, 0x98C}, {0x98F, 0x990}, {0x993, 0x9A8}, {0x9AA, 0x9B0}, {0x9B2, 0x9B2}, +{0x9B6, 0x9B9}, {0x9BD, 0x9BD}, {0x9CE, 0x9CE}, {0x9DC, 0x9DD}, {0x9DF, 0x9E1}, {0x9F0, 0x9F1}, {0x9FC, 0x9FC}, {0xA05, 0xA0A}, {0xA0F, 0xA10}, {0xA13, 0xA28}, {0xA2A, 0xA30}, {0xA32, 0xA33}, +{0xA35, 0xA36}, {0xA38, 0xA39}, {0xA59, 0xA5C}, {0xA5E, 0xA5E}, {0xA72, 0xA74}, {0xA85, 0xA8D}, {0xA8F, 0xA91}, {0xA93, 0xAA8}, {0xAAA, 0xAB0}, {0xAB2, 0xAB3}, {0xAB5, 0xAB9}, {0xABD, 0xABD}, +{0xAD0, 0xAD0}, {0xAE0, 0xAE1}, {0xAF9, 0xAF9}, {0xB05, 0xB0C}, {0xB0F, 0xB10}, {0xB13, 0xB28}, {0xB2A, 0xB30}, {0xB32, 0xB33}, {0xB35, 0xB39}, {0xB3D, 0xB3D}, {0xB5C, 0xB5D}, {0xB5F, 0xB61}, +{0xB71, 0xB71}, {0xB83, 0xB83}, {0xB85, 0xB8A}, {0xB8E, 0xB90}, {0xB92, 0xB95}, {0xB99, 0xB9A}, {0xB9C, 0xB9C}, {0xB9E, 0xB9F}, {0xBA3, 0xBA4}, {0xBA8, 0xBAA}, {0xBAE, 0xBB9}, {0xBD0, 0xBD0}, +{0xC05, 0xC0C}, {0xC0E, 0xC10}, {0xC12, 0xC28}, {0xC2A, 0xC39}, {0xC3D, 0xC3D}, {0xC58, 0xC5A}, {0xC60, 0xC61}, {0xC80, 0xC80}, {0xC85, 0xC8C}, {0xC8E, 0xC90}, {0xC92, 0xCA8}, {0xCAA, 0xCB3}, +{0xCB5, 0xCB9}, {0xCBD, 0xCBD}, {0xCDE, 0xCDE}, {0xCE0, 0xCE1}, {0xCF1, 0xCF2}, {0xD04, 0xD0C}, {0xD0E, 0xD10}, {0xD12, 0xD3A}, {0xD3D, 0xD3D}, {0xD4E, 0xD4E}, {0xD54, 0xD56}, {0xD5F, 0xD61}, +{0xD7A, 0xD7F}, {0xD85, 0xD96}, {0xD9A, 0xDB1}, {0xDB3, 0xDBB}, {0xDBD, 0xDBD}, {0xDC0, 0xDC6}, {0xE01, 0xE30}, {0xE32, 0xE33}, {0xE40, 0xE46}, {0xE81, 0xE82}, {0xE84, 0xE84}, {0xE86, 0xE8A}, +{0xE8C, 0xEA3}, {0xEA5, 0xEA5}, {0xEA7, 0xEB0}, {0xEB2, 0xEB3}, {0xEBD, 0xEBD}, {0xEC0, 0xEC4}, {0xEC6, 0xEC6}, {0xEDC, 0xEDF}, {0xF00, 0xF00}, {0xF40, 0xF47}, {0xF49, 0xF6C}, {0xF88, 0xF8C}, +{0x1000, 0x102A}, {0x103F, 0x103F}, {0x1050, 0x1055}, {0x105A, 0x105D}, {0x1061, 0x1061}, {0x1065, 0x1066}, {0x106E, 0x1070}, {0x1075, 0x1081}, {0x108E, 0x108E}, {0x10A0, 0x10C5}, {0x10C7, 0x10C7}, +{0x10CD, 0x10CD}, {0x10D0, 0x10FA}, {0x10FC, 0x1248}, {0x124A, 0x124D}, {0x1250, 0x1256}, {0x1258, 0x1258}, {0x125A, 0x125D}, {0x1260, 0x1288}, {0x128A, 0x128D}, {0x1290, 0x12B0}, {0x12B2, 0x12B5}, +{0x12B8, 0x12BE}, {0x12C0, 0x12C0}, {0x12C2, 0x12C5}, {0x12C8, 0x12D6}, {0x12D8, 0x1310}, {0x1312, 0x1315}, {0x1318, 0x135A}, {0x1380, 0x138F}, {0x13A0, 0x13F5}, {0x13F8, 0x13FD}, {0x1401, 0x166C}, +{0x166F, 0x167F}, {0x1681, 0x169A}, {0x16A0, 0x16EA}, {0x16F1, 0x16F8}, {0x1700, 0x170C}, {0x170E, 0x1711}, {0x1720, 0x1731}, {0x1740, 0x1751}, {0x1760, 0x176C}, {0x176E, 0x1770}, {0x1780, 0x17B3}, +{0x17D7, 0x17D7}, {0x17DC, 0x17DC}, {0x1820, 0x1878}, {0x1880, 0x1884}, {0x1887, 0x18A8}, {0x18AA, 0x18AA}, {0x18B0, 0x18F5}, {0x1900, 0x191E}, {0x1950, 0x196D}, {0x1970, 0x1974}, {0x1980, 0x19AB}, +{0x19B0, 0x19C9}, {0x1A00, 0x1A16}, {0x1A20, 0x1A54}, {0x1AA7, 0x1AA7}, {0x1B05, 0x1B33}, {0x1B45, 0x1B4B}, {0x1B83, 0x1BA0}, {0x1BAE, 0x1BAF}, {0x1BBA, 0x1BE5}, {0x1C00, 0x1C23}, {0x1C4D, 0x1C4F}, +{0x1C5A, 0x1C7D}, {0x1C80, 0x1C88}, {0x1C90, 0x1CBA}, {0x1CBD, 0x1CBF}, {0x1CE9, 0x1CEC}, {0x1CEE, 0x1CF3}, {0x1CF5, 0x1CF6}, {0x1CFA, 0x1CFA}, {0x1D00, 0x1DBF}, {0x1E00, 0x1F15}, {0x1F18, 0x1F1D}, +{0x1F20, 0x1F45}, {0x1F48, 0x1F4D}, {0x1F50, 0x1F57}, {0x1F59, 0x1F59}, {0x1F5B, 0x1F5B}, {0x1F5D, 0x1F5D}, {0x1F5F, 0x1F7D}, {0x1F80, 0x1FB4}, {0x1FB6, 0x1FBC}, {0x1FBE, 0x1FBE}, {0x1FC2, 0x1FC4}, +{0x1FC6, 0x1FCC}, {0x1FD0, 0x1FD3}, {0x1FD6, 0x1FDB}, {0x1FE0, 0x1FEC}, {0x1FF2, 0x1FF4}, {0x1FF6, 0x1FFC}, {0x2071, 0x2071}, {0x207F, 0x207F}, {0x2090, 0x209C}, {0x2102, 0x2102}, {0x2107, 0x2107}, +{0x210A, 0x2113}, {0x2115, 0x2115}, {0x2119, 0x211D}, {0x2124, 0x2124}, {0x2126, 0x2126}, {0x2128, 0x2128}, {0x212A, 0x212D}, {0x212F, 0x2139}, {0x213C, 0x213F}, {0x2145, 0x2149}, {0x214E, 0x214E}, +{0x2183, 0x2184}, {0x2C00, 0x2C2E}, {0x2C30, 0x2C5E}, {0x2C60, 0x2CE4}, {0x2CEB, 0x2CEE}, {0x2CF2, 0x2CF3}, {0x2D00, 0x2D25}, {0x2D27, 0x2D27}, {0x2D2D, 0x2D2D}, {0x2D30, 0x2D67}, {0x2D6F, 0x2D6F}, +{0x2D80, 0x2D96}, {0x2DA0, 0x2DA6}, {0x2DA8, 0x2DAE}, {0x2DB0, 0x2DB6}, {0x2DB8, 0x2DBE}, {0x2DC0, 0x2DC6}, {0x2DC8, 0x2DCE}, {0x2DD0, 0x2DD6}, {0x2DD8, 0x2DDE}, {0x2E2F, 0x2E2F}, {0x3005, 0x3006}, +{0x3031, 0x3035}, {0x303B, 0x303C}, {0x3041, 0x3096}, {0x309D, 0x309F}, {0x30A1, 0x30FA}, {0x30FC, 0x30FF}, {0x3105, 0x312F}, {0x3131, 0x318E}, {0x31A0, 0x31BF}, {0x31F0, 0x31FF}, {0x3400, 0x4DBF}, +{0x4E00, 0x9FFC}, {0xA000, 0xA48C}, {0xA4D0, 0xA4FD}, {0xA500, 0xA60C}, {0xA610, 0xA61F}, {0xA62A, 0xA62B}, {0xA640, 0xA66E}, {0xA67F, 0xA69D}, {0xA6A0, 0xA6E5}, {0xA717, 0xA71F}, {0xA722, 0xA788}, +{0xA78B, 0xA7BF}, {0xA7C2, 0xA7CA}, {0xA7F5, 0xA801}, {0xA803, 0xA805}, {0xA807, 0xA80A}, {0xA80C, 0xA822}, {0xA840, 0xA873}, {0xA882, 0xA8B3}, {0xA8F2, 0xA8F7}, {0xA8FB, 0xA8FB}, {0xA8FD, 0xA8FE}, +{0xA90A, 0xA925}, {0xA930, 0xA946}, {0xA960, 0xA97C}, {0xA984, 0xA9B2}, {0xA9CF, 0xA9CF}, {0xA9E0, 0xA9E4}, {0xA9E6, 0xA9EF}, {0xA9FA, 0xA9FE}, {0xAA00, 0xAA28}, {0xAA40, 0xAA42}, {0xAA44, 0xAA4B}, +{0xAA60, 0xAA76}, {0xAA7A, 0xAA7A}, {0xAA7E, 0xAAAF}, {0xAAB1, 0xAAB1}, {0xAAB5, 0xAAB6}, {0xAAB9, 0xAABD}, {0xAAC0, 0xAAC0}, {0xAAC2, 0xAAC2}, {0xAADB, 0xAADD}, {0xAAE0, 0xAAEA}, {0xAAF2, 0xAAF4}, +{0xAB01, 0xAB06}, {0xAB09, 0xAB0E}, {0xAB11, 0xAB16}, {0xAB20, 0xAB26}, {0xAB28, 0xAB2E}, {0xAB30, 0xAB5A}, {0xAB5C, 0xAB69}, {0xAB70, 0xABE2}, {0xAC00, 0xD7A3}, {0xD7B0, 0xD7C6}, {0xD7CB, 0xD7FB}, +{0xF900, 0xFA6D}, {0xFA70, 0xFAD9}, {0xFB00, 0xFB06}, {0xFB13, 0xFB17}, {0xFB1D, 0xFB1D}, {0xFB1F, 0xFB28}, {0xFB2A, 0xFB36}, {0xFB38, 0xFB3C}, {0xFB3E, 0xFB3E}, {0xFB40, 0xFB41}, {0xFB43, 0xFB44}, +{0xFB46, 0xFBB1}, {0xFBD3, 0xFD3D}, {0xFD50, 0xFD8F}, {0xFD92, 0xFDC7}, {0xFDF0, 0xFDFB}, {0xFE70, 0xFE74}, {0xFE76, 0xFEFC}, {0xFF21, 0xFF3A}, {0xFF41, 0xFF5A}, {0xFF66, 0xFFBE}, {0xFFC2, 0xFFC7}, +{0xFFCA, 0xFFCF}, {0xFFD2, 0xFFD7}, {0xFFDA, 0xFFDC}, {0x10000, 0x1000B}, {0x1000D, 0x10026}, {0x10028, 0x1003A}, {0x1003C, 0x1003D}, {0x1003F, 0x1004D}, {0x10050, 0x1005D}, {0x10080, 0x100FA}, +{0x10280, 0x1029C}, {0x102A0, 0x102D0}, {0x10300, 0x1031F}, {0x1032D, 0x10340}, {0x10342, 0x10349}, {0x10350, 0x10375}, {0x10380, 0x1039D}, {0x103A0, 0x103C3}, {0x103C8, 0x103CF}, {0x10400, 0x1049D}, +{0x104B0, 0x104D3}, {0x104D8, 0x104FB}, {0x10500, 0x10527}, {0x10530, 0x10563}, {0x10600, 0x10736}, {0x10740, 0x10755}, {0x10760, 0x10767}, {0x10800, 0x10805}, {0x10808, 0x10808}, {0x1080A, 0x10835}, +{0x10837, 0x10838}, {0x1083C, 0x1083C}, {0x1083F, 0x10855}, {0x10860, 0x10876}, {0x10880, 0x1089E}, {0x108E0, 0x108F2}, {0x108F4, 0x108F5}, {0x10900, 0x10915}, {0x10920, 0x10939}, {0x10980, 0x109B7}, +{0x109BE, 0x109BF}, {0x10A00, 0x10A00}, {0x10A10, 0x10A13}, {0x10A15, 0x10A17}, {0x10A19, 0x10A35}, {0x10A60, 0x10A7C}, {0x10A80, 0x10A9C}, {0x10AC0, 0x10AC7}, {0x10AC9, 0x10AE4}, {0x10B00, 0x10B35}, +{0x10B40, 0x10B55}, {0x10B60, 0x10B72}, {0x10B80, 0x10B91}, {0x10C00, 0x10C48}, {0x10C80, 0x10CB2}, {0x10CC0, 0x10CF2}, {0x10D00, 0x10D23}, {0x10E80, 0x10EA9}, {0x10EB0, 0x10EB1}, {0x10F00, 0x10F1C}, +{0x10F27, 0x10F27}, {0x10F30, 0x10F45}, {0x10FB0, 0x10FC4}, {0x10FE0, 0x10FF6}, {0x11003, 0x11037}, {0x11083, 0x110AF}, {0x110D0, 0x110E8}, {0x11103, 0x11126}, {0x11144, 0x11144}, {0x11147, 0x11147}, +{0x11150, 0x11172}, {0x11176, 0x11176}, {0x11183, 0x111B2}, {0x111C1, 0x111C4}, {0x111DA, 0x111DA}, {0x111DC, 0x111DC}, {0x11200, 0x11211}, {0x11213, 0x1122B}, {0x11280, 0x11286}, {0x11288, 0x11288}, +{0x1128A, 0x1128D}, {0x1128F, 0x1129D}, {0x1129F, 0x112A8}, {0x112B0, 0x112DE}, {0x11305, 0x1130C}, {0x1130F, 0x11310}, {0x11313, 0x11328}, {0x1132A, 0x11330}, {0x11332, 0x11333}, {0x11335, 0x11339}, +{0x1133D, 0x1133D}, {0x11350, 0x11350}, {0x1135D, 0x11361}, {0x11400, 0x11434}, {0x11447, 0x1144A}, {0x1145F, 0x11461}, {0x11480, 0x114AF}, {0x114C4, 0x114C5}, {0x114C7, 0x114C7}, {0x11580, 0x115AE}, +{0x115D8, 0x115DB}, {0x11600, 0x1162F}, {0x11644, 0x11644}, {0x11680, 0x116AA}, {0x116B8, 0x116B8}, {0x11700, 0x1171A}, {0x11800, 0x1182B}, {0x118A0, 0x118DF}, {0x118FF, 0x11906}, {0x11909, 0x11909}, +{0x1190C, 0x11913}, {0x11915, 0x11916}, {0x11918, 0x1192F}, {0x1193F, 0x1193F}, {0x11941, 0x11941}, {0x119A0, 0x119A7}, {0x119AA, 0x119D0}, {0x119E1, 0x119E1}, {0x119E3, 0x119E3}, {0x11A00, 0x11A00}, +{0x11A0B, 0x11A32}, {0x11A3A, 0x11A3A}, {0x11A50, 0x11A50}, {0x11A5C, 0x11A89}, {0x11A9D, 0x11A9D}, {0x11AC0, 0x11AF8}, {0x11C00, 0x11C08}, {0x11C0A, 0x11C2E}, {0x11C40, 0x11C40}, {0x11C72, 0x11C8F}, +{0x11D00, 0x11D06}, {0x11D08, 0x11D09}, {0x11D0B, 0x11D30}, {0x11D46, 0x11D46}, {0x11D60, 0x11D65}, {0x11D67, 0x11D68}, {0x11D6A, 0x11D89}, {0x11D98, 0x11D98}, {0x11EE0, 0x11EF2}, {0x11FB0, 0x11FB0}, +{0x12000, 0x12399}, {0x12480, 0x12543}, {0x13000, 0x1342E}, {0x14400, 0x14646}, {0x16800, 0x16A38}, {0x16A40, 0x16A5E}, {0x16AD0, 0x16AED}, {0x16B00, 0x16B2F}, {0x16B40, 0x16B43}, {0x16B63, 0x16B77}, +{0x16B7D, 0x16B8F}, {0x16E40, 0x16E7F}, {0x16F00, 0x16F4A}, {0x16F50, 0x16F50}, {0x16F93, 0x16F9F}, {0x16FE0, 0x16FE1}, {0x16FE3, 0x16FE3}, {0x17000, 0x187F7}, {0x18800, 0x18CD5}, {0x18D00, 0x18D08}, +{0x1B000, 0x1B11E}, {0x1B150, 0x1B152}, {0x1B164, 0x1B167}, {0x1B170, 0x1B2FB}, {0x1BC00, 0x1BC6A}, {0x1BC70, 0x1BC7C}, {0x1BC80, 0x1BC88}, {0x1BC90, 0x1BC99}, {0x1D400, 0x1D454}, {0x1D456, 0x1D49C}, +{0x1D49E, 0x1D49F}, {0x1D4A2, 0x1D4A2}, {0x1D4A5, 0x1D4A6}, {0x1D4A9, 0x1D4AC}, {0x1D4AE, 0x1D4B9}, {0x1D4BB, 0x1D4BB}, {0x1D4BD, 0x1D4C3}, {0x1D4C5, 0x1D505}, {0x1D507, 0x1D50A}, {0x1D50D, 0x1D514}, +{0x1D516, 0x1D51C}, {0x1D51E, 0x1D539}, {0x1D53B, 0x1D53E}, {0x1D540, 0x1D544}, {0x1D546, 0x1D546}, {0x1D54A, 0x1D550}, {0x1D552, 0x1D6A5}, {0x1D6A8, 0x1D6C0}, {0x1D6C2, 0x1D6DA}, {0x1D6DC, 0x1D6FA}, +{0x1D6FC, 0x1D714}, {0x1D716, 0x1D734}, {0x1D736, 0x1D74E}, {0x1D750, 0x1D76E}, {0x1D770, 0x1D788}, {0x1D78A, 0x1D7A8}, {0x1D7AA, 0x1D7C2}, {0x1D7C4, 0x1D7CB}, {0x1E100, 0x1E12C}, {0x1E137, 0x1E13D}, +{0x1E14E, 0x1E14E}, {0x1E2C0, 0x1E2EB}, {0x1E800, 0x1E8C4}, {0x1E900, 0x1E943}, {0x1E94B, 0x1E94B}, {0x1EE00, 0x1EE03}, {0x1EE05, 0x1EE1F}, {0x1EE21, 0x1EE22}, {0x1EE24, 0x1EE24}, {0x1EE27, 0x1EE27}, +{0x1EE29, 0x1EE32}, {0x1EE34, 0x1EE37}, {0x1EE39, 0x1EE39}, {0x1EE3B, 0x1EE3B}, {0x1EE42, 0x1EE42}, {0x1EE47, 0x1EE47}, {0x1EE49, 0x1EE49}, {0x1EE4B, 0x1EE4B}, {0x1EE4D, 0x1EE4F}, {0x1EE51, 0x1EE52}, +{0x1EE54, 0x1EE54}, {0x1EE57, 0x1EE57}, {0x1EE59, 0x1EE59}, {0x1EE5B, 0x1EE5B}, {0x1EE5D, 0x1EE5D}, {0x1EE5F, 0x1EE5F}, {0x1EE61, 0x1EE62}, {0x1EE64, 0x1EE64}, {0x1EE67, 0x1EE6A}, {0x1EE6C, 0x1EE72}, +{0x1EE74, 0x1EE77}, {0x1EE79, 0x1EE7C}, {0x1EE7E, 0x1EE7E}, {0x1EE80, 0x1EE89}, {0x1EE8B, 0x1EE9B}, {0x1EEA1, 0x1EEA3}, {0x1EEA5, 0x1EEA9}, {0x1EEAB, 0x1EEBB}, {0x20000, 0x2A6DD}, {0x2A700, 0x2B734}, +{0x2B740, 0x2B81D}, {0x2B820, 0x2CEA1}, {0x2CEB0, 0x2EBE0}, {0x2F800, 0x2FA1D}, {0x30000, 0x3134A}, +}; + +static const std::vector> whitespace_ranges = { +{0x9, 0xD}, {0x1C, 0x20}, {0x85, 0x85}, {0xA0, 0xA0}, {0x1680, 0x1680}, {0x2000, 0x200A}, {0x2028, 0x2029}, {0x202F, 0x202F}, {0x205F, 0x205F}, {0x3000, 0x3000}, +}; + +static const std::vector> accent_mark_ranges = { +{0x300, 0x36F}, {0x483, 0x489}, {0x591, 0x5BD}, {0x5BF, 0x5BF}, {0x5C1, 0x5C2}, {0x5C4, 0x5C5}, {0x5C7, 0x5C7}, {0x610, 0x61A}, {0x64B, 0x65F}, {0x670, 0x670}, {0x6D6, 0x6DC}, {0x6DF, 0x6E4}, +{0x6E7, 0x6E8}, {0x6EA, 0x6ED}, {0x711, 0x711}, {0x730, 0x74A}, {0x7A6, 0x7B0}, {0x7EB, 0x7F3}, {0x7FD, 0x7FD}, {0x816, 0x819}, {0x81B, 0x823}, {0x825, 0x827}, {0x829, 0x82D}, {0x859, 0x85B}, +{0x8D3, 0x8E1}, {0x8E3, 0x903}, {0x93A, 0x93C}, {0x93E, 0x94F}, {0x951, 0x957}, {0x962, 0x963}, {0x981, 0x983}, {0x9BC, 0x9BC}, {0x9BE, 0x9C4}, {0x9C7, 0x9C8}, {0x9CB, 0x9CD}, {0x9D7, 0x9D7}, +{0x9E2, 0x9E3}, {0x9FE, 0x9FE}, {0xA01, 0xA03}, {0xA3C, 0xA3C}, {0xA3E, 0xA42}, {0xA47, 0xA48}, {0xA4B, 0xA4D}, {0xA51, 0xA51}, {0xA70, 0xA71}, {0xA75, 0xA75}, {0xA81, 0xA83}, {0xABC, 0xABC}, +{0xABE, 0xAC5}, {0xAC7, 0xAC9}, {0xACB, 0xACD}, {0xAE2, 0xAE3}, {0xAFA, 0xAFF}, {0xB01, 0xB03}, {0xB3C, 0xB3C}, {0xB3E, 0xB44}, {0xB47, 0xB48}, {0xB4B, 0xB4D}, {0xB55, 0xB57}, {0xB62, 0xB63}, +{0xB82, 0xB82}, {0xBBE, 0xBC2}, {0xBC6, 0xBC8}, {0xBCA, 0xBCD}, {0xBD7, 0xBD7}, {0xC00, 0xC04}, {0xC3E, 0xC44}, {0xC46, 0xC48}, {0xC4A, 0xC4D}, {0xC55, 0xC56}, {0xC62, 0xC63}, {0xC81, 0xC83}, +{0xCBC, 0xCBC}, {0xCBE, 0xCC4}, {0xCC6, 0xCC8}, {0xCCA, 0xCCD}, {0xCD5, 0xCD6}, {0xCE2, 0xCE3}, {0xD00, 0xD03}, {0xD3B, 0xD3C}, {0xD3E, 0xD44}, {0xD46, 0xD48}, {0xD4A, 0xD4D}, {0xD57, 0xD57}, +{0xD62, 0xD63}, {0xD81, 0xD83}, {0xDCA, 0xDCA}, {0xDCF, 0xDD4}, {0xDD6, 0xDD6}, {0xDD8, 0xDDF}, {0xDF2, 0xDF3}, {0xE31, 0xE31}, {0xE34, 0xE3A}, {0xE47, 0xE4E}, {0xEB1, 0xEB1}, {0xEB4, 0xEBC}, +{0xEC8, 0xECD}, {0xF18, 0xF19}, {0xF35, 0xF35}, {0xF37, 0xF37}, {0xF39, 0xF39}, {0xF3E, 0xF3F}, {0xF71, 0xF84}, {0xF86, 0xF87}, {0xF8D, 0xF97}, {0xF99, 0xFBC}, {0xFC6, 0xFC6}, {0x102B, 0x103E}, +{0x1056, 0x1059}, {0x105E, 0x1060}, {0x1062, 0x1064}, {0x1067, 0x106D}, {0x1071, 0x1074}, {0x1082, 0x108D}, {0x108F, 0x108F}, {0x109A, 0x109D}, {0x135D, 0x135F}, {0x1712, 0x1714}, {0x1732, 0x1734}, +{0x1752, 0x1753}, {0x1772, 0x1773}, {0x17B4, 0x17D3}, {0x17DD, 0x17DD}, {0x180B, 0x180D}, {0x1885, 0x1886}, {0x18A9, 0x18A9}, {0x1920, 0x192B}, {0x1930, 0x193B}, {0x1A17, 0x1A1B}, {0x1A55, 0x1A5E}, +{0x1A60, 0x1A7C}, {0x1A7F, 0x1A7F}, {0x1AB0, 0x1AC0}, {0x1B00, 0x1B04}, {0x1B34, 0x1B44}, {0x1B6B, 0x1B73}, {0x1B80, 0x1B82}, {0x1BA1, 0x1BAD}, {0x1BE6, 0x1BF3}, {0x1C24, 0x1C37}, {0x1CD0, 0x1CD2}, +{0x1CD4, 0x1CE8}, {0x1CED, 0x1CED}, {0x1CF4, 0x1CF4}, {0x1CF7, 0x1CF9}, {0x1DC0, 0x1DF9}, {0x1DFB, 0x1DFF}, {0x20D0, 0x20F0}, {0x2CEF, 0x2CF1}, {0x2D7F, 0x2D7F}, {0x2DE0, 0x2DFF}, {0x302A, 0x302F}, +{0x3099, 0x309A}, {0xA66F, 0xA672}, {0xA674, 0xA67D}, {0xA69E, 0xA69F}, {0xA6F0, 0xA6F1}, {0xA802, 0xA802}, {0xA806, 0xA806}, {0xA80B, 0xA80B}, {0xA823, 0xA827}, {0xA82C, 0xA82C}, {0xA880, 0xA881}, +{0xA8B4, 0xA8C5}, {0xA8E0, 0xA8F1}, {0xA8FF, 0xA8FF}, {0xA926, 0xA92D}, {0xA947, 0xA953}, {0xA980, 0xA983}, {0xA9B3, 0xA9C0}, {0xA9E5, 0xA9E5}, {0xAA29, 0xAA36}, {0xAA43, 0xAA43}, {0xAA4C, 0xAA4D}, +{0xAA7B, 0xAA7D}, {0xAAB0, 0xAAB0}, {0xAAB2, 0xAAB4}, {0xAAB7, 0xAAB8}, {0xAABE, 0xAABF}, {0xAAC1, 0xAAC1}, {0xAAEB, 0xAAEF}, {0xAAF5, 0xAAF6}, {0xABE3, 0xABEA}, {0xABEC, 0xABED}, {0xFB1E, 0xFB1E}, +{0xFE00, 0xFE0F}, {0xFE20, 0xFE2F}, {0x101FD, 0x101FD}, {0x102E0, 0x102E0}, {0x10376, 0x1037A}, {0x10A01, 0x10A03}, {0x10A05, 0x10A06}, {0x10A0C, 0x10A0F}, {0x10A38, 0x10A3A}, {0x10A3F, 0x10A3F}, +{0x10AE5, 0x10AE6}, {0x10D24, 0x10D27}, {0x10EAB, 0x10EAC}, {0x10F46, 0x10F50}, {0x11000, 0x11002}, {0x11038, 0x11046}, {0x1107F, 0x11082}, {0x110B0, 0x110BA}, {0x11100, 0x11102}, {0x11127, 0x11134}, +{0x11145, 0x11146}, {0x11173, 0x11173}, {0x11180, 0x11182}, {0x111B3, 0x111C0}, {0x111C9, 0x111CC}, {0x111CE, 0x111CF}, {0x1122C, 0x11237}, {0x1123E, 0x1123E}, {0x112DF, 0x112EA}, {0x11300, 0x11303}, +{0x1133B, 0x1133C}, {0x1133E, 0x11344}, {0x11347, 0x11348}, {0x1134B, 0x1134D}, {0x11357, 0x11357}, {0x11362, 0x11363}, {0x11366, 0x1136C}, {0x11370, 0x11374}, {0x11435, 0x11446}, {0x1145E, 0x1145E}, +{0x114B0, 0x114C3}, {0x115AF, 0x115B5}, {0x115B8, 0x115C0}, {0x115DC, 0x115DD}, {0x11630, 0x11640}, {0x116AB, 0x116B7}, {0x1171D, 0x1172B}, {0x1182C, 0x1183A}, {0x11930, 0x11935}, {0x11937, 0x11938}, +{0x1193B, 0x1193E}, {0x11940, 0x11940}, {0x11942, 0x11943}, {0x119D1, 0x119D7}, {0x119DA, 0x119E0}, {0x119E4, 0x119E4}, {0x11A01, 0x11A0A}, {0x11A33, 0x11A39}, {0x11A3B, 0x11A3E}, {0x11A47, 0x11A47}, +{0x11A51, 0x11A5B}, {0x11A8A, 0x11A99}, {0x11C2F, 0x11C36}, {0x11C38, 0x11C3F}, {0x11C92, 0x11CA7}, {0x11CA9, 0x11CB6}, {0x11D31, 0x11D36}, {0x11D3A, 0x11D3A}, {0x11D3C, 0x11D3D}, {0x11D3F, 0x11D45}, +{0x11D47, 0x11D47}, {0x11D8A, 0x11D8E}, {0x11D90, 0x11D91}, {0x11D93, 0x11D97}, {0x11EF3, 0x11EF6}, {0x16AF0, 0x16AF4}, {0x16B30, 0x16B36}, {0x16F4F, 0x16F4F}, {0x16F51, 0x16F87}, {0x16F8F, 0x16F92}, +{0x16FE4, 0x16FE4}, {0x16FF0, 0x16FF1}, {0x1BC9D, 0x1BC9E}, {0x1D165, 0x1D169}, {0x1D16D, 0x1D172}, {0x1D17B, 0x1D182}, {0x1D185, 0x1D18B}, {0x1D1AA, 0x1D1AD}, {0x1D242, 0x1D244}, {0x1DA00, 0x1DA36}, +{0x1DA3B, 0x1DA6C}, {0x1DA75, 0x1DA75}, {0x1DA84, 0x1DA84}, {0x1DA9B, 0x1DA9F}, {0x1DAA1, 0x1DAAF}, {0x1E000, 0x1E006}, {0x1E008, 0x1E018}, {0x1E01B, 0x1E021}, {0x1E023, 0x1E024}, {0x1E026, 0x1E02A}, +{0x1E130, 0x1E136}, {0x1E2EC, 0x1E2EF}, {0x1E8D0, 0x1E8D6}, {0x1E944, 0x1E94A}, {0xE0100, 0xE01EF}, +}; + +static const std::vector> punctuation_ranges = { +{0x21, 0x23}, {0x25, 0x2A}, {0x2C, 0x2F}, {0x3A, 0x3B}, {0x3F, 0x40}, {0x5B, 0x5D}, {0x5F, 0x5F}, {0x7B, 0x7B}, {0x7D, 0x7D}, {0xA1, 0xA1}, {0xA7, 0xA7}, {0xAB, 0xAB}, {0xB6, 0xB7}, {0xBB, 0xBB}, +{0xBF, 0xBF}, {0x37E, 0x37E}, {0x387, 0x387}, {0x55A, 0x55F}, {0x589, 0x58A}, {0x5BE, 0x5BE}, {0x5C0, 0x5C0}, {0x5C3, 0x5C3}, {0x5C6, 0x5C6}, {0x5F3, 0x5F4}, {0x609, 0x60A}, {0x60C, 0x60D}, +{0x61B, 0x61B}, {0x61E, 0x61F}, {0x66A, 0x66D}, {0x6D4, 0x6D4}, {0x700, 0x70D}, {0x7F7, 0x7F9}, {0x830, 0x83E}, {0x85E, 0x85E}, {0x964, 0x965}, {0x970, 0x970}, {0x9FD, 0x9FD}, {0xA76, 0xA76}, +{0xAF0, 0xAF0}, {0xC77, 0xC77}, {0xC84, 0xC84}, {0xDF4, 0xDF4}, {0xE4F, 0xE4F}, {0xE5A, 0xE5B}, {0xF04, 0xF12}, {0xF14, 0xF14}, {0xF3A, 0xF3D}, {0xF85, 0xF85}, {0xFD0, 0xFD4}, {0xFD9, 0xFDA}, +{0x104A, 0x104F}, {0x10FB, 0x10FB}, {0x1360, 0x1368}, {0x1400, 0x1400}, {0x166E, 0x166E}, {0x169B, 0x169C}, {0x16EB, 0x16ED}, {0x1735, 0x1736}, {0x17D4, 0x17D6}, {0x17D8, 0x17DA}, {0x1800, 0x180A}, +{0x1944, 0x1945}, {0x1A1E, 0x1A1F}, {0x1AA0, 0x1AA6}, {0x1AA8, 0x1AAD}, {0x1B5A, 0x1B60}, {0x1BFC, 0x1BFF}, {0x1C3B, 0x1C3F}, {0x1C7E, 0x1C7F}, {0x1CC0, 0x1CC7}, {0x1CD3, 0x1CD3}, {0x2010, 0x2027}, +{0x2030, 0x2043}, {0x2045, 0x2051}, {0x2053, 0x205E}, {0x207D, 0x207E}, {0x208D, 0x208E}, {0x2308, 0x230B}, {0x2329, 0x232A}, {0x2768, 0x2775}, {0x27C5, 0x27C6}, {0x27E6, 0x27EF}, {0x2983, 0x2998}, +{0x29D8, 0x29DB}, {0x29FC, 0x29FD}, {0x2CF9, 0x2CFC}, {0x2CFE, 0x2CFF}, {0x2D70, 0x2D70}, {0x2E00, 0x2E2E}, {0x2E30, 0x2E4F}, {0x2E52, 0x2E52}, {0x3001, 0x3003}, {0x3008, 0x3011}, {0x3014, 0x301F}, +{0x3030, 0x3030}, {0x303D, 0x303D}, {0x30A0, 0x30A0}, {0x30FB, 0x30FB}, {0xA4FE, 0xA4FF}, {0xA60D, 0xA60F}, {0xA673, 0xA673}, {0xA67E, 0xA67E}, {0xA6F2, 0xA6F7}, {0xA874, 0xA877}, {0xA8CE, 0xA8CF}, +{0xA8F8, 0xA8FA}, {0xA8FC, 0xA8FC}, {0xA92E, 0xA92F}, {0xA95F, 0xA95F}, {0xA9C1, 0xA9CD}, {0xA9DE, 0xA9DF}, {0xAA5C, 0xAA5F}, {0xAADE, 0xAADF}, {0xAAF0, 0xAAF1}, {0xABEB, 0xABEB}, {0xFD3E, 0xFD3F}, +{0xFE10, 0xFE19}, {0xFE30, 0xFE52}, {0xFE54, 0xFE61}, {0xFE63, 0xFE63}, {0xFE68, 0xFE68}, {0xFE6A, 0xFE6B}, {0xFF01, 0xFF03}, {0xFF05, 0xFF0A}, {0xFF0C, 0xFF0F}, {0xFF1A, 0xFF1B}, {0xFF1F, 0xFF20}, +{0xFF3B, 0xFF3D}, {0xFF3F, 0xFF3F}, {0xFF5B, 0xFF5B}, {0xFF5D, 0xFF5D}, {0xFF5F, 0xFF65}, {0x10100, 0x10102}, {0x1039F, 0x1039F}, {0x103D0, 0x103D0}, {0x1056F, 0x1056F}, {0x10857, 0x10857}, +{0x1091F, 0x1091F}, {0x1093F, 0x1093F}, {0x10A50, 0x10A58}, {0x10A7F, 0x10A7F}, {0x10AF0, 0x10AF6}, {0x10B39, 0x10B3F}, {0x10B99, 0x10B9C}, {0x10EAD, 0x10EAD}, {0x10F55, 0x10F59}, {0x11047, 0x1104D}, +{0x110BB, 0x110BC}, {0x110BE, 0x110C1}, {0x11140, 0x11143}, {0x11174, 0x11175}, {0x111C5, 0x111C8}, {0x111CD, 0x111CD}, {0x111DB, 0x111DB}, {0x111DD, 0x111DF}, {0x11238, 0x1123D}, {0x112A9, 0x112A9}, +{0x1144B, 0x1144F}, {0x1145A, 0x1145B}, {0x1145D, 0x1145D}, {0x114C6, 0x114C6}, {0x115C1, 0x115D7}, {0x11641, 0x11643}, {0x11660, 0x1166C}, {0x1173C, 0x1173E}, {0x1183B, 0x1183B}, {0x11944, 0x11946}, +{0x119E2, 0x119E2}, {0x11A3F, 0x11A46}, {0x11A9A, 0x11A9C}, {0x11A9E, 0x11AA2}, {0x11C41, 0x11C45}, {0x11C70, 0x11C71}, {0x11EF7, 0x11EF8}, {0x11FFF, 0x11FFF}, {0x12470, 0x12474}, {0x16A6E, 0x16A6F}, +{0x16AF5, 0x16AF5}, {0x16B37, 0x16B3B}, {0x16B44, 0x16B44}, {0x16E97, 0x16E9A}, {0x16FE2, 0x16FE2}, {0x1BC9F, 0x1BC9F}, {0x1DA87, 0x1DA8B}, {0x1E95E, 0x1E95F}, +}; + +static const std::vector> symbol_ranges = { +{0x24, 0x24}, {0x2B, 0x2B}, {0x3C, 0x3E}, {0x5E, 0x5E}, {0x60, 0x60}, {0x7C, 0x7C}, {0x7E, 0x7E}, {0xA2, 0xA6}, {0xA8, 0xA9}, {0xAC, 0xAC}, {0xAE, 0xB1}, {0xB4, 0xB4}, {0xB8, 0xB8}, {0xD7, 0xD7}, +{0xF7, 0xF7}, {0x2C2, 0x2C5}, {0x2D2, 0x2DF}, {0x2E5, 0x2EB}, {0x2ED, 0x2ED}, {0x2EF, 0x2FF}, {0x375, 0x375}, {0x384, 0x385}, {0x3F6, 0x3F6}, {0x482, 0x482}, {0x58D, 0x58F}, {0x606, 0x608}, +{0x60B, 0x60B}, {0x60E, 0x60F}, {0x6DE, 0x6DE}, {0x6E9, 0x6E9}, {0x6FD, 0x6FE}, {0x7F6, 0x7F6}, {0x7FE, 0x7FF}, {0x9F2, 0x9F3}, {0x9FA, 0x9FB}, {0xAF1, 0xAF1}, {0xB70, 0xB70}, {0xBF3, 0xBFA}, +{0xC7F, 0xC7F}, {0xD4F, 0xD4F}, {0xD79, 0xD79}, {0xE3F, 0xE3F}, {0xF01, 0xF03}, {0xF13, 0xF13}, {0xF15, 0xF17}, {0xF1A, 0xF1F}, {0xF34, 0xF34}, {0xF36, 0xF36}, {0xF38, 0xF38}, {0xFBE, 0xFC5}, +{0xFC7, 0xFCC}, {0xFCE, 0xFCF}, {0xFD5, 0xFD8}, {0x109E, 0x109F}, {0x1390, 0x1399}, {0x166D, 0x166D}, {0x17DB, 0x17DB}, {0x1940, 0x1940}, {0x19DE, 0x19FF}, {0x1B61, 0x1B6A}, {0x1B74, 0x1B7C}, +{0x1FBD, 0x1FBD}, {0x1FBF, 0x1FC1}, {0x1FCD, 0x1FCF}, {0x1FDD, 0x1FDF}, {0x1FED, 0x1FEF}, {0x1FFD, 0x1FFE}, {0x2044, 0x2044}, {0x2052, 0x2052}, {0x207A, 0x207C}, {0x208A, 0x208C}, {0x20A0, 0x20BF}, +{0x2100, 0x2101}, {0x2103, 0x2106}, {0x2108, 0x2109}, {0x2114, 0x2114}, {0x2116, 0x2118}, {0x211E, 0x2123}, {0x2125, 0x2125}, {0x2127, 0x2127}, {0x2129, 0x2129}, {0x212E, 0x212E}, {0x213A, 0x213B}, +{0x2140, 0x2144}, {0x214A, 0x214D}, {0x214F, 0x214F}, {0x218A, 0x218B}, {0x2190, 0x2307}, {0x230C, 0x2328}, {0x232B, 0x2426}, {0x2440, 0x244A}, {0x249C, 0x24E9}, {0x2500, 0x2767}, {0x2794, 0x27C4}, +{0x27C7, 0x27E5}, {0x27F0, 0x2982}, {0x2999, 0x29D7}, {0x29DC, 0x29FB}, {0x29FE, 0x2B73}, {0x2B76, 0x2B95}, {0x2B97, 0x2BFF}, {0x2CE5, 0x2CEA}, {0x2E50, 0x2E51}, {0x2E80, 0x2E99}, {0x2E9B, 0x2EF3}, +{0x2F00, 0x2FD5}, {0x2FF0, 0x2FFB}, {0x3004, 0x3004}, {0x3012, 0x3013}, {0x3020, 0x3020}, {0x3036, 0x3037}, {0x303E, 0x303F}, {0x309B, 0x309C}, {0x3190, 0x3191}, {0x3196, 0x319F}, {0x31C0, 0x31E3}, +{0x3200, 0x321E}, {0x322A, 0x3247}, {0x3250, 0x3250}, {0x3260, 0x327F}, {0x328A, 0x32B0}, {0x32C0, 0x33FF}, {0x4DC0, 0x4DFF}, {0xA490, 0xA4C6}, {0xA700, 0xA716}, {0xA720, 0xA721}, {0xA789, 0xA78A}, +{0xA828, 0xA82B}, {0xA836, 0xA839}, {0xAA77, 0xAA79}, {0xAB5B, 0xAB5B}, {0xAB6A, 0xAB6B}, {0xFB29, 0xFB29}, {0xFBB2, 0xFBC1}, {0xFDFC, 0xFDFD}, {0xFE62, 0xFE62}, {0xFE64, 0xFE66}, {0xFE69, 0xFE69}, +{0xFF04, 0xFF04}, {0xFF0B, 0xFF0B}, {0xFF1C, 0xFF1E}, {0xFF3E, 0xFF3E}, {0xFF40, 0xFF40}, {0xFF5C, 0xFF5C}, {0xFF5E, 0xFF5E}, {0xFFE0, 0xFFE6}, {0xFFE8, 0xFFEE}, {0xFFFC, 0xFFFD}, {0x10137, 0x1013F}, +{0x10179, 0x10189}, {0x1018C, 0x1018E}, {0x10190, 0x1019C}, {0x101A0, 0x101A0}, {0x101D0, 0x101FC}, {0x10877, 0x10878}, {0x10AC8, 0x10AC8}, {0x1173F, 0x1173F}, {0x11FD5, 0x11FF1}, {0x16B3C, 0x16B3F}, +{0x16B45, 0x16B45}, {0x1BC9C, 0x1BC9C}, {0x1D000, 0x1D0F5}, {0x1D100, 0x1D126}, {0x1D129, 0x1D164}, {0x1D16A, 0x1D16C}, {0x1D183, 0x1D184}, {0x1D18C, 0x1D1A9}, {0x1D1AE, 0x1D1E8}, {0x1D200, 0x1D241}, +{0x1D245, 0x1D245}, {0x1D300, 0x1D356}, {0x1D6C1, 0x1D6C1}, {0x1D6DB, 0x1D6DB}, {0x1D6FB, 0x1D6FB}, {0x1D715, 0x1D715}, {0x1D735, 0x1D735}, {0x1D74F, 0x1D74F}, {0x1D76F, 0x1D76F}, {0x1D789, 0x1D789}, +{0x1D7A9, 0x1D7A9}, {0x1D7C3, 0x1D7C3}, {0x1D800, 0x1D9FF}, {0x1DA37, 0x1DA3A}, {0x1DA6D, 0x1DA74}, {0x1DA76, 0x1DA83}, {0x1DA85, 0x1DA86}, {0x1E14F, 0x1E14F}, {0x1E2FF, 0x1E2FF}, {0x1ECAC, 0x1ECAC}, +{0x1ECB0, 0x1ECB0}, {0x1ED2E, 0x1ED2E}, {0x1EEF0, 0x1EEF1}, {0x1F000, 0x1F02B}, {0x1F030, 0x1F093}, {0x1F0A0, 0x1F0AE}, {0x1F0B1, 0x1F0BF}, {0x1F0C1, 0x1F0CF}, {0x1F0D1, 0x1F0F5}, {0x1F10D, 0x1F1AD}, +{0x1F1E6, 0x1F202}, {0x1F210, 0x1F23B}, {0x1F240, 0x1F248}, {0x1F250, 0x1F251}, {0x1F260, 0x1F265}, {0x1F300, 0x1F6D7}, {0x1F6E0, 0x1F6EC}, {0x1F6F0, 0x1F6FC}, {0x1F700, 0x1F773}, {0x1F780, 0x1F7D8}, +{0x1F7E0, 0x1F7EB}, {0x1F800, 0x1F80B}, {0x1F810, 0x1F847}, {0x1F850, 0x1F859}, {0x1F860, 0x1F887}, {0x1F890, 0x1F8AD}, {0x1F8B0, 0x1F8B1}, {0x1F900, 0x1F978}, {0x1F97A, 0x1F9CB}, {0x1F9CD, 0x1FA53}, +{0x1FA60, 0x1FA6D}, {0x1FA70, 0x1FA74}, {0x1FA78, 0x1FA7A}, {0x1FA80, 0x1FA86}, {0x1FA90, 0x1FAA8}, {0x1FAB0, 0x1FAB6}, {0x1FAC0, 0x1FAC2}, {0x1FAD0, 0x1FAD6}, {0x1FB00, 0x1FB92}, {0x1FB94, 0x1FBCA}, +}; + +static const std::vector> control_ranges = { +{0x0, 0x8}, {0xE, 0x1B}, {0x7F, 0x84}, {0x86, 0x9F}, {0xAD, 0xAD}, {0x378, 0x379}, {0x380, 0x383}, {0x38B, 0x38B}, {0x38D, 0x38D}, {0x3A2, 0x3A2}, {0x530, 0x530}, {0x557, 0x558}, {0x58B, 0x58C}, +{0x590, 0x590}, {0x5C8, 0x5CF}, {0x5EB, 0x5EE}, {0x5F5, 0x605}, {0x61C, 0x61D}, {0x6DD, 0x6DD}, {0x70E, 0x70F}, {0x74B, 0x74C}, {0x7B2, 0x7BF}, {0x7FB, 0x7FC}, {0x82E, 0x82F}, {0x83F, 0x83F}, +{0x85C, 0x85D}, {0x85F, 0x85F}, {0x86B, 0x89F}, {0x8B5, 0x8B5}, {0x8C8, 0x8D2}, {0x8E2, 0x8E2}, {0x984, 0x984}, {0x98D, 0x98E}, {0x991, 0x992}, {0x9A9, 0x9A9}, {0x9B1, 0x9B1}, {0x9B3, 0x9B5}, +{0x9BA, 0x9BB}, {0x9C5, 0x9C6}, {0x9C9, 0x9CA}, {0x9CF, 0x9D6}, {0x9D8, 0x9DB}, {0x9DE, 0x9DE}, {0x9E4, 0x9E5}, {0x9FF, 0xA00}, {0xA04, 0xA04}, {0xA0B, 0xA0E}, {0xA11, 0xA12}, {0xA29, 0xA29}, +{0xA31, 0xA31}, {0xA34, 0xA34}, {0xA37, 0xA37}, {0xA3A, 0xA3B}, {0xA3D, 0xA3D}, {0xA43, 0xA46}, {0xA49, 0xA4A}, {0xA4E, 0xA50}, {0xA52, 0xA58}, {0xA5D, 0xA5D}, {0xA5F, 0xA65}, {0xA77, 0xA80}, +{0xA84, 0xA84}, {0xA8E, 0xA8E}, {0xA92, 0xA92}, {0xAA9, 0xAA9}, {0xAB1, 0xAB1}, {0xAB4, 0xAB4}, {0xABA, 0xABB}, {0xAC6, 0xAC6}, {0xACA, 0xACA}, {0xACE, 0xACF}, {0xAD1, 0xADF}, {0xAE4, 0xAE5}, +{0xAF2, 0xAF8}, {0xB00, 0xB00}, {0xB04, 0xB04}, {0xB0D, 0xB0E}, {0xB11, 0xB12}, {0xB29, 0xB29}, {0xB31, 0xB31}, {0xB34, 0xB34}, {0xB3A, 0xB3B}, {0xB45, 0xB46}, {0xB49, 0xB4A}, {0xB4E, 0xB54}, +{0xB58, 0xB5B}, {0xB5E, 0xB5E}, {0xB64, 0xB65}, {0xB78, 0xB81}, {0xB84, 0xB84}, {0xB8B, 0xB8D}, {0xB91, 0xB91}, {0xB96, 0xB98}, {0xB9B, 0xB9B}, {0xB9D, 0xB9D}, {0xBA0, 0xBA2}, {0xBA5, 0xBA7}, +{0xBAB, 0xBAD}, {0xBBA, 0xBBD}, {0xBC3, 0xBC5}, {0xBC9, 0xBC9}, {0xBCE, 0xBCF}, {0xBD1, 0xBD6}, {0xBD8, 0xBE5}, {0xBFB, 0xBFF}, {0xC0D, 0xC0D}, {0xC11, 0xC11}, {0xC29, 0xC29}, {0xC3A, 0xC3C}, +{0xC45, 0xC45}, {0xC49, 0xC49}, {0xC4E, 0xC54}, {0xC57, 0xC57}, {0xC5B, 0xC5F}, {0xC64, 0xC65}, {0xC70, 0xC76}, {0xC8D, 0xC8D}, {0xC91, 0xC91}, {0xCA9, 0xCA9}, {0xCB4, 0xCB4}, {0xCBA, 0xCBB}, +{0xCC5, 0xCC5}, {0xCC9, 0xCC9}, {0xCCE, 0xCD4}, {0xCD7, 0xCDD}, {0xCDF, 0xCDF}, {0xCE4, 0xCE5}, {0xCF0, 0xCF0}, {0xCF3, 0xCFF}, {0xD0D, 0xD0D}, {0xD11, 0xD11}, {0xD45, 0xD45}, {0xD49, 0xD49}, +{0xD50, 0xD53}, {0xD64, 0xD65}, {0xD80, 0xD80}, {0xD84, 0xD84}, {0xD97, 0xD99}, {0xDB2, 0xDB2}, {0xDBC, 0xDBC}, {0xDBE, 0xDBF}, {0xDC7, 0xDC9}, {0xDCB, 0xDCE}, {0xDD5, 0xDD5}, {0xDD7, 0xDD7}, +{0xDE0, 0xDE5}, {0xDF0, 0xDF1}, {0xDF5, 0xE00}, {0xE3B, 0xE3E}, {0xE5C, 0xE80}, {0xE83, 0xE83}, {0xE85, 0xE85}, {0xE8B, 0xE8B}, {0xEA4, 0xEA4}, {0xEA6, 0xEA6}, {0xEBE, 0xEBF}, {0xEC5, 0xEC5}, +{0xEC7, 0xEC7}, {0xECE, 0xECF}, {0xEDA, 0xEDB}, {0xEE0, 0xEFF}, {0xF48, 0xF48}, {0xF6D, 0xF70}, {0xF98, 0xF98}, {0xFBD, 0xFBD}, {0xFCD, 0xFCD}, {0xFDB, 0xFFF}, {0x10C6, 0x10C6}, {0x10C8, 0x10CC}, +{0x10CE, 0x10CF}, {0x1249, 0x1249}, {0x124E, 0x124F}, {0x1257, 0x1257}, {0x1259, 0x1259}, {0x125E, 0x125F}, {0x1289, 0x1289}, {0x128E, 0x128F}, {0x12B1, 0x12B1}, {0x12B6, 0x12B7}, {0x12BF, 0x12BF}, +{0x12C1, 0x12C1}, {0x12C6, 0x12C7}, {0x12D7, 0x12D7}, {0x1311, 0x1311}, {0x1316, 0x1317}, {0x135B, 0x135C}, {0x137D, 0x137F}, {0x139A, 0x139F}, {0x13F6, 0x13F7}, {0x13FE, 0x13FF}, {0x169D, 0x169F}, +{0x16F9, 0x16FF}, {0x170D, 0x170D}, {0x1715, 0x171F}, {0x1737, 0x173F}, {0x1754, 0x175F}, {0x176D, 0x176D}, {0x1771, 0x1771}, {0x1774, 0x177F}, {0x17DE, 0x17DF}, {0x17EA, 0x17EF}, {0x17FA, 0x17FF}, +{0x180E, 0x180F}, {0x181A, 0x181F}, {0x1879, 0x187F}, {0x18AB, 0x18AF}, {0x18F6, 0x18FF}, {0x191F, 0x191F}, {0x192C, 0x192F}, {0x193C, 0x193F}, {0x1941, 0x1943}, {0x196E, 0x196F}, {0x1975, 0x197F}, +{0x19AC, 0x19AF}, {0x19CA, 0x19CF}, {0x19DB, 0x19DD}, {0x1A1C, 0x1A1D}, {0x1A5F, 0x1A5F}, {0x1A7D, 0x1A7E}, {0x1A8A, 0x1A8F}, {0x1A9A, 0x1A9F}, {0x1AAE, 0x1AAF}, {0x1AC1, 0x1AFF}, {0x1B4C, 0x1B4F}, +{0x1B7D, 0x1B7F}, {0x1BF4, 0x1BFB}, {0x1C38, 0x1C3A}, {0x1C4A, 0x1C4C}, {0x1C89, 0x1C8F}, {0x1CBB, 0x1CBC}, {0x1CC8, 0x1CCF}, {0x1CFB, 0x1CFF}, {0x1DFA, 0x1DFA}, {0x1F16, 0x1F17}, {0x1F1E, 0x1F1F}, +{0x1F46, 0x1F47}, {0x1F4E, 0x1F4F}, {0x1F58, 0x1F58}, {0x1F5A, 0x1F5A}, {0x1F5C, 0x1F5C}, {0x1F5E, 0x1F5E}, {0x1F7E, 0x1F7F}, {0x1FB5, 0x1FB5}, {0x1FC5, 0x1FC5}, {0x1FD4, 0x1FD5}, {0x1FDC, 0x1FDC}, +{0x1FF0, 0x1FF1}, {0x1FF5, 0x1FF5}, {0x1FFF, 0x1FFF}, {0x200B, 0x200F}, {0x202A, 0x202E}, {0x2060, 0x206F}, {0x2072, 0x2073}, {0x208F, 0x208F}, {0x209D, 0x209F}, {0x20C0, 0x20CF}, {0x20F1, 0x20FF}, +{0x218C, 0x218F}, {0x2427, 0x243F}, {0x244B, 0x245F}, {0x2B74, 0x2B75}, {0x2B96, 0x2B96}, {0x2C2F, 0x2C2F}, {0x2C5F, 0x2C5F}, {0x2CF4, 0x2CF8}, {0x2D26, 0x2D26}, {0x2D28, 0x2D2C}, {0x2D2E, 0x2D2F}, +{0x2D68, 0x2D6E}, {0x2D71, 0x2D7E}, {0x2D97, 0x2D9F}, {0x2DA7, 0x2DA7}, {0x2DAF, 0x2DAF}, {0x2DB7, 0x2DB7}, {0x2DBF, 0x2DBF}, {0x2DC7, 0x2DC7}, {0x2DCF, 0x2DCF}, {0x2DD7, 0x2DD7}, {0x2DDF, 0x2DDF}, +{0x2E53, 0x2E7F}, {0x2E9A, 0x2E9A}, {0x2EF4, 0x2EFF}, {0x2FD6, 0x2FEF}, {0x2FFC, 0x2FFF}, {0x3040, 0x3040}, {0x3097, 0x3098}, {0x3100, 0x3104}, {0x3130, 0x3130}, {0x318F, 0x318F}, {0x31E4, 0x31EF}, +{0x321F, 0x321F}, {0x9FFD, 0x9FFF}, {0xA48D, 0xA48F}, {0xA4C7, 0xA4CF}, {0xA62C, 0xA63F}, {0xA6F8, 0xA6FF}, {0xA7C0, 0xA7C1}, {0xA7CB, 0xA7F4}, {0xA82D, 0xA82F}, {0xA83A, 0xA83F}, {0xA878, 0xA87F}, +{0xA8C6, 0xA8CD}, {0xA8DA, 0xA8DF}, {0xA954, 0xA95E}, {0xA97D, 0xA97F}, {0xA9CE, 0xA9CE}, {0xA9DA, 0xA9DD}, {0xA9FF, 0xA9FF}, {0xAA37, 0xAA3F}, {0xAA4E, 0xAA4F}, {0xAA5A, 0xAA5B}, {0xAAC3, 0xAADA}, +{0xAAF7, 0xAB00}, {0xAB07, 0xAB08}, {0xAB0F, 0xAB10}, {0xAB17, 0xAB1F}, {0xAB27, 0xAB27}, {0xAB2F, 0xAB2F}, {0xAB6C, 0xAB6F}, {0xABEE, 0xABEF}, {0xABFA, 0xABFF}, {0xD7A4, 0xD7AF}, {0xD7C7, 0xD7CA}, +{0xD7FC, 0xF8FF}, {0xFA6E, 0xFA6F}, {0xFADA, 0xFAFF}, {0xFB07, 0xFB12}, {0xFB18, 0xFB1C}, {0xFB37, 0xFB37}, {0xFB3D, 0xFB3D}, {0xFB3F, 0xFB3F}, {0xFB42, 0xFB42}, {0xFB45, 0xFB45}, {0xFBC2, 0xFBD2}, +{0xFD40, 0xFD4F}, {0xFD90, 0xFD91}, {0xFDC8, 0xFDEF}, {0xFDFE, 0xFDFF}, {0xFE1A, 0xFE1F}, {0xFE53, 0xFE53}, {0xFE67, 0xFE67}, {0xFE6C, 0xFE6F}, {0xFE75, 0xFE75}, {0xFEFD, 0xFF00}, {0xFFBF, 0xFFC1}, +{0xFFC8, 0xFFC9}, {0xFFD0, 0xFFD1}, {0xFFD8, 0xFFD9}, {0xFFDD, 0xFFDF}, {0xFFE7, 0xFFE7}, {0xFFEF, 0xFFFB}, {0xFFFE, 0xFFFF}, {0x1000C, 0x1000C}, {0x10027, 0x10027}, {0x1003B, 0x1003B}, +{0x1003E, 0x1003E}, {0x1004E, 0x1004F}, {0x1005E, 0x1007F}, {0x100FB, 0x100FF}, {0x10103, 0x10106}, {0x10134, 0x10136}, {0x1018F, 0x1018F}, {0x1019D, 0x1019F}, {0x101A1, 0x101CF}, {0x101FE, 0x1027F}, +{0x1029D, 0x1029F}, {0x102D1, 0x102DF}, {0x102FC, 0x102FF}, {0x10324, 0x1032C}, {0x1034B, 0x1034F}, {0x1037B, 0x1037F}, {0x1039E, 0x1039E}, {0x103C4, 0x103C7}, {0x103D6, 0x103FF}, {0x1049E, 0x1049F}, +{0x104AA, 0x104AF}, {0x104D4, 0x104D7}, {0x104FC, 0x104FF}, {0x10528, 0x1052F}, {0x10564, 0x1056E}, {0x10570, 0x105FF}, {0x10737, 0x1073F}, {0x10756, 0x1075F}, {0x10768, 0x107FF}, {0x10806, 0x10807}, +{0x10809, 0x10809}, {0x10836, 0x10836}, {0x10839, 0x1083B}, {0x1083D, 0x1083E}, {0x10856, 0x10856}, {0x1089F, 0x108A6}, {0x108B0, 0x108DF}, {0x108F3, 0x108F3}, {0x108F6, 0x108FA}, {0x1091C, 0x1091E}, +{0x1093A, 0x1093E}, {0x10940, 0x1097F}, {0x109B8, 0x109BB}, {0x109D0, 0x109D1}, {0x10A04, 0x10A04}, {0x10A07, 0x10A0B}, {0x10A14, 0x10A14}, {0x10A18, 0x10A18}, {0x10A36, 0x10A37}, {0x10A3B, 0x10A3E}, +{0x10A49, 0x10A4F}, {0x10A59, 0x10A5F}, {0x10AA0, 0x10ABF}, {0x10AE7, 0x10AEA}, {0x10AF7, 0x10AFF}, {0x10B36, 0x10B38}, {0x10B56, 0x10B57}, {0x10B73, 0x10B77}, {0x10B92, 0x10B98}, {0x10B9D, 0x10BA8}, +{0x10BB0, 0x10BFF}, {0x10C49, 0x10C7F}, {0x10CB3, 0x10CBF}, {0x10CF3, 0x10CF9}, {0x10D28, 0x10D2F}, {0x10D3A, 0x10E5F}, {0x10E7F, 0x10E7F}, {0x10EAA, 0x10EAA}, {0x10EAE, 0x10EAF}, {0x10EB2, 0x10EFF}, +{0x10F28, 0x10F2F}, {0x10F5A, 0x10FAF}, {0x10FCC, 0x10FDF}, {0x10FF7, 0x10FFF}, {0x1104E, 0x11051}, {0x11070, 0x1107E}, {0x110BD, 0x110BD}, {0x110C2, 0x110CF}, {0x110E9, 0x110EF}, {0x110FA, 0x110FF}, +{0x11135, 0x11135}, {0x11148, 0x1114F}, {0x11177, 0x1117F}, {0x111E0, 0x111E0}, {0x111F5, 0x111FF}, {0x11212, 0x11212}, {0x1123F, 0x1127F}, {0x11287, 0x11287}, {0x11289, 0x11289}, {0x1128E, 0x1128E}, +{0x1129E, 0x1129E}, {0x112AA, 0x112AF}, {0x112EB, 0x112EF}, {0x112FA, 0x112FF}, {0x11304, 0x11304}, {0x1130D, 0x1130E}, {0x11311, 0x11312}, {0x11329, 0x11329}, {0x11331, 0x11331}, {0x11334, 0x11334}, +{0x1133A, 0x1133A}, {0x11345, 0x11346}, {0x11349, 0x1134A}, {0x1134E, 0x1134F}, {0x11351, 0x11356}, {0x11358, 0x1135C}, {0x11364, 0x11365}, {0x1136D, 0x1136F}, {0x11375, 0x113FF}, {0x1145C, 0x1145C}, +{0x11462, 0x1147F}, {0x114C8, 0x114CF}, {0x114DA, 0x1157F}, {0x115B6, 0x115B7}, {0x115DE, 0x115FF}, {0x11645, 0x1164F}, {0x1165A, 0x1165F}, {0x1166D, 0x1167F}, {0x116B9, 0x116BF}, {0x116CA, 0x116FF}, +{0x1171B, 0x1171C}, {0x1172C, 0x1172F}, {0x11740, 0x117FF}, {0x1183C, 0x1189F}, {0x118F3, 0x118FE}, {0x11907, 0x11908}, {0x1190A, 0x1190B}, {0x11914, 0x11914}, {0x11917, 0x11917}, {0x11936, 0x11936}, +{0x11939, 0x1193A}, {0x11947, 0x1194F}, {0x1195A, 0x1199F}, {0x119A8, 0x119A9}, {0x119D8, 0x119D9}, {0x119E5, 0x119FF}, {0x11A48, 0x11A4F}, {0x11AA3, 0x11ABF}, {0x11AF9, 0x11BFF}, {0x11C09, 0x11C09}, +{0x11C37, 0x11C37}, {0x11C46, 0x11C4F}, {0x11C6D, 0x11C6F}, {0x11C90, 0x11C91}, {0x11CA8, 0x11CA8}, {0x11CB7, 0x11CFF}, {0x11D07, 0x11D07}, {0x11D0A, 0x11D0A}, {0x11D37, 0x11D39}, {0x11D3B, 0x11D3B}, +{0x11D3E, 0x11D3E}, {0x11D48, 0x11D4F}, {0x11D5A, 0x11D5F}, {0x11D66, 0x11D66}, {0x11D69, 0x11D69}, {0x11D8F, 0x11D8F}, {0x11D92, 0x11D92}, {0x11D99, 0x11D9F}, {0x11DAA, 0x11EDF}, {0x11EF9, 0x11FAF}, +{0x11FB1, 0x11FBF}, {0x11FF2, 0x11FFE}, {0x1239A, 0x123FF}, {0x1246F, 0x1246F}, {0x12475, 0x1247F}, {0x12544, 0x12FFF}, {0x1342F, 0x143FF}, {0x14647, 0x167FF}, {0x16A39, 0x16A3F}, {0x16A5F, 0x16A5F}, +{0x16A6A, 0x16A6D}, {0x16A70, 0x16ACF}, {0x16AEE, 0x16AEF}, {0x16AF6, 0x16AFF}, {0x16B46, 0x16B4F}, {0x16B5A, 0x16B5A}, {0x16B62, 0x16B62}, {0x16B78, 0x16B7C}, {0x16B90, 0x16E3F}, {0x16E9B, 0x16EFF}, +{0x16F4B, 0x16F4E}, {0x16F88, 0x16F8E}, {0x16FA0, 0x16FDF}, {0x16FE5, 0x16FEF}, {0x16FF2, 0x16FFF}, {0x187F8, 0x187FF}, {0x18CD6, 0x18CFF}, {0x18D09, 0x1AFFF}, {0x1B11F, 0x1B14F}, {0x1B153, 0x1B163}, +{0x1B168, 0x1B16F}, {0x1B2FC, 0x1BBFF}, {0x1BC6B, 0x1BC6F}, {0x1BC7D, 0x1BC7F}, {0x1BC89, 0x1BC8F}, {0x1BC9A, 0x1BC9B}, {0x1BCA0, 0x1CFFF}, {0x1D0F6, 0x1D0FF}, {0x1D127, 0x1D128}, {0x1D173, 0x1D17A}, +{0x1D1E9, 0x1D1FF}, {0x1D246, 0x1D2DF}, {0x1D2F4, 0x1D2FF}, {0x1D357, 0x1D35F}, {0x1D379, 0x1D3FF}, {0x1D455, 0x1D455}, {0x1D49D, 0x1D49D}, {0x1D4A0, 0x1D4A1}, {0x1D4A3, 0x1D4A4}, {0x1D4A7, 0x1D4A8}, +{0x1D4AD, 0x1D4AD}, {0x1D4BA, 0x1D4BA}, {0x1D4BC, 0x1D4BC}, {0x1D4C4, 0x1D4C4}, {0x1D506, 0x1D506}, {0x1D50B, 0x1D50C}, {0x1D515, 0x1D515}, {0x1D51D, 0x1D51D}, {0x1D53A, 0x1D53A}, {0x1D53F, 0x1D53F}, +{0x1D545, 0x1D545}, {0x1D547, 0x1D549}, {0x1D551, 0x1D551}, {0x1D6A6, 0x1D6A7}, {0x1D7CC, 0x1D7CD}, {0x1DA8C, 0x1DA9A}, {0x1DAA0, 0x1DAA0}, {0x1DAB0, 0x1DFFF}, {0x1E007, 0x1E007}, {0x1E019, 0x1E01A}, +{0x1E022, 0x1E022}, {0x1E025, 0x1E025}, {0x1E02B, 0x1E0FF}, {0x1E12D, 0x1E12F}, {0x1E13E, 0x1E13F}, {0x1E14A, 0x1E14D}, {0x1E150, 0x1E2BF}, {0x1E2FA, 0x1E2FE}, {0x1E300, 0x1E7FF}, {0x1E8C5, 0x1E8C6}, +{0x1E8D7, 0x1E8FF}, {0x1E94C, 0x1E94F}, {0x1E95A, 0x1E95D}, {0x1E960, 0x1EC70}, {0x1ECB5, 0x1ED00}, {0x1ED3E, 0x1EDFF}, {0x1EE04, 0x1EE04}, {0x1EE20, 0x1EE20}, {0x1EE23, 0x1EE23}, {0x1EE25, 0x1EE26}, +{0x1EE28, 0x1EE28}, {0x1EE33, 0x1EE33}, {0x1EE38, 0x1EE38}, {0x1EE3A, 0x1EE3A}, {0x1EE3C, 0x1EE41}, {0x1EE43, 0x1EE46}, {0x1EE48, 0x1EE48}, {0x1EE4A, 0x1EE4A}, {0x1EE4C, 0x1EE4C}, {0x1EE50, 0x1EE50}, +{0x1EE53, 0x1EE53}, {0x1EE55, 0x1EE56}, {0x1EE58, 0x1EE58}, {0x1EE5A, 0x1EE5A}, {0x1EE5C, 0x1EE5C}, {0x1EE5E, 0x1EE5E}, {0x1EE60, 0x1EE60}, {0x1EE63, 0x1EE63}, {0x1EE65, 0x1EE66}, {0x1EE6B, 0x1EE6B}, +{0x1EE73, 0x1EE73}, {0x1EE78, 0x1EE78}, {0x1EE7D, 0x1EE7D}, {0x1EE7F, 0x1EE7F}, {0x1EE8A, 0x1EE8A}, {0x1EE9C, 0x1EEA0}, {0x1EEA4, 0x1EEA4}, {0x1EEAA, 0x1EEAA}, {0x1EEBC, 0x1EEEF}, {0x1EEF2, 0x1EFFF}, +{0x1F02C, 0x1F02F}, {0x1F094, 0x1F09F}, {0x1F0AF, 0x1F0B0}, {0x1F0C0, 0x1F0C0}, {0x1F0D0, 0x1F0D0}, {0x1F0F6, 0x1F0FF}, {0x1F1AE, 0x1F1E5}, {0x1F203, 0x1F20F}, {0x1F23C, 0x1F23F}, {0x1F249, 0x1F24F}, +{0x1F252, 0x1F25F}, {0x1F266, 0x1F2FF}, {0x1F6D8, 0x1F6DF}, {0x1F6ED, 0x1F6EF}, {0x1F6FD, 0x1F6FF}, {0x1F774, 0x1F77F}, {0x1F7D9, 0x1F7DF}, {0x1F7EC, 0x1F7FF}, {0x1F80C, 0x1F80F}, {0x1F848, 0x1F84F}, +{0x1F85A, 0x1F85F}, {0x1F888, 0x1F88F}, {0x1F8AE, 0x1F8AF}, {0x1F8B2, 0x1F8FF}, {0x1F979, 0x1F979}, {0x1F9CC, 0x1F9CC}, {0x1FA54, 0x1FA5F}, {0x1FA6E, 0x1FA6F}, {0x1FA75, 0x1FA77}, {0x1FA7B, 0x1FA7F}, +{0x1FA87, 0x1FA8F}, {0x1FAA9, 0x1FAAF}, {0x1FAB7, 0x1FABF}, {0x1FAC3, 0x1FACF}, {0x1FAD7, 0x1FAFF}, {0x1FB93, 0x1FB93}, {0x1FBCB, 0x1FBEF}, {0x1FBFA, 0x1FFFF}, {0x2A6DE, 0x2A6FF}, {0x2B735, 0x2B73F}, +{0x2B81E, 0x2B81F}, {0x2CEA2, 0x2CEAF}, {0x2EBE1, 0x2F7FF}, {0x2FA1E, 0x2FFFF}, {0x3134B, 0xE00FF}, {0xE01F0, 0x10FFFF}, +}; + +//String +bool CNCTString::operator==(const std::string& other) const { + return str.compare(other) == 0; +} +bool CNCTString::operator==(const char other) const { + return str.compare(std::string(1, other)) == 0; +} +bool CNCTString::operator==(const CNCTString& other) const { + return str.compare(other.str) == 0; +} +// + operators +CNCTString& CNCTString::operator+=(const std::string& other) { + str += other; + int new_len = CNCTUnicode::strlen_utf8(other); + utf8_chars += new_len; + char_type = CNCTUnicode::string_identify(str); + seq_offset_bytes += other.size(); + seq_offset_utf8_chars += new_len; + return *this; +} + +CNCTString& CNCTString::operator+=(const char other) { + std::string str = std::string(1, other); + *this += str; + return *this; +} + +CNCTString& CNCTString::operator+=(const CNCTString& other) { + str += other.str; + utf8_chars += other.utf8_chars; + char_type = CNCTUnicode::string_identify(str); + seq_offset_bytes += other.str.size(); + seq_offset_utf8_chars += other.utf8_chars; + return *this; +} + +struct CRCompare { + bool operator()(const std::pair& p, int i) { + return p.second < i; + } + bool operator()(int i, const std::pair& p) { + return i < p.first; + } +}; + +// binary search for code range +bool CNCTUnicode::check_code_range(int c, const std::vector> &ranges) { + auto it = std::upper_bound(ranges.begin(), ranges.end(), c, CRCompare()); + if (it != ranges.begin()) { + --it; + } + return c >= it->first && c <= it->second; +} + +// these are binary searches, it takes only a few operations +CNCTCharType CNCTUnicode::get_code_type(int c) { + if (check_code_range(c, letter_ranges)) { + return LETTER; + } + if (check_code_range(c, digit_ranges)) { + return DIGIT; + } + if (check_code_range(c, whitespace_ranges)) { + return WHITESPACE; + } + if (check_code_range(c, punctuation_ranges)) { + return PUNCTUATION; + } + if (check_code_range(c, symbol_ranges)) { + return SYMBOL; + } + if (check_code_range(c, accent_mark_ranges)) { + return ACCENT_MARK; + } + if (check_code_range(c, control_ranges)) { + return CONTROL; + } + return UNIDENTIFIED; +} + +static int utf8_to_unicode(const std::string& utf8_char) { + int c = 0; + int len = (int)utf8_char.size(); + if (len == 1) { + c = utf8_char[0]; + } else if (len == 2) { + c = ((utf8_char[0] & 0x1F) << 6) | (utf8_char[1] & 0x3F); + } else if (len == 3) { + c = ((utf8_char[0] & 0x0F) << 12) | ((utf8_char[1] & 0x3F) << 6) | (utf8_char[2] & 0x3F); + } else if (len == 4) { + c = ((utf8_char[0] & 0x07) << 18) | ((utf8_char[1] & 0x3F) << 12) | ((utf8_char[2] & 0x3F) << 6) | (utf8_char[3] & 0x3F); + } + return c; +} + +CNCTCharType CNCTUnicode::get_code_type(const std::string &utf8_char) { + return get_code_type(utf8_to_unicode(utf8_char)); +} + +int CNCTUnicode::utf8_len(const char c) +{ + if ((c & 0x80) == 0) { + return 1; // ASCII character + } + if ((c & 0xE0) == 0xC0) { + return 2; // 2-byte character + } + if ((c & 0xF0) == 0xE0) { + return 3; // 3-byte character + } + if ((c & 0xF0) == 0xF0) { + return 4; // 4-byte character + } + return 1; // not valid utf8 + // static const uint8_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; + // return lookup[static_cast(c) >> 4]; +} + +int CNCTUnicode::strlen_utf8(const std::string src) { + int len = 0; + for (std::string::const_iterator it = src.begin(); it != src.end(); ++it) { + int char_len = utf8_len(*it); + if (char_len > 1) { + it += char_len - 1; + } + len += 1; + } + return len; +} + +// split a string into unicode strings +std::vector CNCTUnicode::split_utf8(const std::string &src) { + std::vector result; + for (std::string::const_iterator it = src.begin(); it != src.end(); ++it) { + int char_len = utf8_len(*it); + std::string str(it, it + char_len); + result.push_back(str); + if (char_len > 1) { + it += char_len - 1; + } + } + return result; +} + +// split a string into unicode strings (CNCTString) with sequence information +std::vector CNCTUnicode::split_utf8_enhanced(const std::string &src) { + std::vector result; + int seq_offset_bytes=0; + int seq_offset_utf8_chars=0; + for (std::string::const_iterator it = src.begin(); it != src.end(); ++it) { + int char_len = utf8_len(*it); + std::string str(it, it + char_len); + CNCTString cnct_str; + cnct_str.seq_offset_bytes = seq_offset_bytes; + cnct_str.seq_offset_utf8_chars = seq_offset_utf8_chars; + cnct_str.str = str; + cnct_str.utf8_chars = 1; + cnct_str.char_type = get_code_type(str); + #if 0 + switch (cnct_str.char_type) + { + case DIGIT: + printf("%s = DIGIT\n", str.c_str()); + break; + case LETTER: + printf("%s = LETTER\n", str.c_str()); + break; + case WHITESPACE: + printf("%s = WHITESPACE\n", str.c_str()); + break; + case PUNCTUATION: + printf("%s = PUNCTUATION\n", str.c_str()); + break; + case UNIDENTIFIED: + printf("%s = UNIDENTIFIED\n", str.c_str()); + break; + case SYMBOL: + printf("%s = SYMBOL\n", str.c_str()); + break; + case CONTROL: + printf("%s = CONTROL\n", str.c_str()); + break; + } + #endif + + result.push_back(cnct_str); + seq_offset_bytes += char_len; + seq_offset_utf8_chars += 1; + if (char_len > 1) { + it += char_len - 1; + } + + } + return result; +} + +// return the type of the string +CNCTCharType CNCTUnicode::string_identify(const std::string &str) { + CNCTCharType result = UNIDENTIFIED; + std::string::const_iterator it = str.begin(); + while (it != str.end()) { + int len = utf8_len(*it); + int c = 0; + for (int i = 0; i < len && it != str.end(); ++i, ++it) { + c = (c << 8) | static_cast(*it); + } + switch (get_code_type(c)) { + case DIGIT: + if (result == UNIDENTIFIED) { + result = DIGIT; + } else if (result != DIGIT) { + return MIXED; + } + break; + case LETTER: + if (result == UNIDENTIFIED) { + result = LETTER; + } else if (result != LETTER) { + return MIXED; + } + break; + case WHITESPACE: + if (result == UNIDENTIFIED) { + result = WHITESPACE; + } else if (result != WHITESPACE) { + return MIXED; + } + break; + case PUNCTUATION: + if (result == UNIDENTIFIED) { + result = PUNCTUATION; + } else if (result != PUNCTUATION) { + return MIXED; + } + break; + default: + return MIXED; + break; + } + } + return result; +} + +// verify the content of a string +bool CNCTUnicode::string_test(const std::string &str, CNCTCharType chartype) +{ + std::string::const_iterator it = str.begin(); + while (it != str.end()) { + int len = utf8_len(*it); + int c = 0; + for (int i = 0; i < len && it != str.end(); ++i, ++it) { + c = (c << 8) | static_cast(*it); + } + if (get_code_type(c) != chartype) { + return false; + } + } + return true; +} + +//----------------- +// llama.cpp GPT2 vocab (from libfalcon.cpp) +//----------------- + +std::string replaceAll(std::string str, const std::string& from, const std::string& to) { + size_t start_pos = 0; + while((start_pos = str.find(from, start_pos)) != std::string::npos) { + str.replace(start_pos, from.length(), to); + start_pos += to.length(); // Handles case where 'to' is a substring of 'from' + } + return str; +} + +struct TrieNode { + std::map map; + int32_t Id = -1; +}; + +struct Trie { + TrieNode *root; + + Trie() : root(new TrieNode()) {} + + ~Trie() { + if(root) + deleteTrie(root); + } + + // Move constructor + Trie(Trie&& other) noexcept : root(other.root) { + other.root = nullptr; + } + + // Move assignment operator + Trie& operator=(Trie&& other) noexcept { + if (this != &other) { + if(root) + deleteTrie(root); + root = other.root; + other.root = nullptr; + } + return *this; + } + + void insert(const std::string &token, int32_t Id) { + TrieNode* current = root; + for(auto ch : token) { + if(current->map.find(ch) == current->map.end()) { + current->map[ch] = new TrieNode(); + } + current = current->map[ch]; + } + current->Id = Id; + } + + void reset() { + deleteTrie(root); + root = new TrieNode(); + } + +private: + void deleteTrie(TrieNode* node) { + for(auto &it: node->map) { + deleteTrie(it.second); + } + delete node; + } + +}; + +struct gpt2bpe_vocab { + using id = int32_t; + using token = std::string; + + std::map max_token_length; // max length, for each 2byte prefix + std::map, int> bpe_ranks; + std::vector> bpe_merges; + + id special_bos_id = -1; + id special_eos_id = -1; + id special_unk_id = -1; + id special_sep_id = -1; + id special_pad_id = -1; + + id linefeed_id = -1; + + std::unordered_map token_to_id; + std::unordered_map id_to_token; + + Trie trie; // highspeed access to tokens by prefix tree + + // populate trie from map + void populate_trie_from_map() { + trie.reset(); + for (const auto& pair : token_to_id) { + trie.insert(pair.first, pair.second); + if (pair.first.size() >= 2) { + std::string prefix = pair.first.substr(0, 2); + max_token_length[prefix] = std::max(max_token_length[prefix], (uint32_t)pair.first.size()); + } + } + } + // populate token ranks map + int populate_bpe_ranks(std::vector> bpe_merges_) { + for (int i = 0; i < (int)bpe_merges_.size(); i++) { + bpe_ranks.emplace(bpe_merges_[i], i); + } + bpe_merges = bpe_merges_; + return bpe_merges_.size(); + } + + // Trim whitespace characters from the beginning and end of the string + void trim(std::string& str) { + // Remove whitespace characters from the beginning of the string + str.erase(str.begin(), std::find_if(str.begin(), str.end(), [](int ch) { + return !std::isspace(ch); + })); + + // Remove whitespace characters from the end of the string + str.erase(std::find_if(str.rbegin(), str.rend(), [](int ch) { + return !std::isspace(ch); + }).base(), str.end()); + } + + // get max token length available for a prefix of 2 bytes (string at least 2 bytes long) + int get_max_token_length(const std::string& string) const { + if (string.size() < 2) { + return -1; + } + std::string prefix = string.substr(0, 2); + if (max_token_length.find(prefix) == max_token_length.end()) { + return 0; + } + return max_token_length.at(prefix); + } + + // function to find if two tokens match in bpe_rank, return rank or -1 + int find_bpe_rank(const std::string& token1, const std::string& token2) const { + std::string left_token = token1; + std::string right_token = token2; + left_token = replaceAll(left_token, " ", "Ġ"); + left_token = replaceAll(left_token, "\n", "Ċ"); + right_token = replaceAll(right_token, " ", "Ġ"); + right_token = replaceAll(right_token, "\n", "Ċ"); + + auto it = bpe_ranks.find(std::make_pair(left_token, right_token)); + if (it == bpe_ranks.end()) { + return -1; + } + return it->second; + } + + std::pair find_longest_match(const std::string& snippet) const { + TrieNode* current = trie.root; + gpt2bpe_vocab::id last_matched_id = -1; + std::string last_matched_token = ""; + std::string current_token = ""; + for (auto ch : snippet) { + if (current->map.find(ch) == current->map.end()) { + break; + } + current = current->map[ch]; + current_token += ch; + if (current->Id != -1) { + last_matched_id = current->Id; + last_matched_token = current_token; + } + } + return {last_matched_id, last_matched_token}; + } + +}; + + +// +// tokenizer - bpe type, gpt2 tokenization compatible +// + +struct ggllm_bpe_symbol { + using index = int; + index prev; + index next; + const char * text; + size_t n; +}; + +static_assert(std::is_trivially_copyable::value, "ggllm_bpe_symbol is not trivially copyable"); + +struct ggllm_bpe_bigram { + struct comparator { + bool operator()(ggllm_bpe_bigram & l, ggllm_bpe_bigram & r) { + return l.rank > r.rank || (l.rank == r.rank && l.left > r.left); + } + }; + + using queue_storage = std::vector; + using queue = std::priority_queue; + ggllm_bpe_symbol::index left; + ggllm_bpe_symbol::index right; + std::string text; + int rank; + size_t size; +}; + +struct gpt2bpe_tokenizer { + gpt2bpe_tokenizer(const gpt2bpe_vocab & vocab, bool g2ws_): vocab_(vocab) { flag_g2ws = g2ws_; } + + void tokenize(const std::string & text, std::vector & output) { + int final_prev_index = -1; + // auto start = ggml_time_us(); + auto word_collection = bpe_gpt2_preprocess(text); + // auto end = ggml_time_us(); + // fprintf(stderr, "%s: preprocessing took %0.3f ms\n", __func__, (end - start) / 1000.0); + + symbols_final.clear(); + + for (auto & word : word_collection) { + work_queue_ = ggllm_bpe_bigram::queue(); + symbols_.clear(); + + int index = 0; + size_t offset = 0; + + while (offset < word.size()) { + ggllm_bpe_symbol sym; + size_t char_len = std::min(word.size() - offset, (size_t) CNCTUnicode::utf8_len(word[offset])); + sym.text = word.c_str() + offset; + sym.n = 1; + sym.n = char_len; + offset += sym.n; + sym.prev = index - 1; + sym.next = offset == word.size() ? -1 : index + 1; + index++; + symbols_.emplace_back(sym); + } + for (size_t i = 1; i < symbols_.size(); ++i) { + add_new_bigram(i - 1, i); + } + + // build token(s) + while (!work_queue_.empty()) { + auto bigram = work_queue_.top(); + work_queue_.pop(); + + auto & left_symbol = symbols_[bigram.left]; + auto & right_symbol = symbols_[bigram.right]; + + if (left_symbol.n == 0 || right_symbol.n == 0) { + continue; + } + std::string left_token = std::string(left_symbol.text, left_symbol.n); + std::string right_token = std::string(right_symbol.text, right_symbol.n); + if (left_token + right_token != bigram.text) { + continue; // Skip this bigram if it's outdated + } + + // merge the right sym into the left one + left_symbol.n += right_symbol.n; + right_symbol.n = 0; + + // remove the right sym from the chain + left_symbol.next = right_symbol.next; + if (right_symbol.next >= 0) { + symbols_[right_symbol.next].prev = bigram.left; + } + + add_new_bigram(left_symbol.prev, bigram.left); // left side of current symbol + add_new_bigram(bigram.left, left_symbol.next); // right side of current symbol + } + + // add the fnished tokens to the final list keeping correct order for next and prev + for (auto & sym : symbols_) { + if (sym.n > 0) { + sym.prev = final_prev_index; + sym.next = -1; + if (final_prev_index != -1) { + symbols_final[final_prev_index].next = symbols_final.size(); + } + symbols_final.emplace_back(sym); + final_prev_index = symbols_final.size() - 1; + } + } + } + + symbols_ = symbols_final; + if (symbols_.size()) + for (int i = 0; i != -1; i = symbols_[i].next) { + auto & symbol = symbols_[i]; + if (symbol.n == 0) { + continue; + } + std::string str = std::string(symbol.text, symbol.n); + std::string str_decoded = decode_token(str); + auto token = vocab_.token_to_id.find(str_decoded); + + if (token == vocab_.token_to_id.end()) { + for (auto j = str_decoded.begin(); j != str_decoded.end(); ++j) { + std::string byte_str(1, *j); + auto token_multibyte = vocab_.token_to_id.find(byte_str); + if (token_multibyte == vocab_.token_to_id.end()) { + fprintf(stderr,"ERROR: byte not found in vocab: '%s'\n", byte_str.c_str()); + } + output.push_back((*token_multibyte).second); + } + } else { + output.push_back((*token).second); + } + } + } + +private: + void add_new_bigram(int left, int right) { + if (left == -1 || right == -1) return; + + std::string left_token = std::string(symbols_[left].text, symbols_[left].n); + std::string right_token = std::string(symbols_[right].text, symbols_[right].n); + + int rank_found = -1; + rank_found = vocab_.find_bpe_rank(left_token, right_token); + + if (rank_found < 0) { + return; + } + + ggllm_bpe_bigram bigram; + bigram.left = left; + bigram.right = right; + bigram.rank = rank_found; + bigram.size = left_token.size() + right_token.size(); + bigram.text = left_token + right_token; + work_queue_.push(bigram); + } + + std::unordered_map bytes_to_unicode() { + static std::unordered_map hex_map = { + { 0x21, "\x21" }, { 0x22, "\x22" }, { 0x23, "\x23" }, { 0x24, "\x24" }, { 0x25, "\x25" }, { 0x26, "\x26" }, { 0x27, "\x27" }, { 0x28, "\x28" }, { 0x29, "\x29" }, { 0x2A, "\x2A" }, + { 0x2B, "\x2B" }, { 0x2C, "\x2C" }, { 0x2D, "\x2D" }, { 0x2E, "\x2E" }, { 0x2F, "\x2F" }, { 0x30, "\x30" }, { 0x31, "\x31" }, { 0x32, "\x32" }, { 0x33, "\x33" }, { 0x34, "\x34" }, + { 0x35, "\x35" }, { 0x36, "\x36" }, { 0x37, "\x37" }, { 0x38, "\x38" }, { 0x39, "\x39" }, { 0x3A, "\x3A" }, { 0x3B, "\x3B" }, { 0x3C, "\x3C" }, { 0x3D, "\x3D" }, { 0x3E, "\x3E" }, + { 0x3F, "\x3F" }, { 0x40, "\x40" }, { 0x41, "\x41" }, { 0x42, "\x42" }, { 0x43, "\x43" }, { 0x44, "\x44" }, { 0x45, "\x45" }, { 0x46, "\x46" }, { 0x47, "\x47" }, { 0x48, "\x48" }, + { 0x49, "\x49" }, { 0x4A, "\x4A" }, { 0x4B, "\x4B" }, { 0x4C, "\x4C" }, { 0x4D, "\x4D" }, { 0x4E, "\x4E" }, { 0x4F, "\x4F" }, { 0x50, "\x50" }, { 0x51, "\x51" }, { 0x52, "\x52" }, + { 0x53, "\x53" }, { 0x54, "\x54" }, { 0x55, "\x55" }, { 0x56, "\x56" }, { 0x57, "\x57" }, { 0x58, "\x58" }, { 0x59, "\x59" }, { 0x5A, "\x5A" }, { 0x5B, "\x5B" }, { 0x5C, "\x5C" }, + { 0x5D, "\x5D" }, { 0x5E, "\x5E" }, { 0x5F, "\x5F" }, { 0x60, "\x60" }, { 0x61, "\x61" }, { 0x62, "\x62" }, { 0x63, "\x63" }, { 0x64, "\x64" }, { 0x65, "\x65" }, { 0x66, "\x66" }, + { 0x67, "\x67" }, { 0x68, "\x68" }, { 0x69, "\x69" }, { 0x6A, "\x6A" }, { 0x6B, "\x6B" }, { 0x6C, "\x6C" }, { 0x6D, "\x6D" }, { 0x6E, "\x6E" }, { 0x6F, "\x6F" }, { 0x70, "\x70" }, + { 0x71, "\x71" }, { 0x72, "\x72" }, { 0x73, "\x73" }, { 0x74, "\x74" }, { 0x75, "\x75" }, { 0x76, "\x76" }, { 0x77, "\x77" }, { 0x78, "\x78" }, { 0x79, "\x79" }, { 0x7A, "\x7A" }, + { 0x7B, "\x7B" }, { 0x7C, "\x7C" }, { 0x7D, "\x7D" }, { 0x7E, "\x7E" }, { 0xA1, "\xC2\xA1" }, { 0xA2, "\xC2\xA2" }, { 0xA3, "\xC2\xA3" }, { 0xA4, "\xC2\xA4" }, { 0xA5, "\xC2\xA5" }, + { 0xA6, "\xC2\xA6" }, { 0xA7, "\xC2\xA7" }, { 0xA8, "\xC2\xA8" }, { 0xA9, "\xC2\xA9" }, { 0xAA, "\xC2\xAA" }, { 0xAB, "\xC2\xAB" }, { 0xAC, "\xC2\xAC" }, { 0xAE, "\xC2\xAE" }, + { 0xAF, "\xC2\xAF" }, { 0xB0, "\xC2\xB0" }, { 0xB1, "\xC2\xB1" }, { 0xB2, "\xC2\xB2" }, { 0xB3, "\xC2\xB3" }, { 0xB4, "\xC2\xB4" }, { 0xB5, "\xC2\xB5" }, { 0xB6, "\xC2\xB6" }, + { 0xB7, "\xC2\xB7" }, { 0xB8, "\xC2\xB8" }, { 0xB9, "\xC2\xB9" }, { 0xBA, "\xC2\xBA" }, { 0xBB, "\xC2\xBB" }, { 0xBC, "\xC2\xBC" }, { 0xBD, "\xC2\xBD" }, { 0xBE, "\xC2\xBE" }, + { 0xBF, "\xC2\xBF" }, { 0xC0, "\xC3\x80" }, { 0xC1, "\xC3\x81" }, { 0xC2, "\xC3\x82" }, { 0xC3, "\xC3\x83" }, { 0xC4, "\xC3\x84" }, { 0xC5, "\xC3\x85" }, { 0xC6, "\xC3\x86" }, + { 0xC7, "\xC3\x87" }, { 0xC8, "\xC3\x88" }, { 0xC9, "\xC3\x89" }, { 0xCA, "\xC3\x8A" }, { 0xCB, "\xC3\x8B" }, { 0xCC, "\xC3\x8C" }, { 0xCD, "\xC3\x8D" }, { 0xCE, "\xC3\x8E" }, + { 0xCF, "\xC3\x8F" }, { 0xD0, "\xC3\x90" }, { 0xD1, "\xC3\x91" }, { 0xD2, "\xC3\x92" }, { 0xD3, "\xC3\x93" }, { 0xD4, "\xC3\x94" }, { 0xD5, "\xC3\x95" }, { 0xD6, "\xC3\x96" }, + { 0xD7, "\xC3\x97" }, { 0xD8, "\xC3\x98" }, { 0xD9, "\xC3\x99" }, { 0xDA, "\xC3\x9A" }, { 0xDB, "\xC3\x9B" }, { 0xDC, "\xC3\x9C" }, { 0xDD, "\xC3\x9D" }, { 0xDE, "\xC3\x9E" }, + { 0xDF, "\xC3\x9F" }, { 0xE0, "\xC3\xA0" }, { 0xE1, "\xC3\xA1" }, { 0xE2, "\xC3\xA2" }, { 0xE3, "\xC3\xA3" }, { 0xE4, "\xC3\xA4" }, { 0xE5, "\xC3\xA5" }, { 0xE6, "\xC3\xA6" }, + { 0xE7, "\xC3\xA7" }, { 0xE8, "\xC3\xA8" }, { 0xE9, "\xC3\xA9" }, { 0xEA, "\xC3\xAA" }, { 0xEB, "\xC3\xAB" }, { 0xEC, "\xC3\xAC" }, { 0xED, "\xC3\xAD" }, { 0xEE, "\xC3\xAE" }, + { 0xEF, "\xC3\xAF" }, { 0xF0, "\xC3\xB0" }, { 0xF1, "\xC3\xB1" }, { 0xF2, "\xC3\xB2" }, { 0xF3, "\xC3\xB3" }, { 0xF4, "\xC3\xB4" }, { 0xF5, "\xC3\xB5" }, { 0xF6, "\xC3\xB6" }, + { 0xF7, "\xC3\xB7" }, { 0xF8, "\xC3\xB8" }, { 0xF9, "\xC3\xB9" }, { 0xFA, "\xC3\xBA" }, { 0xFB, "\xC3\xBB" }, { 0xFC, "\xC3\xBC" }, { 0xFD, "\xC3\xBD" }, { 0xFE, "\xC3\xBE" }, + { 0xFF, "\xC3\xBF" }, { 0x00, "\xC4\x80" }, { 0x01, "\xC4\x81" }, { 0x02, "\xC4\x82" }, { 0x03, "\xC4\x83" }, { 0x04, "\xC4\x84" }, { 0x05, "\xC4\x85" }, { 0x06, "\xC4\x86" }, + { 0x07, "\xC4\x87" }, { 0x08, "\xC4\x88" }, { 0x09, "\xC4\x89" }, { 0x0A, "\xC4\x8A" }, { 0x0B, "\xC4\x8B" }, { 0x0C, "\xC4\x8C" }, { 0x0D, "\xC4\x8D" }, { 0x0E, "\xC4\x8E" }, + { 0x0F, "\xC4\x8F" }, { 0x10, "\xC4\x90" }, { 0x11, "\xC4\x91" }, { 0x12, "\xC4\x92" }, { 0x13, "\xC4\x93" }, { 0x14, "\xC4\x94" }, { 0x15, "\xC4\x95" }, { 0x16, "\xC4\x96" }, + { 0x17, "\xC4\x97" }, { 0x18, "\xC4\x98" }, { 0x19, "\xC4\x99" }, { 0x1A, "\xC4\x9A" }, { 0x1B, "\xC4\x9B" }, { 0x1C, "\xC4\x9C" }, { 0x1D, "\xC4\x9D" }, { 0x1E, "\xC4\x9E" }, + { 0x1F, "\xC4\x9F" }, { 0x20, "\xC4\xA0" }, { 0x7F, "\xC4\xA1" }, { 0x80, "\xC4\xA2" }, { 0x81, "\xC4\xA3" }, { 0x82, "\xC4\xA4" }, { 0x83, "\xC4\xA5" }, { 0x84, "\xC4\xA6" }, + { 0x85, "\xC4\xA7" }, { 0x86, "\xC4\xA8" }, { 0x87, "\xC4\xA9" }, { 0x88, "\xC4\xAA" }, { 0x89, "\xC4\xAB" }, { 0x8A, "\xC4\xAC" }, { 0x8B, "\xC4\xAD" }, { 0x8C, "\xC4\xAE" }, + { 0x8D, "\xC4\xAF" }, { 0x8E, "\xC4\xB0" }, { 0x8F, "\xC4\xB1" }, { 0x90, "\xC4\xB2" }, { 0x91, "\xC4\xB3" }, { 0x92, "\xC4\xB4" }, { 0x93, "\xC4\xB5" }, { 0x94, "\xC4\xB6" }, + { 0x95, "\xC4\xB7" }, { 0x96, "\xC4\xB8" }, { 0x97, "\xC4\xB9" }, { 0x98, "\xC4\xBA" }, { 0x99, "\xC4\xBB" }, { 0x9A, "\xC4\xBC" }, { 0x9B, "\xC4\xBD" }, { 0x9C, "\xC4\xBE" }, + { 0x9D, "\xC4\xBF" }, { 0x9E, "\xC5\x80" }, { 0x9F, "\xC5\x81" }, { 0xA0, "\xC5\x82" }, { 0xAD, "\xC5\x83" } + }; + return hex_map; + } + + std::unordered_map unicode_to_bytes() { + static std::unordered_map hex_map = { + { "\x21", 0x21 }, { "\x22", 0x22 }, { "\x23", 0x23 }, { "\x24", 0x24 }, { "\x25", 0x25 }, { "\x26", 0x26 }, { "\x27", 0x27 }, { "\x28", 0x28 }, { "\x29", 0x29 }, { "\x2A", 0x2A }, + { "\x2B", 0x2B }, { "\x2C", 0x2C }, { "\x2D", 0x2D }, { "\x2E", 0x2E }, { "\x2F", 0x2F }, { "\x30", 0x30 }, { "\x31", 0x31 }, { "\x32", 0x32 }, { "\x33", 0x33 }, { "\x34", 0x34 }, + { "\x35", 0x35 }, { "\x36", 0x36 }, { "\x37", 0x37 }, { "\x38", 0x38 }, { "\x39", 0x39 }, { "\x3A", 0x3A }, { "\x3B", 0x3B }, { "\x3C", 0x3C }, { "\x3D", 0x3D }, { "\x3E", 0x3E }, + { "\x3F", 0x3F }, { "\x40", 0x40 }, { "\x41", 0x41 }, { "\x42", 0x42 }, { "\x43", 0x43 }, { "\x44", 0x44 }, { "\x45", 0x45 }, { "\x46", 0x46 }, { "\x47", 0x47 }, { "\x48", 0x48 }, + { "\x49", 0x49 }, { "\x4A", 0x4A }, { "\x4B", 0x4B }, { "\x4C", 0x4C }, { "\x4D", 0x4D }, { "\x4E", 0x4E }, { "\x4F", 0x4F }, { "\x50", 0x50 }, { "\x51", 0x51 }, { "\x52", 0x52 }, + { "\x53", 0x53 }, { "\x54", 0x54 }, { "\x55", 0x55 }, { "\x56", 0x56 }, { "\x57", 0x57 }, { "\x58", 0x58 }, { "\x59", 0x59 }, { "\x5A", 0x5A }, { "\x5B", 0x5B }, { "\x5C", 0x5C }, + { "\x5D", 0x5D }, { "\x5E", 0x5E }, { "\x5F", 0x5F }, { "\x60", 0x60 }, { "\x61", 0x61 }, { "\x62", 0x62 }, { "\x63", 0x63 }, { "\x64", 0x64 }, { "\x65", 0x65 }, { "\x66", 0x66 }, + { "\x67", 0x67 }, { "\x68", 0x68 }, { "\x69", 0x69 }, { "\x6A", 0x6A }, { "\x6B", 0x6B }, { "\x6C", 0x6C }, { "\x6D", 0x6D }, { "\x6E", 0x6E }, { "\x6F", 0x6F }, { "\x70", 0x70 }, + { "\x71", 0x71 }, { "\x72", 0x72 }, { "\x73", 0x73 }, { "\x74", 0x74 }, { "\x75", 0x75 }, { "\x76", 0x76 }, { "\x77", 0x77 }, { "\x78", 0x78 }, { "\x79", 0x79 }, { "\x7A", 0x7A }, + { "\x7B", 0x7B }, { "\x7C", 0x7C }, { "\x7D", 0x7D }, { "\x7E", 0x7E }, { "\xC2\xA1", 0xA1 }, { "\xC2\xA2", 0xA2 }, { "\xC2\xA3", 0xA3 }, { "\xC2\xA4", 0xA4 }, { "\xC2\xA5", 0xA5 }, + { "\xC2\xA6", 0xA6 }, { "\xC2\xA7", 0xA7 }, { "\xC2\xA8", 0xA8 }, { "\xC2\xA9", 0xA9 }, { "\xC2\xAA", 0xAA }, { "\xC2\xAB", 0xAB }, { "\xC2\xAC", 0xAC }, { "\xC2\xAE", 0xAE }, + { "\xC2\xAF", 0xAF }, { "\xC2\xB0", 0xB0 }, { "\xC2\xB1", 0xB1 }, { "\xC2\xB2", 0xB2 }, { "\xC2\xB3", 0xB3 }, { "\xC2\xB4", 0xB4 }, { "\xC2\xB5", 0xB5 }, { "\xC2\xB6", 0xB6 }, + { "\xC2\xB7", 0xB7 }, { "\xC2\xB8", 0xB8 }, { "\xC2\xB9", 0xB9 }, { "\xC2\xBA", 0xBA }, { "\xC2\xBB", 0xBB }, { "\xC2\xBC", 0xBC }, { "\xC2\xBD", 0xBD }, { "\xC2\xBE", 0xBE }, + { "\xC2\xBF", 0xBF }, { "\xC3\x80", 0xC0 }, { "\xC3\x81", 0xC1 }, { "\xC3\x82", 0xC2 }, { "\xC3\x83", 0xC3 }, { "\xC3\x84", 0xC4 }, { "\xC3\x85", 0xC5 }, { "\xC3\x86", 0xC6 }, + { "\xC3\x87", 0xC7 }, { "\xC3\x88", 0xC8 }, { "\xC3\x89", 0xC9 }, { "\xC3\x8A", 0xCA }, { "\xC3\x8B", 0xCB }, { "\xC3\x8C", 0xCC }, { "\xC3\x8D", 0xCD }, { "\xC3\x8E", 0xCE }, + { "\xC3\x8F", 0xCF }, { "\xC3\x90", 0xD0 }, { "\xC3\x91", 0xD1 }, { "\xC3\x92", 0xD2 }, { "\xC3\x93", 0xD3 }, { "\xC3\x94", 0xD4 }, { "\xC3\x95", 0xD5 }, { "\xC3\x96", 0xD6 }, + { "\xC3\x97", 0xD7 }, { "\xC3\x98", 0xD8 }, { "\xC3\x99", 0xD9 }, { "\xC3\x9A", 0xDA }, { "\xC3\x9B", 0xDB }, { "\xC3\x9C", 0xDC }, { "\xC3\x9D", 0xDD }, { "\xC3\x9E", 0xDE }, + { "\xC3\x9F", 0xDF }, { "\xC3\xA0", 0xE0 }, { "\xC3\xA1", 0xE1 }, { "\xC3\xA2", 0xE2 }, { "\xC3\xA3", 0xE3 }, { "\xC3\xA4", 0xE4 }, { "\xC3\xA5", 0xE5 }, { "\xC3\xA6", 0xE6 }, + { "\xC3\xA7", 0xE7 }, { "\xC3\xA8", 0xE8 }, { "\xC3\xA9", 0xE9 }, { "\xC3\xAA", 0xEA }, { "\xC3\xAB", 0xEB }, { "\xC3\xAC", 0xEC }, { "\xC3\xAD", 0xED }, { "\xC3\xAE", 0xEE }, + { "\xC3\xAF", 0xEF }, { "\xC3\xB0", 0xF0 }, { "\xC3\xB1", 0xF1 }, { "\xC3\xB2", 0xF2 }, { "\xC3\xB3", 0xF3 }, { "\xC3\xB4", 0xF4 }, { "\xC3\xB5", 0xF5 }, { "\xC3\xB6", 0xF6 }, + { "\xC3\xB7", 0xF7 }, { "\xC3\xB8", 0xF8 }, { "\xC3\xB9", 0xF9 }, { "\xC3\xBA", 0xFA }, { "\xC3\xBB", 0xFB }, { "\xC3\xBC", 0xFC }, { "\xC3\xBD", 0xFD }, { "\xC3\xBE", 0xFE }, + { "\xC3\xBF", 0xFF }, { "\xC4\x80", 0x00 }, { "\xC4\x81", 0x01 }, { "\xC4\x82", 0x02 }, { "\xC4\x83", 0x03 }, { "\xC4\x84", 0x04 }, { "\xC4\x85", 0x05 }, { "\xC4\x86", 0x06 }, + { "\xC4\x87", 0x07 }, { "\xC4\x88", 0x08 }, { "\xC4\x89", 0x09 }, { "\xC4\x8A", 0x0A }, { "\xC4\x8B", 0x0B }, { "\xC4\x8C", 0x0C }, { "\xC4\x8D", 0x0D }, { "\xC4\x8E", 0x0E }, + { "\xC4\x8F", 0x0F }, { "\xC4\x90", 0x10 }, { "\xC4\x91", 0x11 }, { "\xC4\x92", 0x12 }, { "\xC4\x93", 0x13 }, { "\xC4\x94", 0x14 }, { "\xC4\x95", 0x15 }, { "\xC4\x96", 0x16 }, + { "\xC4\x97", 0x17 }, { "\xC4\x98", 0x18 }, { "\xC4\x99", 0x19 }, { "\xC4\x9A", 0x1A }, { "\xC4\x9B", 0x1B }, { "\xC4\x9C", 0x1C }, { "\xC4\x9D", 0x1D }, { "\xC4\x9E", 0x1E }, + { "\xC4\x9F", 0x1F }, { "\xC4\xA0", 0x20 }, { "\xC4\xA1", 0x7F }, { "\xC4\xA2", 0x80 }, { "\xC4\xA3", 0x81 }, { "\xC4\xA4", 0x82 }, { "\xC4\xA5", 0x83 }, { "\xC4\xA6", 0x84 }, + { "\xC4\xA7", 0x85 }, { "\xC4\xA8", 0x86 }, { "\xC4\xA9", 0x87 }, { "\xC4\xAA", 0x88 }, { "\xC4\xAB", 0x89 }, { "\xC4\xAC", 0x8A }, { "\xC4\xAD", 0x8B }, { "\xC4\xAE", 0x8C }, + { "\xC4\xAF", 0x8D }, { "\xC4\xB0", 0x8E }, { "\xC4\xB1", 0x8F }, { "\xC4\xB2", 0x90 }, { "\xC4\xB3", 0x91 }, { "\xC4\xB4", 0x92 }, { "\xC4\xB5", 0x93 }, { "\xC4\xB6", 0x94 }, + { "\xC4\xB7", 0x95 }, { "\xC4\xB8", 0x96 }, { "\xC4\xB9", 0x97 }, { "\xC4\xBA", 0x98 }, { "\xC4\xBB", 0x99 }, { "\xC4\xBC", 0x9A }, { "\xC4\xBD", 0x9B }, { "\xC4\xBE", 0x9C }, + { "\xC4\xBF", 0x9D }, { "\xC5\x80", 0x9E }, { "\xC5\x81", 0x9F }, { "\xC5\x82", 0xA0 }, { "\xC5\x83", 0xAD } + }; + return hex_map; + } + + // len must be available + bool inline str_is_equal(const char* str1, const char* str2, size_t len) { + for (size_t i = 0; i < len; ++i) { + if (str1[i] != str2[i]) { + return false; + } + } + return true; + } + + std::vector bpe_gpt2_preprocess(const std::string& text) { + static std::unordered_map< unsigned char, std::string> byte_encoder = bytes_to_unicode(); + std::vector bpe_words; + std::vector bpe_encoded_words; + + std::string token=""; + const char *raw_text_p = text.c_str(); + // GPT2 system regex: 's|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+ + bool collecting_numeric = false; + bool collecting_letter = false; + bool collecting_special = false; + bool collecting_whitespace_lookahead = false; + bool collecting=false; + + std::vector text_utf; + text_utf.reserve(text.size()); + bpe_words.reserve(text.size()); + bpe_encoded_words.reserve(text.size()); + + text_utf = CNCTUnicode::split_utf8_enhanced(text); + + for (int i = 0; i < (int)text_utf.size(); i++) { + const CNCTString &utf_char = text_utf[i]; + bool split_condition = false; + const char *text_pos = raw_text_p + utf_char.seq_offset_bytes; + int bytes_remain = strlen(text_pos); + // forward backward lookups + const CNCTString &utf_char_next = (i+1 < (int)text_utf.size()) ? text_utf[i+1] : CNCTString(); + const CNCTString &utf_char_next_next = (i+2 < (int)text_utf.size()) ? text_utf[i+2] : CNCTString(); + // const CNCTString &utf_char_prev = (i > 0) ? text_utf[i-1] : CNCTString(); + + // handling contractions + if (!split_condition && bytes_remain >= 2) { + // 's|'t|'m|'d + if (utf_char == '\'' && (utf_char_next == 's' || utf_char_next == 't' || utf_char_next == 'm' || utf_char_next == 'd')) { + split_condition = true; + } + if (split_condition) { + if (token.size()) { + bpe_words.emplace_back(token); // push previous content as token + } + token = utf_char.str + utf_char_next.str; + bpe_words.emplace_back(token); + token=""; + i++; + continue; + } + } + if (!split_condition && bytes_remain >= 3) { + // 're|'ve|'ll + if (utf_char == '\'' && ( + (utf_char_next == 'r' || utf_char_next_next == 'e') || + (utf_char_next == 'v' || utf_char_next_next == 'e') || + (utf_char_next == 'l' || utf_char_next_next == 'l')) + ) { + split_condition = true; + } + if (split_condition) { + // current token + next token can be defined + if (token.size()) { + bpe_words.emplace_back(token); // push previous content as token + } + token = utf_char.str + utf_char_next.str + utf_char_next_next.str; + bpe_words.emplace_back(token); // the contraction + token=""; + i+=2; + continue; + } + } + + if (!split_condition && !collecting) { + if (utf_char.char_type == CNCTCharType::LETTER || (!token.size() && utf_char==" " && utf_char_next.char_type == CNCTCharType::LETTER)) { + collecting_letter = true; + collecting = true; + } else if (utf_char.char_type == CNCTCharType::DIGIT || (!token.size() && utf_char==" " && utf_char_next.char_type == CNCTCharType::DIGIT)) { + collecting_numeric = true; + collecting = true; + } else if ( + ((utf_char.char_type != CNCTCharType::LETTER && utf_char.char_type != CNCTCharType::DIGIT) && (utf_char.char_type != CNCTCharType::WHITESPACE)) || + (!token.size() && utf_char==" " && utf_char_next.char_type != CNCTCharType::LETTER && utf_char_next.char_type != CNCTCharType::DIGIT && utf_char_next.char_type != CNCTCharType::WHITESPACE) + ) { + collecting_special = true; + collecting = true; + } else if (utf_char.char_type == CNCTCharType::WHITESPACE && utf_char_next.char_type == CNCTCharType::WHITESPACE) { + collecting_whitespace_lookahead = true; + collecting = true; + } else if (utf_char.char_type == CNCTCharType::WHITESPACE) { + split_condition = true; + } + } else if (!split_condition && collecting) { + if (collecting_letter && utf_char.char_type != CNCTCharType::LETTER) { + split_condition = true; + } else if (collecting_numeric && utf_char.char_type != CNCTCharType::DIGIT) { + split_condition = true; + } else if (collecting_special && (utf_char.char_type == CNCTCharType::LETTER || utf_char.char_type == CNCTCharType::DIGIT || utf_char.char_type == CNCTCharType::WHITESPACE)) { + split_condition = true; + } else if (collecting_whitespace_lookahead && utf_char_next.char_type != CNCTCharType::WHITESPACE) { + split_condition = true; + } + } + + if(utf_char_next.str.size() == 0) { + split_condition = true; // final + token += utf_char.str; + } + + if (split_condition) { + if (token.size()) { + bpe_words.emplace_back(token); + } + token = utf_char.str; + collecting = false; + collecting_letter = false; + collecting_numeric = false; + collecting_special = false; + collecting_whitespace_lookahead = false; + } else { + token += utf_char.str; + } + } + + for (std::string& word : bpe_words) { + std::string encoded_token=""; + for (char& c : word) { + encoded_token += byte_encoder[c]; + } + bpe_encoded_words.emplace_back(encoded_token); + } + + return bpe_encoded_words; + } + + // decoder (for one token) + std::string decode_token(const std::string& token) { + static std::unordered_map< std::string, unsigned char> byte_decoder = unicode_to_bytes(); + std::string decoded_token=""; + auto unicode_seqeunces = CNCTUnicode::split_utf8(token); + for (auto& unicode_sequence : unicode_seqeunces) { + decoded_token += byte_decoder[unicode_sequence]; + } + + return decoded_token; + } + + const gpt2bpe_vocab & vocab_; + std::vector symbols_; + std::vector symbols_final; + ggllm_bpe_bigram::queue work_queue_; + bool flag_g2ws=false; +}; + +static std::vector gpt2bpe_tokenize(const gpt2bpe_vocab & vocab, const std::string & text, bool bos, bool g2ws ) { + gpt2bpe_tokenizer tokenizer(vocab, g2ws); + std::vector output; + + if (text.empty()) { + return output; + } + + if (bos && vocab.special_bos_id != -1) { + output.push_back(vocab.special_bos_id); + } + + tokenizer.tokenize(text, output); + return output; +} + +#endif // CMPNCT_GPT2BPE diff --git a/examples/gptneox-wip/falcon-main.cpp b/examples/gptneox-wip/falcon-main.cpp new file mode 100644 index 000000000..43b6a29f3 --- /dev/null +++ b/examples/gptneox-wip/falcon-main.cpp @@ -0,0 +1,1111 @@ +#include "ggml.h" +#include "cmpnct_gpt2bpe.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +// default hparams +struct falcon_hparams { + size_t n_merges = 0; + size_t n_vocab = 0; + uint32_t n_ctx = 0; + uint32_t n_embd = 0; + uint32_t n_head = 0; + uint32_t n_head_kv = 1; // Needs to be 1 for 7B model + uint32_t n_ff = 0; + uint32_t n_block = 0; + float norm_eps = 1e-5; +}; +struct falcon_block { + // normalization + struct ggml_tensor* input_layernorm; + struct ggml_tensor* input_layernorm_b; + struct ggml_tensor* attention_norm; // Falcon-40B only + struct ggml_tensor* attention_norm_b; // Falcon-40B only + + // attention + struct ggml_tensor* query_key_value; + struct ggml_tensor* wo; + + // ff + struct ggml_tensor* ffn_up; + struct ggml_tensor* ffn_down; +}; + +struct falcon_model { + falcon_hparams hparams; + + struct ggml_tensor* tok_embeddings; + struct ggml_tensor* output_norm; + struct ggml_tensor* output_norm_b; + struct ggml_tensor* lm_head; + + std::vector blocks; + + // key + value memory + struct ggml_tensor* memory_k; + struct ggml_tensor* memory_v; + + struct gguf_context * ggufctx; + struct ggml_context * ctx; + struct ggml_context * kvctx; + + std::map tensors; +}; + +struct gpt_params { + int32_t seed = -1; // RNG seed + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + uint32_t n_predict = 200; // new tokens to predict + uint32_t n_batch = 512; // batch size for prompt processing + + // sampling parameters + int32_t top_k = 40; + float top_p = 1.0f; + float temp = 0.8f; + int32_t repeat_last_n = 64; + float repeat_penalty = 1.02f; + + std::string model = ""; // model path + std::string prompt = ""; + + std::string token_test = ""; + bool interactive = false; + int32_t interactive_port = -1; + int32_t n_gpu_layers = 0; +}; + +void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); + fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n", params.n_gpu_layers); + fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); + fprintf(stderr, " prompt to start generation with (default: random)\n"); + fprintf(stderr, " -f FNAME, --file FNAME\n"); + fprintf(stderr, " load prompt from a file\n"); + fprintf(stderr, " -tt TOKEN_TEST, --token_test TOKEN_TEST\n"); + fprintf(stderr, " test tokenization\n"); + fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); + fprintf(stderr, " --top_k N top-k sampling, 0 = n_vocab (default: %d)\n", params.top_k); + fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); + fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); + fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n); + fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty); + fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); + fprintf(stderr, " -m FNAME, --model FNAME\n"); + fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); + fprintf(stderr, "\n"); +} + +// Function to check if the next argument exists +std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) { + if (i + 1 < argc && argv[i + 1][0] != '-') { + return argv[++i]; + } else { + fprintf(stderr, "error: %s requires one argument.\n", flag.c_str()); + gpt_print_usage(argc, argv, params); + exit(0); + } +} + +bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-s" || arg == "--seed") { + params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-t" || arg == "--threads") { + params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") { + params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-p" || arg == "--prompt") { + params.prompt = get_next_arg(i, argc, argv, arg, params); + } else if (arg == "-n" || arg == "--n_predict") { + params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "--top_k") { + params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "--top_p") { + params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "--temp") { + params.temp = std::stof(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "--repeat-last-n") { + params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "--repeat-penalty") { + params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-b" || arg == "--batch_size") { + params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-m" || arg == "--model") { + params.model = get_next_arg(i, argc, argv, arg, params); + } else if (arg == "-i" || arg == "--interactive") { + params.interactive = true; + } else if (arg == "-ip" || arg == "--interactive-port") { + params.interactive = true; + params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-h" || arg == "--help") { + gpt_print_usage(argc, argv, params); + exit(0); + } else if (arg == "-f" || arg == "--file") { + get_next_arg(i, argc, argv, arg, params); + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + break; + } + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); + if (params.prompt.back() == '\n') { + params.prompt.pop_back(); + } + } else if (arg == "-tt" || arg == "--token_test") { + params.token_test = get_next_arg(i, argc, argv, arg, params); + } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + gpt_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + +gpt2bpe_vocab::id sample_top_k_top_p_repeat( + const gpt2bpe_vocab & vocab, + const float * logits, + const int32_t * last_n_tokens_data, + size_t last_n_tokens_data_size, + int top_k, + double top_p, + double temp, + int repeat_last_n, + float repeat_penalty, + std::mt19937 & rng) { + + int n_logits = vocab.id_to_token.size(); + + const auto * plogits = logits; + + const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size); + + if (temp <= 0) { + // select the token with the highest logit directly + float max_logit = plogits[0]; + gpt2bpe_vocab::id max_id = 0; + + for (int i = 1; i < n_logits; ++i) { + if (plogits[i] > max_logit) { + max_logit = plogits[i]; + max_id = i; + } + } + return max_id; + } + + + std::vector> logits_id; + logits_id.reserve(n_logits); + + { + const float scale = 1.0f/temp; + for (int i = 0; i < n_logits; ++i) { + // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main + if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) { + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if (plogits[i] < 0.0f) { + logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); + } + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale, i)); + } + } + } + + // find the top K tokens + std::partial_sort( + logits_id.begin(), + logits_id.begin() + top_k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + logits_id.resize(top_k); + + double maxl = -INFINITY; + for (const auto & kv : logits_id) { + maxl = std::max(maxl, kv.first); + } + + // compute probs for the top K tokens + std::vector probs; + probs.reserve(logits_id.size()); + + double sum = 0.0; + for (const auto & kv : logits_id) { + double p = exp(kv.first - maxl); + probs.push_back(p); + sum += p; + } + + // normalize the probs + for (auto & p : probs) { + p /= sum; + } + + if (top_p < 1.0f) { + double cumsum = 0.0f; + for (int i = 0; i < top_k; i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + top_k = i + 1; + probs.resize(top_k); + logits_id.resize(top_k); + break; + } + } + + cumsum = 1.0/cumsum; + for (int i = 0; i < (int) probs.size(); i++) { + probs[i] *= cumsum; + } + } + +// printf("\n"); +// for (int i = 0; i < (int) probs.size(); i++) { +// for (int i = 0; i < 10; i++) { +// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); +// } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + + return logits_id[idx].second; + +} + +struct ggml_tensor * get_tensor_ex( struct ggml_context * ctx, std::string name){ + + struct ggml_tensor * cur = ggml_get_tensor(ctx, name.c_str()); + if( cur == NULL ) { + fprintf(stdout, "%s: tensor '%s' not found!\n", __func__, name.c_str()); + } else { +// fprintf(stdout, "%s: n_dims = %d, name = '%s'\n", __func__, cur->n_dims, cur->name); + } + + return cur; +} + +// load the model's weights from a file +bool falcon_model_load(const std::string & fname, falcon_model & model, gpt2bpe_vocab & vocab) { + printf("%s: loading model from '%s'..\n", __func__, fname.c_str()); + + model.ctx = NULL; + + struct gguf_init_params ggufparams = { + /*.no_alloc = */ false, + /*.ctx = */ &model.ctx, + }; + + auto & ggufctx = model.ggufctx; + + ggufctx = gguf_init_from_file(fname.c_str(), ggufparams); + + if (!ggufctx) { + fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__); + return false; + } + + fprintf(stdout, "%s: gguf version = %d\n", __func__, gguf_get_version(ggufctx)); + fprintf(stdout, "%s: gguf alignment = %zu\n", __func__, gguf_get_alignment(ggufctx)); + fprintf(stdout, "%s: gguf data offset = %zu\n", __func__, gguf_get_data_offset(ggufctx)); + + // print all kv + #if 0 + { + const int n_kv = gguf_get_n_kv(ggufctx); + + fprintf(stdout, "%s: n_kv: %d\n", __func__, n_kv); + + for (int i = 0; i < n_kv; ++i) { + const char * key = gguf_get_key(ggufctx, i); + + fprintf(stdout, "%s: kv[%d]: key = %s\n", __func__, i, key); + } + } + #endif + + // print some standard metadata + { + int keyidx; + + keyidx = gguf_find_key(ggufctx, "general.name"); + if (keyidx != -1) { fprintf(stdout, "%s: model name = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.description"); + if (keyidx != -1) { fprintf(stdout, "%s: model description = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.author"); + if (keyidx != -1) { fprintf(stdout, "%s: model author = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.license"); + if (keyidx != -1) { fprintf(stdout, "%s: model license = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.architecture"); + if (keyidx != -1) { fprintf(stdout, "%s: model architecture = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.file_type"); + if (keyidx != -1) { fprintf(stdout, "%s: model file type = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "gptneox.tensor_data_layout"); + if (keyidx != -1) { fprintf(stdout, "%s: model data layout = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.source.hugginface.repository"); + if (keyidx != -1) { fprintf(stdout, "%s: model source HF repo = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + } + + // check required metadata + { + int keyidx; + + // check model architecture kv + keyidx = gguf_find_key(ggufctx, "general.architecture"); + if (keyidx != -1) { + if ( strcmp(gguf_get_val_str(ggufctx, keyidx), "falcon") != 0) { + fprintf(stdout, "%s: model architecture not supported!\n", __func__); + return false; + } + } else { + fprintf(stdout, "%s: gguf model architecture not found!\n", __func__); + return false; + } + + // check model tensor data layout kv + keyidx = gguf_find_key(ggufctx, "falcon.tensor_data_layout"); + if (keyidx != -1) { + if ( strcmp(gguf_get_val_str(ggufctx, keyidx), "jploski") != 0) { + fprintf(stdout, "%s: model tensor data layout not supported!\n", __func__); + return false; + } + } else { + fprintf(stdout, "%s: gguf model tensor data layout not found!\n", __func__); + return false; + } + + } + + // load hparams + { + auto & hparams = model.hparams; + + bool ok = true; + int keyidx; + + if (ok) { keyidx = gguf_find_key(ggufctx, "falcon.context_length"); + if (keyidx != -1) { hparams.n_ctx = gguf_get_val_u32(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "falcon.embedding_length"); + if (keyidx != -1) { hparams.n_embd = gguf_get_val_u32(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "falcon.attention.head_count"); + if (keyidx != -1) { hparams.n_head = gguf_get_val_u32(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "falcon.feed_forward_length"); + if (keyidx != -1) { hparams.n_ff = gguf_get_val_u32(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "falcon.block_count"); + if (keyidx != -1) { hparams.n_block = gguf_get_val_u32(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "falcon.attention.layer_norm_epsilon"); + if (keyidx != -1) { hparams.norm_eps= gguf_get_val_f32(ggufctx, keyidx); } else { ok = false; } } + + if (!ok) { + fprintf(stderr, "%s: required hparam missing!\n", __func__); + return false; + } + + keyidx = gguf_find_key(ggufctx, "falcon.attention.head_count_kv"); + if (keyidx != -1) { hparams.n_head_kv = gguf_get_val_u32(ggufctx, keyidx); } + + + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_head_kv = %d\n", __func__, hparams.n_head_kv); + printf("%s: n_block = %d\n", __func__, hparams.n_block); + printf("%s: norm_eps = %g\n", __func__, hparams.norm_eps); + + } + + // load vocab + { + auto & hparams = model.hparams; + + int keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.model"); + + if (keyidx != -1) { + if ( strcmp(gguf_get_val_str(ggufctx, keyidx), "gpt2") != 0) { + fprintf(stdout, "%s: tokenizer model not supported!\n", __func__); + return false; + } + } else { + fprintf(stdout, "%s: tokenizer model not found!\n", __func__); + return false; + } + + + int tokens_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.tokens"); + + if (tokens_keyidx == -1) { + fprintf(stdout, "%s: gpt2 tokenizer vocab not found!\n", __func__); + return false; + } + + int merges_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.merges"); + + if (merges_keyidx == -1) { + fprintf(stdout, "%s: gpt2 tokenizer merges not found!\n", __func__); + return false; + } + + hparams.n_vocab = gguf_get_arr_n(ggufctx,tokens_keyidx); + hparams.n_merges = gguf_get_arr_n(ggufctx,merges_keyidx); + + fprintf(stdout, "%s: gpt2 tokenizer vocab = %zu\n", __func__, hparams.n_vocab); + fprintf(stdout, "%s: gpt2 tokenizer merges = %zu\n", __func__, hparams.n_merges); + + for (size_t i = 0; i < hparams.n_vocab; i++) { + std::string word = gguf_get_arr_str(ggufctx, tokens_keyidx, i); + +// printf("token %d = '%s'\n",i,word.c_str() ); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + + if( vocab.id_to_token[i] == "\n" ) { + vocab.linefeed_id = i; + } + } + + std::vector> bpe_merges; + + for (size_t i = 0; i < hparams.n_merges; i++) { + + std::string word = gguf_get_arr_str(ggufctx, merges_keyidx, i); + + // Split the merges + std::string first, second; + size_t pos = word.find(' ', 1); // Start the search from the second character + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_merges.push_back(std::make_pair(first, second)); + } + + vocab.populate_bpe_ranks(bpe_merges); + + + keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.bos_token_id"); if( keyidx != -1 ) { vocab.special_bos_id = (int32_t)gguf_get_val_u32(ggufctx, keyidx); } + keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.eos_token_id"); if( keyidx != -1 ) { vocab.special_eos_id = (int32_t)gguf_get_val_u32(ggufctx, keyidx); } + keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.unknown_token_id"); if( keyidx != -1 ) { vocab.special_unk_id = (int32_t)gguf_get_val_u32(ggufctx, keyidx); } + keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.separator_token_id"); if( keyidx != -1 ) { vocab.special_sep_id = (int32_t)gguf_get_val_u32(ggufctx, keyidx); } + keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.padding_token_id"); if( keyidx != -1 ) { vocab.special_pad_id = (int32_t)gguf_get_val_u32(ggufctx, keyidx); } + + if( vocab.special_bos_id != -1 ) { fprintf(stdout, "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].c_str() ); } + if( vocab.special_eos_id != -1 ) { fprintf(stdout, "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].c_str() ); } + if( vocab.special_unk_id != -1 ) { fprintf(stdout, "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].c_str() ); } + if( vocab.special_sep_id != -1 ) { fprintf(stdout, "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].c_str() ); } + if( vocab.special_pad_id != -1 ) { fprintf(stdout, "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].c_str() ); } + if( vocab.linefeed_id != -1 ) { fprintf(stdout, "%s: LF token = %d\n", __func__, vocab.linefeed_id ); } + + } + + + auto & ctx = model.ctx; + size_t ctx_size = ggml_get_mem_size(ctx); + + printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + + // print tensor info + #if 0 + { + const int n_tensors = gguf_get_n_tensors(ggufctx); + + fprintf(stdout, "%s: n_tensors: %d\n", __func__, n_tensors); + + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name (ggufctx, i); + const size_t offset = gguf_get_tensor_offset(ggufctx, i); + + fprintf(stdout, "%s: tensor[%d]: name = %s, offset = %zu\n", __func__, i, name, offset); + } + } + #endif + + // prepare memory for the weights + { + + auto & hparams = model.hparams; + + const int n_block = hparams.n_block; + + model.blocks.resize(n_block); + + model.tok_embeddings = ggml_get_tensor(ctx, "token_embd.weight"); + + model.output_norm = ggml_get_tensor(ctx, "output_norm.weight"); + model.output_norm_b = ggml_get_tensor(ctx, "output_norm.bias"); + model.lm_head = ggml_get_tensor(ctx, "output.weight"); + + // map by name + model.tensors["token_embd.weight"] = model.tok_embeddings; + model.tensors["output_norm.weight"] = model.output_norm; + model.tensors["output_norm.bias"] = model.output_norm_b; + model.tensors["output.weight"] = model.lm_head; + + for (int i = 0; i < n_block; ++i) { + + auto& block = model.blocks[i]; + std::string blocknamestart = "blk." + std::to_string(i) + "."; + + block.input_layernorm = get_tensor_ex(ctx, blocknamestart + "attn_norm.weight" ); + block.input_layernorm_b = get_tensor_ex(ctx, blocknamestart + "attn_norm.bias" ); + + if ( hparams.n_head_kv == 8 ) { // Falcon-40B + block.attention_norm = get_tensor_ex(ctx, blocknamestart + "attn_norm_2.weight" ); + block.attention_norm_b = get_tensor_ex(ctx, blocknamestart + "attn_norm_2.bias" ); + } + + // query_key_value shape for config.multi_query == True: + block.query_key_value = get_tensor_ex(ctx, blocknamestart + "attn_qkv.weight" ); + block.wo = get_tensor_ex(ctx, blocknamestart + "attn_output.weight" ); + + block.ffn_up = get_tensor_ex(ctx, blocknamestart + "ffn_up.weight" ); + block.ffn_down = get_tensor_ex(ctx, blocknamestart + "ffn_down.weight" ); + + // map by name + if ( hparams.n_head_kv == 8 ) { // Falcon-40B + // Falcon-40B: + model.tensors[blocknamestart + "attn_norm.weight"] = block.input_layernorm; + model.tensors[blocknamestart + "attn_norm.bias"] = block.input_layernorm_b; + model.tensors[blocknamestart + "attn_norm_2.weight"] = block.attention_norm; + model.tensors[blocknamestart + "attn_norm_2.bias"] = block.attention_norm_b; + } else { + // Falcon-7B: + model.tensors[blocknamestart + "attn_norm.weight"] = block.input_layernorm; + model.tensors[blocknamestart + "attn_norm.bias"] = block.input_layernorm_b; + } + + model.tensors[blocknamestart + "attn_qkv.weight"] = block.query_key_value; + model.tensors[blocknamestart + "attn_output.weight"] = block.wo; + + model.tensors[blocknamestart + "ffn_up.weight"] = block.ffn_up; + model.tensors[blocknamestart + "ffn_down.weight"] = block.ffn_down; + } + } + + // key + value memory + { + const auto & kvctx = model.kvctx; + const auto & hparams = model.hparams; + + const int n_block = hparams.n_block; + const int n_ctx = hparams.n_ctx; + const int n_embd = hparams.n_embd; + + const int64_t n_mem = n_block*n_ctx; + const int64_t n_elements = n_embd*n_mem; + + // create the ggml context + { + struct ggml_init_params params = { + /*.mem_size =*/ size_t(n_elements*4+ggml_tensor_overhead()*2), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + model.kvctx = ggml_init(params); + if (!model.kvctx) { + fprintf(stderr, "%s: kv ggml_init() failed\n", __func__); + return false; + } + + } + + + model.memory_k = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_elements); + model.memory_v = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_elements); + + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + + printf("%s: memory_size = %8.2f MB, n_mem = %" PRId64 "\n", __func__, memory_size/1024.0/1024.0, n_mem); + } + + return true; +} + + +// evaluate the transformer +// +// - model: the model +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +bool falcon_eval( + const falcon_model & model, + const int n_threads, + const int n_past, + const std::vector & embd_inp, + std::vector & embd_w, + size_t & mem_per_token) { + + + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_block = hparams.n_block; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + const int n_head_kv = hparams.n_head_kv; + const int n_vocab = hparams.n_vocab; + const size_t head_dim = n_embd / n_head; + + static size_t buf_size = 256u*1024*1024; + static void * buf = malloc(buf_size); + + // use 2 scratch buffers + // TODO: very hacky solution - reimplement in a more elegant way + static size_t scr0_size = 256u*1024*1024; + static void * scr0 = malloc(scr0_size); + + static size_t scr1_size = 256u*1024*1024; + static void * scr1 = malloc(scr1_size); + + if (mem_per_token > 0 && mem_per_token*N > buf_size) { + const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead + //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); + + // reallocate + buf_size = buf_size_new; + buf = realloc(buf, buf_size); + if (buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + return false; + } + } + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = {}; +// gf.n_threads = n_threads; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); + + // wte + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); +// struct ggml_tensor* repeat_dummy = ggml_new_tensor_3d(ctx0, inpL->type, head_dim, N + n_past, n_head); + + ggml_type wtype = GGML_TYPE_F32; + const int sizeof_wtype = ggml_type_sizef(wtype); + + for (int il = 0; il < n_block; ++il) { + struct ggml_tensor * cur; + struct ggml_tensor * layernorm_output; + + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + + // self-attention + { + layernorm_output = ggml_norm(ctx0, inpL); + + layernorm_output = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.blocks[il].input_layernorm, layernorm_output), + layernorm_output), + ggml_repeat(ctx0, model.blocks[il].input_layernorm_b, layernorm_output)); + + if ( hparams.n_head_kv == 8 ) { // Falcon-40B + cur = ggml_norm(ctx0, inpL); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.blocks[il].attention_norm, cur), + cur), + ggml_repeat(ctx0, model.blocks[il].attention_norm_b, cur)); + } + else { // Falcon 7B + cur = layernorm_output; + } + + // compute QKV + + cur = ggml_mul_mat(ctx0, model.blocks[il].query_key_value, cur); + + // Note that the strides for Kcur, Vcur are set up so that the + // resulting views are misaligned with the tensor's storage + // (by applying the K/V offset we shift the tensor's original + // view to stick out behind the viewed QKV tensor's allocated + // memory, so to say). This is ok because no actual accesses + // happen to that out-of-range memory, but it can require some + // trickery when trying to accurately dump these views for + // debugging. + + struct ggml_tensor * Qcur = ggml_view_3d( + ctx0, cur, head_dim, n_head, N, + head_dim * sizeof_wtype, + head_dim * (n_head + 2 * n_head_kv) * sizeof_wtype, + 0); + + struct ggml_tensor * Kcur = ggml_view_3d( + ctx0, cur, head_dim, n_head_kv, N, + head_dim * sizeof_wtype, + head_dim * (n_head + 2 * n_head_kv) * sizeof_wtype, + head_dim * n_head * sizeof_wtype); + + struct ggml_tensor * Vcur = ggml_view_3d( + ctx0, cur, head_dim, n_head_kv, N, + head_dim * sizeof_wtype, + head_dim * (n_head + 2 * n_head_kv) * sizeof_wtype, + head_dim * (n_head + n_head_kv) * sizeof_wtype); + + // using mode = 2 for neox mode + Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, head_dim, 2, 0); + Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, head_dim, 2, 0); + + // store key and value to memory + { + struct ggml_tensor* k = ggml_view_1d( + ctx0, model.memory_k, N * n_head_kv * head_dim, + (ggml_element_size(model.memory_k) * n_head_kv * head_dim) * + (il * n_ctx + n_past)); + struct ggml_tensor* v = ggml_view_1d( + ctx0, model.memory_v, N * n_head_kv * head_dim, + (ggml_element_size(model.memory_v) * n_head_kv * head_dim) * + (il * n_ctx + n_past)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * K = ggml_permute( + ctx0, + ggml_reshape_3d( + ctx0, + ggml_view_1d(ctx0, model.memory_k, (n_past + N) * n_head_kv * head_dim, + il * n_ctx * + ggml_element_size(model.memory_k) * + n_head_kv * + head_dim), + head_dim, n_head_kv, n_past + N), + 0, 2, 1, 3); + + // K * Q + +// K = ggml_cont(ctx0, ggml_repeat2(ctx0, K, repeat_dummy)); + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ggml_tensor * KQ_scaled = + ggml_scale_inplace(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(head_dim))) + ); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ggml_tensor* V = ggml_permute( + ctx0, + ggml_reshape_3d( + ctx0, + ggml_view_1d(ctx0, model.memory_v, (n_past + N) * n_head_kv * head_dim, + il * n_ctx * + ggml_element_size(model.memory_v) * + n_head_kv * + head_dim), + head_dim, n_head_kv, n_past + N), + 0, 2, 1, 3); + +// V = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_repeat2(ctx0, V, repeat_dummy))); + V = ggml_cont(ctx0, ggml_transpose(ctx0, V)); + + // KQV = transpose(V) * KQ_soft_max + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, + KQV_merged, + ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection + { + cur = ggml_mul_mat(ctx0, + model.blocks[il].wo, + cur); + } + } + + ggml_set_scratch(ctx0, { 0, scr1_size, scr1, }); + + struct ggml_tensor* inpFF = layernorm_output; + struct ggml_tensor* attn_out = ggml_cpy( + ctx0, cur, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + { + cur = ggml_mul_mat(ctx0, model.blocks[il].ffn_up, inpFF); + cur = ggml_gelu(ctx0, cur); + cur = ggml_mul_mat(ctx0, model.blocks[il].ffn_down, cur); + } + + cur = ggml_add(ctx0, cur, attn_out); + cur = ggml_add(ctx0, cur, inpL); + // input for next layer + inpL = cur; + } + + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + + // norm + { + inpL = ggml_norm(ctx0, inpL); + + // inpL = ln_f_g*inpL + ln_f_b + inpL = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.output_norm, inpL), + inpL), + ggml_repeat(ctx0, model.output_norm_b, inpL)); + } + + ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + + // lm_head + { + inpL = ggml_mul_mat(ctx0, model.lm_head, inpL); + + //inpL = ggml_add(ctx0, + // ggml_repeat(ctx0, model.lmh_b, inpL), + // inpL); + } + + // logits -> probs + //inpL = ggml_soft_max_inplace(ctx0, inpL); + + // run the computation + ggml_build_forward_expand(&gf, inpL); +// ggml_graph_compute (ctx0, &gf); + ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + + //if (n_past%100 == 0) { + // ggml_graph_print (&gf); + // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + //} + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *)ggml_get_data(inpL) + (n_vocab * (N - 1)), sizeof(float) * n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0)/N; + } + //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); + + ggml_free(ctx0); + + return true; +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + const int64_t t_main_start_us = ggml_time_us(); + + gpt_params params; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + int64_t t_load_us = 0; + + gpt2bpe_vocab vocab; + falcon_model model; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!falcon_model_load(params.model, model, vocab)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + if (params.top_k == 0) { + params.top_k = model.hparams.n_vocab; + } + + printf("%s: seed = %d\n", __func__, params.seed); + printf("%s: temp = %.3f\n", __func__, params.temp); + printf("%s: top_k = %d\n", __func__, params.top_k); + printf("%s: top_p = %.3f\n", __func__, params.top_p); + printf("%s: repeat_last_n = %d\n", __func__, params.repeat_last_n); + printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty); + + std::mt19937 rng(params.seed); + + if (params.prompt.empty()) { + params.prompt = "Once upon"; + } + + std::vector last_n_tokens(model.hparams.n_ctx); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + + int n_past = 0; + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = gpt2bpe_tokenize(vocab, params.prompt,false, false); + + params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); + + printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); +// for (size_t i = 0; i < embd_inp.size(); i++) { +// printf("%s: token[%zu] = %6d, %s\n", __func__, i, embd_inp[i], vocab.id_to_token[embd_inp[i]].c_str()); +// } + + if( model.hparams.n_ctx < params.n_predict+embd_inp.size() ) { + params.n_predict = model.hparams.n_ctx-embd_inp.size(); + } + + printf("%s: n_predict = %d\n", __func__, params.n_predict); + printf("\n"); + + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + falcon_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + + for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { + // predict + if (embd.size() > 0) { + const int64_t t_start_us = ggml_time_us(); + + if (!falcon_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + printf("Failed to predict\n"); + return 1; + } + + t_predict_us += ggml_time_us() - t_start_us; + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + // sample next token + const int top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + const int repeat_last_n = params.repeat_last_n; + const float repeat_penalty = params.repeat_penalty; + + const int n_vocab = model.hparams.n_vocab; + + gpt2bpe_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ggml_time_us(); + + id = sample_top_k_top_p_repeat(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_last_n, repeat_penalty, rng); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); + + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (size_t k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (embd.size() > params.n_batch) { + break; + } + } + i += embd.size() - 1; + } + + // display text + for (auto id : embd) { + printf("%s", vocab.id_to_token[id].c_str() ); + } + fflush(stdout); + + // end of text token + if (vocab.special_eos_id != -1 && embd.back() == vocab.special_eos_id) { + break; + } + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n\n"); + printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); + printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); + printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + ggml_free(model.ctx); + + return 0; +} diff --git a/examples/gptneox-wip/gptneox-main.cpp b/examples/gptneox-wip/gptneox-main.cpp new file mode 100644 index 000000000..04af50245 --- /dev/null +++ b/examples/gptneox-wip/gptneox-main.cpp @@ -0,0 +1,1082 @@ +#include "ggml.h" +#include "cmpnct_gpt2bpe.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4267) // possible loss of data +#endif + +// default hparams +struct gpt_neox_hparams { + size_t n_merges = 0; + size_t n_vocab = 0; + uint32_t n_ctx = 0; + uint32_t n_embd = 0; + uint32_t n_head = 0; + uint32_t n_block = 0; + uint32_t n_rot = 0; // rotary_pct * (n_embd / n_head) + bool par_res = true; + float norm_eps = 1e-5; +}; + +struct gpt_neox_block { + // pre normalization + struct ggml_tensor * ln_1_g; + struct ggml_tensor * ln_1_b; + + // attention + struct ggml_tensor * c_attn_attn_w; + struct ggml_tensor * c_attn_attn_b; + + struct ggml_tensor * c_attn_proj_w; + struct ggml_tensor * c_attn_proj_b; + + // post normalization + struct ggml_tensor * ln_2_g; + struct ggml_tensor * ln_2_b; + + // ff + struct ggml_tensor * c_mlp_fc_w; + struct ggml_tensor * c_mlp_fc_b; + + struct ggml_tensor * c_mlp_proj_w; + struct ggml_tensor * c_mlp_proj_b; +}; + +struct gpt_neox_model { + gpt_neox_hparams hparams; + + // normalization + struct ggml_tensor * ln_f_g; + struct ggml_tensor * ln_f_b; + + struct ggml_tensor * wte; // position embedding + + struct ggml_tensor * lmh_g; // language model head + + std::vector blocks; + + // key + value memory + struct ggml_tensor * memory_k; + struct ggml_tensor * memory_v; + + // + struct gguf_context * ggufctx; + struct ggml_context * ctx; + struct ggml_context * kvctx; + + std::map tensors; +}; + +struct gpt_params { + int32_t seed = -1; // RNG seed + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + uint32_t n_predict = 200; // new tokens to predict + uint32_t n_batch = 512; // batch size for prompt processing + + // sampling parameters + int32_t top_k = 40; + float top_p = 1.0f; + float temp = 0.8f; + int32_t repeat_last_n = 64; + float repeat_penalty = 1.02f; + + std::string model = ""; // model path + std::string prompt = ""; + + std::string token_test = ""; + bool interactive = false; + int32_t interactive_port = -1; + int32_t n_gpu_layers = 0; +}; + +void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help show this help message and exit\n"); + fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1)\n"); + fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); + fprintf(stderr, " -ngl N, --gpu-layers N number of layers to offload to GPU on supported models (default: %d)\n", params.n_gpu_layers); + fprintf(stderr, " -p PROMPT, --prompt PROMPT\n"); + fprintf(stderr, " prompt to start generation with (default: random)\n"); + fprintf(stderr, " -f FNAME, --file FNAME\n"); + fprintf(stderr, " load prompt from a file\n"); + fprintf(stderr, " -tt TOKEN_TEST, --token_test TOKEN_TEST\n"); + fprintf(stderr, " test tokenization\n"); + fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d)\n", params.n_predict); + fprintf(stderr, " --top_k N top-k sampling, 0 = n_vocab (default: %d)\n", params.top_k); + fprintf(stderr, " --top_p N top-p sampling (default: %.1f)\n", params.top_p); + fprintf(stderr, " --temp N temperature (default: %.1f)\n", params.temp); + fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled)\n", params.repeat_last_n); + fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.2f, 1.0 = disabled)\n", (double)params.repeat_penalty); + fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); + fprintf(stderr, " -m FNAME, --model FNAME\n"); + fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); + fprintf(stderr, "\n"); +} + +// Function to check if the next argument exists +std::string get_next_arg(int& i, int argc, char** argv, const std::string& flag, gpt_params& params) { + if (i + 1 < argc && argv[i + 1][0] != '-') { + return argv[++i]; + } else { + fprintf(stderr, "error: %s requires one argument.\n", flag.c_str()); + gpt_print_usage(argc, argv, params); + exit(0); + } +} + +bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-s" || arg == "--seed") { + params.seed = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-t" || arg == "--threads") { + params.n_threads = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-ngl" || arg == "--gpu-layers" || arg == "--n-gpu-layers") { + params.n_gpu_layers = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-p" || arg == "--prompt") { + params.prompt = get_next_arg(i, argc, argv, arg, params); + } else if (arg == "-n" || arg == "--n_predict") { + params.n_predict = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "--top_k") { + params.top_k = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "--top_p") { + params.top_p = std::stof(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "--temp") { + params.temp = std::stof(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "--repeat-last-n") { + params.repeat_last_n = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "--repeat-penalty") { + params.repeat_penalty = std::stof(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-b" || arg == "--batch_size") { + params.n_batch= std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-m" || arg == "--model") { + params.model = get_next_arg(i, argc, argv, arg, params); + } else if (arg == "-i" || arg == "--interactive") { + params.interactive = true; + } else if (arg == "-ip" || arg == "--interactive-port") { + params.interactive = true; + params.interactive_port = std::stoi(get_next_arg(i, argc, argv, arg, params)); + } else if (arg == "-h" || arg == "--help") { + gpt_print_usage(argc, argv, params); + exit(0); + } else if (arg == "-f" || arg == "--file") { + get_next_arg(i, argc, argv, arg, params); + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + break; + } + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); + if (params.prompt.back() == '\n') { + params.prompt.pop_back(); + } + } else if (arg == "-tt" || arg == "--token_test") { + params.token_test = get_next_arg(i, argc, argv, arg, params); + } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + gpt_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + +gpt2bpe_vocab::id sample_top_k_top_p_repeat( + const gpt2bpe_vocab & vocab, + const float * logits, + const int32_t * last_n_tokens_data, + size_t last_n_tokens_data_size, + int top_k, + double top_p, + double temp, + int repeat_last_n, + float repeat_penalty, + std::mt19937 & rng) { + + int n_logits = vocab.id_to_token.size(); + + const auto * plogits = logits; + + const auto last_n_tokens = std::vector(last_n_tokens_data, last_n_tokens_data + last_n_tokens_data_size); + + if (temp <= 0) { + // select the token with the highest logit directly + float max_logit = plogits[0]; + gpt2bpe_vocab::id max_id = 0; + + for (int i = 1; i < n_logits; ++i) { + if (plogits[i] > max_logit) { + max_logit = plogits[i]; + max_id = i; + } + } + return max_id; + } + + + std::vector> logits_id; + logits_id.reserve(n_logits); + + { + const float scale = 1.0f/temp; + for (int i = 0; i < n_logits; ++i) { + // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) + // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main + if (repeat_last_n > 0 && std::find(last_n_tokens.end()-repeat_last_n, last_n_tokens.end(), i) != last_n_tokens.end()) { + // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability + if (plogits[i] < 0.0f) { + logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); + } + } else { + logits_id.push_back(std::make_pair(plogits[i]*scale, i)); + } + } + } + + // find the top K tokens + std::partial_sort( + logits_id.begin(), + logits_id.begin() + top_k, logits_id.end(), + [](const std::pair & a, const std::pair & b) { + return a.first > b.first; + }); + + logits_id.resize(top_k); + + double maxl = -INFINITY; + for (const auto & kv : logits_id) { + maxl = std::max(maxl, kv.first); + } + + // compute probs for the top K tokens + std::vector probs; + probs.reserve(logits_id.size()); + + double sum = 0.0; + for (const auto & kv : logits_id) { + double p = exp(kv.first - maxl); + probs.push_back(p); + sum += p; + } + + // normalize the probs + for (auto & p : probs) { + p /= sum; + } + + if (top_p < 1.0f) { + double cumsum = 0.0f; + for (int i = 0; i < top_k; i++) { + cumsum += probs[i]; + if (cumsum >= top_p) { + top_k = i + 1; + probs.resize(top_k); + logits_id.resize(top_k); + break; + } + } + + cumsum = 1.0/cumsum; + for (int i = 0; i < (int) probs.size(); i++) { + probs[i] *= cumsum; + } + } + +// printf("\n"); +// for (int i = 0; i < (int) probs.size(); i++) { +// for (int i = 0; i < 10; i++) { +// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]); +// } + + std::discrete_distribution<> dist(probs.begin(), probs.end()); + int idx = dist(rng); + + return logits_id[idx].second; + +} + +struct ggml_tensor * get_tensor_ex( struct ggml_context * ctx, std::string name){ + + struct ggml_tensor * cur = ggml_get_tensor(ctx, name.c_str()); + if( cur == NULL ) { + fprintf(stdout, "%s: tensor '%s' not found!\n", __func__, name.c_str()); + } else { +// fprintf(stdout, "%s: n_dims = %d, name = '%s'\n", __func__, cur->n_dims, cur->name); + } + + return cur; +} + +// load the model's weights from a file +bool gpt_neox_model_load(const std::string & fname, gpt_neox_model & model, gpt2bpe_vocab & vocab) { + printf("%s: loading model from '%s'..\n", __func__, fname.c_str()); + + model.ctx = NULL; + + struct gguf_init_params ggufparams = { + /*.no_alloc = */ false, + /*.ctx = */ &model.ctx, + }; + + auto & ggufctx = model.ggufctx; + + ggufctx = gguf_init_from_file(fname.c_str(), ggufparams); + + if (!ggufctx) { + fprintf(stderr, "%s: gguf_init_from_file() failed\n", __func__); + return false; + } + + fprintf(stdout, "%s: gguf version = %d\n", __func__, gguf_get_version(ggufctx)); + fprintf(stdout, "%s: gguf alignment = %zu\n", __func__, gguf_get_alignment(ggufctx)); + fprintf(stdout, "%s: gguf data offset = %zu\n", __func__, gguf_get_data_offset(ggufctx)); + + // print all kv + #if 0 + { + const int n_kv = gguf_get_n_kv(ggufctx); + + fprintf(stdout, "%s: n_kv: %d\n", __func__, n_kv); + + for (int i = 0; i < n_kv; ++i) { + const char * key = gguf_get_key(ggufctx, i); + + fprintf(stdout, "%s: kv[%d]: key = %s\n", __func__, i, key); + } + } + #endif + + // print some standard metadata + { + int keyidx; + + keyidx = gguf_find_key(ggufctx, "general.name"); + if (keyidx != -1) { fprintf(stdout, "%s: model name = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.description"); + if (keyidx != -1) { fprintf(stdout, "%s: model description = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.author"); + if (keyidx != -1) { fprintf(stdout, "%s: model author = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.license"); + if (keyidx != -1) { fprintf(stdout, "%s: model license = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.architecture"); + if (keyidx != -1) { fprintf(stdout, "%s: model architecture = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.file_type"); + if (keyidx != -1) { fprintf(stdout, "%s: model file type = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "gptneox.tensor_data_layout"); + if (keyidx != -1) { fprintf(stdout, "%s: model data layout = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + keyidx = gguf_find_key(ggufctx, "general.source.hugginface.repository"); + if (keyidx != -1) { fprintf(stdout, "%s: model source HF repo = %s\n", __func__, gguf_get_val_str(ggufctx, keyidx)); } + } + + // check required metadata + { + int keyidx; + + // check model architecture kv + keyidx = gguf_find_key(ggufctx, "general.architecture"); + if (keyidx != -1) { + if ( strcmp(gguf_get_val_str(ggufctx, keyidx), "gptneox") != 0) { + fprintf(stdout, "%s: model architecture not supported!\n", __func__); + return false; + } + } else { + fprintf(stdout, "%s: gguf model architecture not found!\n", __func__); + return false; + } + + } + + // load hparams + { + auto & hparams = model.hparams; + + bool ok = true; + int keyidx; + + if (ok) { keyidx = gguf_find_key(ggufctx, "gptneox.context_length"); + if (keyidx != -1) { hparams.n_ctx = gguf_get_val_u32(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "gptneox.embedding_length"); + if (keyidx != -1) { hparams.n_embd = gguf_get_val_u32(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "gptneox.attention.head_count"); + if (keyidx != -1) { hparams.n_head = gguf_get_val_u32(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "gptneox.block_count"); + if (keyidx != -1) { hparams.n_block = gguf_get_val_u32(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "gptneox.rope.dimension_count"); + if (keyidx != -1) { hparams.n_rot = gguf_get_val_u32(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "gptneox.use_parallel_residual"); + if (keyidx != -1) { hparams.par_res = gguf_get_val_bool(ggufctx, keyidx); } else { ok = false; } } + + if (ok) { keyidx = gguf_find_key(ggufctx, "gptneox.attention.layer_norm_epsilon"); + if (keyidx != -1) { hparams.norm_eps= gguf_get_val_f32(ggufctx, keyidx); } else { ok = false; } } + + if (!ok) { + fprintf(stderr, "%s: required hparam missing!\n", __func__); + return false; + } + + printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx); + printf("%s: n_embd = %d\n", __func__, hparams.n_embd); + printf("%s: n_head = %d\n", __func__, hparams.n_head); + printf("%s: n_block = %d\n", __func__, hparams.n_block); + printf("%s: n_rot = %d\n", __func__, hparams.n_rot); + printf("%s: par_res = %d\n", __func__, hparams.par_res); + printf("%s: norm_eps = %g\n", __func__, hparams.norm_eps); + + } + + // load vocab + { + auto & hparams = model.hparams; + + int keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.model"); + + if (keyidx != -1) { + if ( strcmp(gguf_get_val_str(ggufctx, keyidx), "gpt2") != 0) { + fprintf(stdout, "%s: tokenizer model not supported!\n", __func__); + return false; + } + } else { + fprintf(stdout, "%s: tokenizer model not found!\n", __func__); + return false; + } + + + int tokens_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.tokens"); + + if (tokens_keyidx == -1) { + fprintf(stdout, "%s: gpt2 tokenizer vocab not found!\n", __func__); + return false; + } + + int merges_keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.merges"); + + if (merges_keyidx == -1) { + fprintf(stdout, "%s: gpt2 tokenizer merges not found!\n", __func__); + return false; + } + + hparams.n_vocab = gguf_get_arr_n(ggufctx,tokens_keyidx); + hparams.n_merges = gguf_get_arr_n(ggufctx,merges_keyidx); + + fprintf(stdout, "%s: gpt2 tokenizer vocab = %zu\n", __func__, hparams.n_vocab); + fprintf(stdout, "%s: gpt2 tokenizer merges = %zu\n", __func__, hparams.n_merges); + + for (size_t i = 0; i < hparams.n_vocab; i++) { + std::string word = gguf_get_arr_str(ggufctx, tokens_keyidx, i); + +// printf("token %d = '%s'\n",i,word.c_str() ); + + vocab.token_to_id[word] = i; + vocab.id_to_token[i] = word; + + if( vocab.id_to_token[i] == "\n" ) { + vocab.linefeed_id = i; + } + } + + std::vector> bpe_merges; + + for (size_t i = 0; i < hparams.n_merges; i++) { + + std::string word = gguf_get_arr_str(ggufctx, merges_keyidx, i); + + // Split the merges + std::string first, second; + size_t pos = word.find(' ', 1); // Start the search from the second character + if (pos != std::string::npos) { + first = word.substr(0, pos); + second = word.substr(pos + 1); + } + + bpe_merges.push_back(std::make_pair(first, second)); + } + + vocab.populate_bpe_ranks(bpe_merges); + + + keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.bos_token_id"); if( keyidx != -1 ) { vocab.special_bos_id = (int32_t)gguf_get_val_u32(ggufctx, keyidx); } + keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.eos_token_id"); if( keyidx != -1 ) { vocab.special_eos_id = (int32_t)gguf_get_val_u32(ggufctx, keyidx); } + keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.unknown_token_id"); if( keyidx != -1 ) { vocab.special_unk_id = (int32_t)gguf_get_val_u32(ggufctx, keyidx); } + keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.separator_token_id"); if( keyidx != -1 ) { vocab.special_sep_id = (int32_t)gguf_get_val_u32(ggufctx, keyidx); } + keyidx = gguf_find_key(ggufctx, "tokenizer.ggml.padding_token_id"); if( keyidx != -1 ) { vocab.special_pad_id = (int32_t)gguf_get_val_u32(ggufctx, keyidx); } + + if( vocab.special_bos_id != -1 ) { fprintf(stdout, "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].c_str() ); } + if( vocab.special_eos_id != -1 ) { fprintf(stdout, "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].c_str() ); } + if( vocab.special_unk_id != -1 ) { fprintf(stdout, "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].c_str() ); } + if( vocab.special_sep_id != -1 ) { fprintf(stdout, "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].c_str() ); } + if( vocab.special_pad_id != -1 ) { fprintf(stdout, "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].c_str() ); } + if( vocab.linefeed_id != -1 ) { fprintf(stdout, "%s: LF token = %d\n", __func__, vocab.linefeed_id ); } + } + + + auto & ctx = model.ctx; + size_t ctx_size = ggml_get_mem_size(ctx); + + printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0)); + + // print tensor info + #if 0 + { + const int n_tensors = gguf_get_n_tensors(ggufctx); + + fprintf(stdout, "%s: n_tensors: %d\n", __func__, n_tensors); + + for (int i = 0; i < n_tensors; ++i) { + const char * name = gguf_get_tensor_name (ggufctx, i); + const size_t offset = gguf_get_tensor_offset(ggufctx, i); + + fprintf(stdout, "%s: tensor[%d]: name = %s, offset = %zu\n", __func__, i, name, offset); + } + } + #endif + + // prepare memory for the weights + { + const int n_block = model.hparams.n_block; + + model.blocks.resize(n_block); + + model.wte = ggml_get_tensor(ctx, "token_embd.weight"); + model.ln_f_g = ggml_get_tensor(ctx, "output_norm.weight"); + model.ln_f_b = ggml_get_tensor(ctx, "output_norm.bias"); + model.lmh_g = ggml_get_tensor(ctx, "output.weight"); + + // map by name + model.tensors["token_embd.weight"] = model.wte; + model.tensors["output_norm.weight"] = model.ln_f_g; + model.tensors["output_norm.bias"] = model.ln_f_b; + model.tensors["output.weight"] = model.lmh_g; + + for (int i = 0; i < n_block; ++i) { + auto & block = model.blocks[i]; + + std::string blocknamestart = "blk." + std::to_string(i) + "."; + + block.ln_1_g = get_tensor_ex(ctx, blocknamestart + "attn_norm.weight" ); + block.ln_1_b = get_tensor_ex(ctx, blocknamestart + "attn_norm.bias" ); + + block.c_attn_attn_w = get_tensor_ex(ctx, blocknamestart + "attn_qkv.weight" ); + block.c_attn_attn_b = get_tensor_ex(ctx ,blocknamestart + "attn_qkv.bias" ); + + block.c_attn_proj_w = get_tensor_ex(ctx, blocknamestart + "attn_output.weight" ); + block.c_attn_proj_b = get_tensor_ex(ctx, blocknamestart + "attn_output.bias" ); + + block.ln_2_g = get_tensor_ex(ctx, blocknamestart + "ffn_norm.weight" ); + block.ln_2_b = get_tensor_ex(ctx, blocknamestart + "ffn_norm.bias"); + + block.c_mlp_fc_w = get_tensor_ex(ctx, blocknamestart + "ffn_up.weight" ); + block.c_mlp_fc_b = get_tensor_ex(ctx, blocknamestart + "ffn_up.bias" ); + + block.c_mlp_proj_w = get_tensor_ex(ctx, blocknamestart + "ffn_down.weight" ); + block.c_mlp_proj_b = get_tensor_ex(ctx, blocknamestart + "ffn_down.bias" ); + + // map by name + model.tensors[blocknamestart + "attn_norm.weight"] = block.ln_1_g; + model.tensors[blocknamestart + "attn_norm.bias"] = block.ln_1_b; + + model.tensors[blocknamestart + "attn_qkv.weight"] = block.c_attn_attn_w; + model.tensors[blocknamestart + "attn_qkv.bias"] = block.c_attn_attn_b; + + model.tensors[blocknamestart + "attn_output.weight"] = block.c_attn_proj_w; + model.tensors[blocknamestart + "attn_output.bias"] = block.c_attn_proj_b; + + model.tensors[blocknamestart + "ffn_norm.weight"] = block.ln_2_g; + model.tensors[blocknamestart + "ffn_norm.bias"] = block.ln_2_b; + + model.tensors[blocknamestart + "ffn_up.weight"] = block.c_mlp_fc_w; + model.tensors[blocknamestart + "ffn_up.bias"] = block.c_mlp_fc_b; + + model.tensors[blocknamestart + "ffn_down.weight"] = block.c_mlp_proj_w; + model.tensors[blocknamestart + "ffn_down.bias"] = block.c_mlp_proj_b; + } + } + + // key + value memory + { + const auto & kvctx = model.kvctx; + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_block = hparams.n_block; + const int n_ctx = hparams.n_ctx; + + const int64_t n_mem = n_block*n_ctx; + const int64_t n_elements = n_embd*n_mem; + + // create the ggml context + { + struct ggml_init_params params = { + /*.mem_size =*/ size_t(n_elements*4+ggml_tensor_overhead()*2), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + model.kvctx = ggml_init(params); + if (!model.kvctx) { + fprintf(stderr, "%s: kv ggml_init() failed\n", __func__); + return false; + } + + } + + + model.memory_k = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_elements); + model.memory_v = ggml_new_tensor_1d(kvctx, GGML_TYPE_F16, n_elements); + + const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v); + + printf("%s: memory_size = %8.2f MB, n_mem = %" PRId64 "\n", __func__, memory_size/1024.0/1024.0, n_mem); + } + + return true; +} + + +// feed-forward network +ggml_tensor * gpt_neox_ff( + const gpt_neox_block &block, + ggml_context * ctx0, + ggml_tensor * inp) { + + ggml_tensor * cur = ggml_norm(ctx0, inp); + + cur = ggml_add(ctx0, ggml_mul(ctx0, ggml_repeat(ctx0, block.ln_2_g, cur), cur), ggml_repeat(ctx0, block.ln_2_b, cur)); + cur = ggml_mul_mat(ctx0, block.c_mlp_fc_w, cur); + cur = ggml_add(ctx0, ggml_repeat(ctx0, block.c_mlp_fc_b, cur), cur); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // projection + // cur = proj_w*cur + proj_b + cur = ggml_mul_mat(ctx0, block.c_mlp_proj_w, cur); + + cur = ggml_add(ctx0, ggml_repeat(ctx0, block.c_mlp_proj_b, cur), cur); + return cur; +} + +// evaluate the transformer +// +// - model: the model +// - n_threads: number of threads to use +// - n_past: the context size so far +// - embd_inp: the embeddings of the tokens in the context +// - embd_w: the predicted logits for the next token +// +bool gpt_neox_eval( + const gpt_neox_model & model, + const int n_threads, + const int n_past, + const std::vector & embd_inp, + std::vector & embd_w, + size_t & mem_per_token) { + const int N = embd_inp.size(); + + const auto & hparams = model.hparams; + + const int n_embd = hparams.n_embd; + const int n_block = hparams.n_block; + const int n_ctx = hparams.n_ctx; + const int n_head = hparams.n_head; + const int n_vocab = hparams.n_vocab; + const int n_rot = hparams.n_rot; + + static size_t buf_size = 256u*1024*1024; + static void * buf = malloc(buf_size); + + // use 2 scratch buffers + // TODO: very hacky solution - reimplement in a more elegant way + static size_t scr0_size = 256u*1024*1024; + static void * scr0 = malloc(scr0_size); + + static size_t scr1_size = 256u*1024*1024; + static void * scr1 = malloc(scr1_size); + + if (mem_per_token > 0 && mem_per_token*N > buf_size) { + const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead + //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); + + // reallocate + buf_size = buf_size_new; + buf = realloc(buf, buf_size); + if (buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + return false; + } + } + + struct ggml_init_params params = { + /*.mem_size =*/ buf_size, + /*.mem_buffer =*/ buf, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ctx0 = ggml_init(params); + struct ggml_cgraph gf = {}; + + struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); + memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); + + + // wte + struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte, embd); + + for (int il = 0; il < n_block; ++il) { + struct ggml_tensor * cur; + + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + + // self-attention + { + { + cur = ggml_norm(ctx0, inpL); + + cur = ggml_add(ctx0, + ggml_mul(ctx0, ggml_repeat(ctx0, model.blocks[il].ln_1_g, cur), cur), + ggml_repeat(ctx0, model.blocks[il].ln_1_b, cur)); + } + + // compute QKV + { + + cur = ggml_mul_mat(ctx0, model.blocks[il].c_attn_attn_w, cur); + cur = ggml_add(ctx0, ggml_repeat(ctx0, model.blocks[il].c_attn_attn_b, cur), cur); + } + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 0*sizeof(float)*n_embd/n_head)); + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 1*sizeof(float)*n_embd/n_head)); + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd/n_head, n_head, N, cur->nb[1]/n_head, cur->nb[1], 2*sizeof(float)*n_embd/n_head)); + + // using mode = 2 for GPT-NeoX mode + Qcur = ggml_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0); + Kcur = ggml_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0); + + // store key and value to memory + { + Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd, N)); + + struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past)); + struct ggml_tensor * v = ggml_view_2d(ctx0, model.memory_v, N, n_embd, + ( n_ctx)*ggml_element_size(model.memory_v), + (il*n_ctx)*ggml_element_size(model.memory_v)*n_embd + n_past*ggml_element_size(model.memory_v)); + + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + } + + // Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3) + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + + // K = Kmem.view(n_embd/n_head, n_head, n_past + N).permute(0, 2, 1, 3) + struct ggml_tensor * K = + ggml_permute(ctx0, + ggml_reshape_3d(ctx0, + ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd), + n_embd/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + + // KQ_scaled = KQ / sqrt(n_embd/n_head) + struct ggml_tensor * KQ_scaled = + ggml_scale_inplace(ctx0, + KQ, + ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)) + ); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + + // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() + struct ggml_tensor * V = + ggml_view_3d(ctx0, model.memory_v, + n_past + N, n_embd/n_head, n_head, + n_ctx*ggml_element_size(model.memory_v), + n_ctx*ggml_element_size(model.memory_v)*n_embd/n_head, + il*n_ctx*ggml_element_size(model.memory_v)*n_embd); + + // KQV = transpose(V) * KQ_soft_max + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + + // cur = KQV_merged.contiguous().view(n_embd, N) + cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); + + // projection + { + cur = ggml_mul_mat(ctx0, model.blocks[il].c_attn_proj_w, cur); + cur = ggml_add(ctx0, ggml_repeat(ctx0, model.blocks[il].c_attn_proj_b, cur), cur); + } + } + + ggml_set_scratch(ctx0, { 0, scr1_size, scr1, }); + + if (hparams.par_res == 0) { + struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpL); + + cur = gpt_neox_ff(model.blocks[il], ctx0, inpFF); + + // input for next layer + inpL = ggml_add(ctx0, cur, inpFF); + } else { + struct ggml_tensor * inpFF = cur; + + // this is independent of the self-attention result, so it could be done in parallel to the self-attention + // note here we pass inpL instead of cur + cur = gpt_neox_ff(model.blocks[il], ctx0, inpL); + + // layer input + FF + cur = ggml_add(ctx0, cur, inpFF); + + // input for next layer + inpL = ggml_add(ctx0, cur, inpL); + } + } + + ggml_set_scratch(ctx0, { 0, scr0_size, scr0, }); + + // norm + { + inpL = ggml_norm(ctx0, inpL); + + // inpL = ln_f_g*inpL + ln_f_b + inpL = ggml_add(ctx0, + ggml_mul(ctx0, + ggml_repeat(ctx0, model.ln_f_g, inpL), + inpL), + ggml_repeat(ctx0, model.ln_f_b, inpL)); + } + + ggml_set_scratch(ctx0, { 0, 0, nullptr, }); + + // lm_head + { + inpL = ggml_mul_mat(ctx0, model.lmh_g, inpL); + + //inpL = ggml_add(ctx0, + // ggml_repeat(ctx0, model.lmh_b, inpL), + // inpL); + } + + // logits -> probs + //inpL = ggml_soft_max_inplace(ctx0, inpL); + + // run the computation + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute_with_ctx(ctx0, &gf, n_threads); + + //if (n_past%100 == 0) { + // ggml_graph_print (&gf); + // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + //} + + //embd_w.resize(n_vocab*N); + //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0)/N; + } + //printf("used_mem = %zu\n", ggml_used_mem(ctx0)); + + ggml_free(ctx0); + + return true; +} + +int main(int argc, char ** argv) { + ggml_time_init(); + + const int64_t t_main_start_us = ggml_time_us(); + + gpt_params params; + + if (gpt_params_parse(argc, argv, params) == false) { + return 1; + } + + int64_t t_load_us = 0; + + gpt2bpe_vocab vocab; + gpt_neox_model model; + + // load the model + { + const int64_t t_start_us = ggml_time_us(); + + if (!gpt_neox_model_load(params.model, model, vocab)) { + fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); + return 1; + } + + t_load_us = ggml_time_us() - t_start_us; + + } + + if (params.seed < 0) { + params.seed = time(NULL); + } + + if (params.top_k == 0) { + params.top_k = model.hparams.n_vocab; + } + + printf("%s: seed = %d\n", __func__, params.seed); + printf("%s: temp = %.3f\n", __func__, params.temp); + printf("%s: top_k = %d\n", __func__, params.top_k); + printf("%s: top_p = %.3f\n", __func__, params.top_p); + printf("%s: repeat_last_n = %d\n", __func__, params.repeat_last_n); + printf("%s: repeat_penalty = %.3f\n", __func__, params.repeat_penalty); + + std::mt19937 rng(params.seed); + + if (params.prompt.empty()) { + params.prompt = "Once upon"; + } + + std::vector last_n_tokens(model.hparams.n_ctx); + std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); + + int n_past = 0; + + int64_t t_sample_us = 0; + int64_t t_predict_us = 0; + + std::vector logits; + + // tokenize the prompt + std::vector embd_inp = gpt2bpe_tokenize(vocab, params.prompt,false, false); + + params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); + + printf("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); +// for (size_t i = 0; i < embd_inp.size(); i++) { +// printf("%s: token[%zu] = %6d, %s\n", __func__, i, embd_inp[i], vocab.id_to_token[embd_inp[i]].c_str()); +// } + + if( model.hparams.n_ctx < params.n_predict+embd_inp.size() ) { + params.n_predict = model.hparams.n_ctx-embd_inp.size(); + } + + printf("%s: n_predict = %d\n", __func__, params.n_predict); + printf("\n"); + + std::vector embd; + + // determine the required inference memory per token: + size_t mem_per_token = 0; + gpt_neox_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + + for (size_t i = embd.size(); i < embd_inp.size() + params.n_predict; i++) { + // predict + if (embd.size() > 0) { + const int64_t t_start_us = ggml_time_us(); + + if (!gpt_neox_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + printf("Failed to predict\n"); + return 1; + } + + t_predict_us += ggml_time_us() - t_start_us; + } + + n_past += embd.size(); + embd.clear(); + + if (i >= embd_inp.size()) { + // sample next token + const int top_k = params.top_k; + const float top_p = params.top_p; + const float temp = params.temp; + const int repeat_last_n = params.repeat_last_n; + const float repeat_penalty = params.repeat_penalty; + + const int n_vocab = model.hparams.n_vocab; + + gpt2bpe_vocab::id id = 0; + + { + const int64_t t_start_sample_us = ggml_time_us(); + + id = sample_top_k_top_p_repeat(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens.data(), last_n_tokens.size(), top_k, top_p, temp, repeat_last_n, repeat_penalty, rng); + + last_n_tokens.erase(last_n_tokens.begin()); + last_n_tokens.push_back(id); + + t_sample_us += ggml_time_us() - t_start_sample_us; + } + + // add it to the context + embd.push_back(id); + } else { + // if here, it means we are still processing the input prompt + for (size_t k = i; k < embd_inp.size(); k++) { + embd.push_back(embd_inp[k]); + if (embd.size() > params.n_batch) { + break; + } + } + i += embd.size() - 1; + } + + // display text + for (auto id : embd) { + printf("%s", vocab.id_to_token[id].c_str() ); + } + fflush(stdout); + + // end of text token + if (vocab.special_eos_id != -1 && embd.back() == vocab.special_eos_id) { + break; + } + } + + // report timing + { + const int64_t t_main_end_us = ggml_time_us(); + + printf("\n\n"); + printf("%s: mem per token = %8zu bytes\n", __func__, mem_per_token); + printf("%s: load time = %8.2f ms\n", __func__, t_load_us/1000.0f); + printf("%s: sample time = %8.2f ms\n", __func__, t_sample_us/1000.0f); + printf("%s: predict time = %8.2f ms / %.2f ms per token\n", __func__, t_predict_us/1000.0f, t_predict_us/1000.0f/n_past); + printf("%s: total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); + } + + ggml_free(model.ctx); + + return 0; +} diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 266c8eab3..d11fff288 100755 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -606,6 +606,8 @@ const std::string test::cpu_info = get_cpu_info(); const std::string test::gpu_info = get_gpu_info(); struct printer { + virtual ~printer() {} + FILE * fout; virtual void print_header(const cmd_params & params) { (void) params; }; virtual void print_test(const test & t) = 0; @@ -849,7 +851,7 @@ struct sql_printer : public printer { }; static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { - std::vector tokens(n_batch, llama_token_bos()); + std::vector tokens(n_batch, llama_token_bos(ctx)); int n_processed = 0; while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); @@ -859,7 +861,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat } static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { - llama_token token = llama_token_bos(); + llama_token token = llama_token_bos(ctx); for (int i = 0; i < n_gen; i++) { llama_eval(ctx, &token, 1, n_past + i, n_threads); } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index a632bea1c..388e1f7d7 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -143,7 +143,7 @@ int main(int argc, char ** argv) { { fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx); - const std::vector tmp(params.n_batch, llama_token_bos()); + const std::vector tmp(params.n_batch, llama_token_bos(ctx)); llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads); } @@ -191,10 +191,6 @@ int main(int argc, char ** argv) { // tokenize the prompt std::vector embd_inp; - - // Add a space in front of the first character to match OG llama tokenizer behavior - params.prompt.insert(0, 1, ' '); - if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { embd_inp = ::llama_tokenize(ctx, params.prompt, true); } else { @@ -270,15 +266,12 @@ int main(int argc, char ** argv) { params.interactive = true; } - // determine newline token - auto llama_token_newline = ::llama_tokenize(ctx, "\n", false); - if (params.verbose_prompt) { fprintf(stderr, "\n"); fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); + fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i]).c_str()); } if (ctx_guidance) { @@ -286,14 +279,14 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s: negative prompt: '%s'\n", __func__, params.cfg_negative_prompt.c_str()); fprintf(stderr, "%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size()); for (int i = 0; i < (int) guidance_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i])); + fprintf(stderr, "%6d -> '%s'\n", guidance_inp[i], llama_token_to_str(ctx, guidance_inp[i]).c_str()); } } if (params.n_keep > 0) { fprintf(stderr, "%s: static prompt based on n_keep: '", __func__); for (int i = 0; i < params.n_keep; i++) { - fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i])); + fprintf(stderr, "%s", llama_token_to_str(ctx, embd_inp[i]).c_str()); } fprintf(stderr, "'\n"); } @@ -311,7 +304,7 @@ int main(int argc, char ** argv) { auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; }; - SetConsoleCtrlHandler(static_cast(console_ctrl_handler), true); + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif fprintf(stderr, "%s: interactive mode on.\n", __func__); @@ -352,10 +345,9 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); { - auto it = params.logit_bias.find(llama_token_eos()); + auto it = params.logit_bias.find(llama_token_eos(ctx)); if (it != params.logit_bias.end() && it->second == -INFINITY) { - fprintf(stderr, - "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); + fprintf(stderr, "%s: warning: EOS token is disabled, which will cause most grammars to fail\n", __func__); } } @@ -405,7 +397,7 @@ int main(int argc, char ** argv) { // do one empty run to warm up the model { - const std::vector tmp = { llama_token_bos(), }; + const std::vector tmp = { llama_token_bos(ctx), }; llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); llama_reset_timings(ctx); } @@ -589,7 +581,7 @@ int main(int argc, char ** argv) { } // Apply penalties - float nl_logit = logits[llama_token_nl()]; + float nl_logit = logits[llama_token_nl(ctx)]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx); llama_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, @@ -598,7 +590,7 @@ int main(int argc, char ** argv) { last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, alpha_frequency, alpha_presence); if (!penalize_nl) { - logits[llama_token_nl()] = nl_logit; + logits[llama_token_nl(ctx)] = nl_logit; } if (grammar != NULL) { @@ -662,7 +654,7 @@ int main(int argc, char ** argv) { // display text if (input_echo) { for (auto id : embd) { - printf("%s", llama_token_to_str(ctx, id)); + printf("%s", llama_token_to_str(ctx, id).c_str()); } fflush(stdout); } @@ -704,7 +696,7 @@ int main(int argc, char ** argv) { } // deal with end of text token in interactive mode - if (last_n_tokens.back() == llama_token_eos()) { + if (last_n_tokens.back() == llama_token_eos(ctx)) { if (params.interactive) { if (params.antiprompt.size() != 0) { // tokenize and inject first reverse prompt @@ -728,7 +720,7 @@ int main(int argc, char ** argv) { } if (params.input_prefix_bos) { - embd_inp.push_back(llama_token_bos()); + embd_inp.push_back(llama_token_bos(ctx)); } std::string buffer; @@ -782,8 +774,7 @@ int main(int argc, char ** argv) { if (grammar != NULL) { llama_grammar_free(grammar); - std::vector grammar_rules( - parsed_grammar.c_rules()); + std::vector grammar_rules( parsed_grammar.c_rules()); grammar = llama_grammar_init( grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); @@ -794,7 +785,7 @@ int main(int argc, char ** argv) { } // end of text token - if (!embd.empty() && embd.back() == llama_token_eos() && !(params.instruct || params.interactive)) { + if (!embd.empty() && embd.back() == llama_token_eos(ctx) && !(params.instruct || params.interactive)) { fprintf(stderr, " [end of text]\n"); break; } diff --git a/examples/metal/metal.cpp b/examples/metal/metal.cpp index 7438defde..c05a4fa93 100644 --- a/examples/metal/metal.cpp +++ b/examples/metal/metal.cpp @@ -2,7 +2,7 @@ // // - First, export a LLaMA graph: // -// $ ./bin/main -m ../models/7B/ggml-model-q4_0.bin --export +// $ ./bin/main -m ../models/7B/ggml-model-q4_0.gguf --export // // - Run this tool to evaluate the exported graph: // diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 2409db69f..f3c045aec 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -64,7 +64,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // add BOS token for the first batch of each chunk if (j == 0) { - tokens[batch_start] = llama_token_bos(); + tokens[batch_start] = llama_token_bos(ctx); } if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 6aa06ec8f..06ce18f09 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -24,7 +24,7 @@ #endif struct quantize_stats_params { - std::string model = "models/7B/ggml-model-f16.bin"; + std::string model = "models/7B/ggml-model-f16.gguf"; bool verbose = false; bool per_layer_stats = false; bool print_histogram = false; diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 744f549c5..f628d0642 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -68,10 +68,10 @@ bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftype, std: } // usage: -// ./quantize [--allow-requantize] [--leave-output-tensor] models/llama/ggml-model.bin [models/llama/ggml-model-quant.bin] type [nthreads] +// ./quantize [--allow-requantize] [--leave-output-tensor] models/llama/ggml-model.gguf [models/llama/ggml-model-quant.gguf] type [nthreads] // void usage(const char * executable) { - fprintf(stderr, "usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.bin [model-quant.bin] type [nthreads]\n\n", executable); + fprintf(stderr, "usage: %s [--help] [--allow-requantize] [--leave-output-tensor] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable); fprintf(stderr, " --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); fprintf(stderr, " --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); fprintf(stderr, "\nAllowed quantization types:\n"); @@ -118,8 +118,8 @@ int main(int argc, char ** argv) { if (pos != std::string::npos) { fpath = fname_inp.substr(0, pos + 1); } - // export as [inp path]/ggml-model-[ftype].bin - fname_out = fpath + "ggml-model-" + ftype_str + ".bin"; + // export as [inp path]/ggml-model-[ftype].gguf + fname_out = fpath + "ggml-model-" + ftype_str + ".gguf"; arg_idx++; } else { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 61c71c358..3db61b754 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -26,7 +26,6 @@ int main(int argc, char ** argv) { auto lparams = llama_context_default_params(); lparams.n_ctx = params.n_ctx; - lparams.n_gqa = params.n_gqa; lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; lparams.use_mmap = params.use_mmap; @@ -45,9 +44,8 @@ int main(int argc, char ** argv) { llama_free_model(model); return 1; } - auto tokens = std::vector(params.n_ctx); - auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true); - + auto tokens = llama_tokenize(ctx, params.prompt.c_str(), true); + auto n_prompt_tokens = tokens.size(); if (n_prompt_tokens < 1) { fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); llama_free(ctx); @@ -92,7 +90,7 @@ int main(int argc, char ** argv) { auto next_token_str = llama_token_to_str(ctx, next_token); last_n_tokens_data.push_back(next_token); - printf("%s", next_token_str); + printf("%s", next_token_str.c_str()); if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); @@ -152,7 +150,7 @@ int main(int argc, char ** argv) { auto next_token_str = llama_token_to_str(ctx2, next_token); last_n_tokens_data.push_back(next_token); - printf("%s", next_token_str); + printf("%s", next_token_str.c_str()); if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); diff --git a/examples/server/README.md b/examples/server/README.md index 1559dd3f2..4d97db2e4 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -5,7 +5,7 @@ This example demonstrates a simple HTTP API server and a simple web front end to Command line options: - `--threads N`, `-t N`: Set the number of threads to use during computation. -- `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`). +- `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.gguf`). - `-m ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses. - `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference. The size may differ in other models, for example, baichuan models were build with a context of 4096. - `-ngl N`, `--n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance. @@ -48,15 +48,14 @@ To get started right away, run the following command, making sure to use the cor ### Unix-based systems (Linux, macOS, etc.): ```bash -./server -m models/7B/ggml-model.bin -c 2048 +./server -m models/7B/ggml-model.gguf -c 2048 ``` ### Windows: ```powershell -server.exe -m models\7B\ggml-model.bin -c 2048 +server.exe -m models\7B\ggml-model.gguf -c 2048 ``` - The above command will start a server that by default listens on `127.0.0.1:8080`. You can consume the endpoints with Postman or NodeJS with axios library. You can visit the web front end at the same url. diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 99660455a..a04f1910c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -279,7 +279,7 @@ struct llama_server_context grammar_parser::print_grammar(stderr, parsed_grammar); { - auto it = params.logit_bias.find(llama_token_eos()); + auto it = params.logit_bias.find(llama_token_eos(ctx)); if (it != params.logit_bias.end() && it->second == -INFINITY) { LOG_WARNING("EOS token is disabled, which will cause most grammars to fail", {}); } @@ -402,7 +402,7 @@ struct llama_server_context if (params.n_predict == 0) { has_next_token = false; - result.tok = llama_token_eos(); + result.tok = llama_token_eos(ctx); return result; } @@ -442,7 +442,7 @@ struct llama_server_context llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false}; // Apply penalties - float nl_logit = logits[llama_token_nl()]; + float nl_logit = logits[llama_token_nl(ctx)]; auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), params.n_ctx); llama_sample_repetition_penalty(ctx, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, @@ -452,7 +452,7 @@ struct llama_server_context last_n_repeat, alpha_frequency, alpha_presence); if (!penalize_nl) { - logits[llama_token_nl()] = nl_logit; + logits[llama_token_nl(ctx)] = nl_logit; } if (grammar != nullptr) { @@ -515,7 +515,7 @@ struct llama_server_context // decrement remaining sampling budget --n_remain; - if (!embd.empty() && embd.back() == llama_token_eos()) + if (!embd.empty() && embd.back() == llama_token_eos(ctx)) { // stopping_word = llama_token_to_str(ctx, embd.back()); has_next_token = false; @@ -652,8 +652,6 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, fprintf(stdout, " -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled"); fprintf(stdout, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); fprintf(stdout, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx); - fprintf(stdout, " -gqa N, --gqa N grouped-query attention factor (TEMP!!! use 8 for LLaMAv2 70B) (default: %d)\n", params.n_gqa); - fprintf(stdout, " -eps N, --rms-norm-eps N rms norm eps (TEMP!!! use 1e-5 for LLaMAv2) (default: %.1e)\n", params.rms_norm_eps); fprintf(stdout, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base); fprintf(stdout, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); fprintf(stdout, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); @@ -774,23 +772,6 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, } params.n_ctx = std::stoi(argv[i]); } - else if (arg == "-gqa" || arg == "--gqa") - { - if (++i >= argc) - { - invalid_param = true; - break; - } - params.n_gqa = std::stoi(argv[i]); - } - else if (arg == "-eps" || arg == "--rms-norm-eps") { - if (++i >= argc) - { - invalid_param = true; - break; - } - params.rms_norm_eps = std::stof(argv[i]); - } else if (arg == "--rope-freq-base") { if (++i >= argc) @@ -968,7 +949,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, static json format_generation_settings(llama_server_context &llama) { - const auto eos_bias = llama.params.logit_bias.find(llama_token_eos()); + const auto eos_bias = llama.params.logit_bias.find(llama_token_eos(llama.ctx)); const bool ignore_eos = eos_bias != llama.params.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second); @@ -1103,7 +1084,7 @@ static void parse_options_completion(const json &body, llama_server_context &lla llama.params.logit_bias.clear(); if (body.value("ignore_eos", false)) { - llama.params.logit_bias[llama_token_eos()] = -INFINITY; + llama.params.logit_bias[llama_token_eos(llama.ctx)] = -INFINITY; } const auto &logit_bias = body.find("logit_bias"); diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 97137a658..132f7fbf9 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -2,180 +2,129 @@ #define _GNU_SOURCE #endif -#include "common.h" -#include "llama.h" #include "build-info.h" -#include -#include +#include "common.h" +#include "llama.h" + #include #include -#include -#include -#include -#include #include #include -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) -#include -#include -#elif defined (_WIN32) -#define WIN32_LEAN_AND_MEAN -#define NOMINMAX -#include -#include -#endif - - - -int main(int argc, char ** argv) -{ +int main(int argc, char ** argv) { gpt_params params; - //--------------------------------- - // Print help : - //--------------------------------- - - if ( argc == 1 || argv[1][0] == '-' ) - { - printf( "usage: %s MODEL_PATH [PROMPT]\n" , argv[0] ); + if (argc == 1 || argv[1][0] == '-') { + printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); return 1 ; } - //--------------------------------- - // Load parameters : - //--------------------------------- - - if ( argc >= 2 ) - { + if (argc >= 2) { params.model = argv[1]; } - if ( argc >= 3 ) - { + if (argc >= 3) { params.prompt = argv[2]; } - if ( params.prompt.empty() ) - { + if (params.prompt.empty()) { params.prompt = "Hello my name is"; } - //--------------------------------- - // Init LLM : - //--------------------------------- + // init LLM llama_backend_init(params.numa); - llama_model * model; - llama_context * ctx; + llama_context_params ctx_params = llama_context_default_params(); - std::tie(model, ctx) = llama_init_from_gpt_params( params ); + llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params); - if ( model == NULL ) - { - fprintf( stderr , "%s: error: unable to load model\n" , __func__ ); + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); return 1; } - //--------------------------------- - // Tokenize the prompt : - //--------------------------------- + llama_context * ctx = llama_new_context_with_model(model, ctx_params); + + // tokenize the prompt std::vector tokens_list; - tokens_list = ::llama_tokenize( ctx , params.prompt , true ); + tokens_list = ::llama_tokenize(ctx, params.prompt, true); - const int max_context_size = llama_n_ctx( ctx ); - const int max_tokens_list_size = max_context_size - 4 ; + const int max_context_size = llama_n_ctx(ctx); + const int max_tokens_list_size = max_context_size - 4; - if ( (int)tokens_list.size() > max_tokens_list_size ) - { - fprintf( stderr , "%s: error: prompt too long (%d tokens, max %d)\n" , - __func__ , (int)tokens_list.size() , max_tokens_list_size ); + if ((int) tokens_list.size() > max_tokens_list_size) { + fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size); return 1; } - fprintf( stderr, "\n\n" ); + fprintf(stderr, "\n\n"); - // Print the tokens from the prompt : - - for( auto id : tokens_list ) - { - printf( "%s" , llama_token_to_str( ctx , id ) ); + for (auto id : tokens_list) { + fprintf(stderr, "%s", llama_token_to_str(ctx, id).c_str()); } - fflush(stdout); + fflush(stderr); - - //--------------------------------- - // Main prediction loop : - //--------------------------------- + // main loop // The LLM keeps a contextual cache memory of previous token evaluation. // Usually, once this cache is full, it is required to recompute a compressed context based on previous // tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist // example, we will just stop the loop once this cache is full or once an end of stream is detected. - while ( llama_get_kv_cache_token_count( ctx ) < max_context_size ) - { - //--------------------------------- - // Evaluate the tokens : - //--------------------------------- + const int n_gen = std::min(32, max_context_size); - if ( llama_eval( ctx , tokens_list.data() , int(tokens_list.size()) , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) ) - { - fprintf( stderr, "%s : failed to eval\n" , __func__ ); + while (llama_get_kv_cache_token_count(ctx) < n_gen) { + // evaluate the transformer + + if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } tokens_list.clear(); - //--------------------------------- - // Select the best prediction : - //--------------------------------- + // sample the next token llama_token new_token_id = 0; - auto logits = llama_get_logits( ctx ); - auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens) + auto logits = llama_get_logits(ctx); + auto n_vocab = llama_n_vocab(ctx); std::vector candidates; - candidates.reserve( n_vocab ); + candidates.reserve(n_vocab); - for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ ) - { - candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } ); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f }); } llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - // Select it using the "Greedy sampling" method : - new_token_id = llama_sample_token_greedy( ctx , &candidates_p ); - + new_token_id = llama_sample_token_greedy(ctx , &candidates_p); // is it an end of stream ? - if ( new_token_id == llama_token_eos() ) - { + if (new_token_id == llama_token_eos(ctx)) { fprintf(stderr, " [end of text]\n"); break; } - // Print the new token : - printf( "%s" , llama_token_to_str( ctx , new_token_id ) ); - fflush( stdout ); + // print the new token : + printf("%s", llama_token_to_str(ctx, new_token_id).c_str()); + fflush(stdout); - // Push this new token for next evaluation : - tokens_list.push_back( new_token_id ); + // push this new token for next evaluation + tokens_list.push_back(new_token_id); + } - } // wend of main loop - - llama_free( ctx ); - llama_free_model( model ); + llama_free(ctx); + llama_free_model(model); llama_backend_free(); + fprintf(stderr, "\n\n"); + return 0; } - -// EOF diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 54dc2beed..31d6620a2 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1,4 +1,5 @@ #include "ggml.h" +#include "common.h" #include "llama.h" #include #include @@ -16,7 +17,7 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -static const float rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; +static const float rms_norm_eps = 1e-5f; struct random_normal_distribution { std::mt19937 gen; @@ -169,14 +170,16 @@ struct ggml_tensor * randomize_tensor_uniform(struct ggml_tensor * tensor, struc struct llama_vocab { using id = int32_t; using token = std::string; + using ttype = llama_token_type; - struct token_score { - token tok; + struct token_data { + token text; float score; + ttype type; }; std::unordered_map token_to_id; - std::vector id_to_token; + std::vector id_to_token; }; struct my_llama_hparams { @@ -1961,7 +1964,7 @@ void print_matrix(struct ggml_tensor * probs) { void print_token(struct llama_context * ctx, llama_token token) { - printf("%s", llama_token_to_str(ctx, token)); + printf("%s", llama_token_to_str(ctx, token).c_str()); } void print_tokens(struct llama_context* ctx, struct ggml_tensor * tokens) { @@ -1995,7 +1998,7 @@ void print_tokens_batch(struct llama_context* ctx, struct ggml_tensor * tokens) } } -void get_example_targets(const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { +void get_example_targets(struct llama_context * lctx, const int * train_samples, size_t n_train_samples, const llama_token * train_data, size_t n_train_data, int example_id, struct ggml_tensor * tokens_input, struct ggml_tensor * target_logits, struct ggml_tensor * target_probs) { int n_tokens = tokens_input->ne[0]; int n_vocab = target_logits->ne[0]; @@ -2004,7 +2007,7 @@ void get_example_targets(const int * train_samples, size_t n_train_samples, cons ggml_set_f32(target_logits, -1.0f/n_vocab); ggml_set_f32(target_probs, 0.0f); - ggml_set_i32_1d(tokens_input, 0, llama_token_bos()); + ggml_set_i32_1d(tokens_input, 0, llama_token_bos(lctx)); for (int i=1; in_dims == 2); GGML_ASSERT(target_logits->n_dims == 3); GGML_ASSERT(target_probs->n_dims == 3); @@ -2035,7 +2038,7 @@ void get_example_targets_batch(struct llama_context * /*lctx*/, const int * trai size_t sample = train_samples[(example_id*n_batch + k) % n_train_samples]; GGML_ASSERT(sample+n_tokens-1 < n_train_data); - set_i32_2d(tokens_input, 0, k, llama_token_bos()); + set_i32_2d(tokens_input, 0, k, llama_token_bos(lctx)); for (int i=1; i= 0) { - out.resize(n_tokens); + int n_tokens = llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); + if (n_tokens < 0) { + out.resize(-n_tokens); + llama_tokenize(lctx, buf.data(), out.data(), out.size(), false); } bool verify = false; @@ -2200,17 +2202,17 @@ int tokenize_file(struct llama_context * lctx, const char * filename, std::vecto const char * in = buf.data(); const char * end = buf.data() + buf.size(); for (int i = 0; i < (int) out.size(); ++i) { - const char * s = llama_token_to_str(lctx, out[i]); - int len = strlen(s); + std::string s = llama_token_to_str(lctx, out[i]); + int len = s.length(); if (in >= end) { printf("%s: unexpected end of original text.\n", __func__); break; } - const bool matches = (strncmp(in, s, len) == 0); + const bool matches = (strncmp(in, s.c_str(), len) == 0); if (matches) { in += len; } else { - printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s); + printf("%s: mismatch: expected '%s', but got '%s'\n", __func__, std::string(in, len).c_str(), s.c_str()); } } } @@ -2294,7 +2296,7 @@ llama_token sample(struct my_llama_sampler * sampler, float * logits, const llam const auto params = sampler->params; // Apply penalties - const float nl_logit = logits[llama_token_nl()]; + const float nl_logit = logits[llama_token_nl(ctx)]; const int n_last = std::min(std::min(n_last_tokens, params.repeat_last_n), sampler->n_ctx); @@ -2313,7 +2315,7 @@ llama_token sample(struct my_llama_sampler * sampler, float * logits, const llam params.alpha_presence); if (!params.penalize_nl) { - logits[llama_token_nl()] = nl_logit; + logits[llama_token_nl(ctx)] = nl_logit; } llama_token token = 0; @@ -2612,42 +2614,45 @@ void save_as_llama_model(struct llama_vocab * vocab, struct my_llama_model * mod return; } - // write_magic - file.write_u32(LLAMA_FILE_MAGIC); // magic - file.write_u32(LLAMA_FILE_VERSION); // version - // write_hparams - file.write_u32(model->hparams.n_vocab); - file.write_u32(model->hparams.n_embd); - file.write_u32(model->hparams.n_mult); - file.write_u32(model->hparams.n_head); - file.write_u32(model->hparams.n_layer); - file.write_u32(model->hparams.n_rot); - file.write_u32(LLAMA_FTYPE_ALL_F32); - // write_vocab - uint32_t n_vocab = model->hparams.n_vocab; - for (uint32_t i = 0; i < n_vocab; i++) { - const auto & token_score = vocab->id_to_token.at(i); - file.write_u32((uint32_t) token_score.tok.size()); - file.write_raw(token_score.tok.data(), token_score.tok.size()); - file.write_raw(&token_score.score, sizeof(token_score.score)); - } - // write tensors - write_tensor(&file, model->tok_embeddings); - write_tensor(&file, model->norm); - write_tensor(&file, model->output); - for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { - auto & layer = model->layers[i]; - - write_tensor(&file, layer.attention_norm); - write_tensor(&file, layer.wq); - write_tensor(&file, layer.wk); - write_tensor(&file, layer.wv); - write_tensor(&file, layer.wo); - write_tensor(&file, layer.ffn_norm); - write_tensor(&file, layer.w1); - write_tensor(&file, layer.w2); - write_tensor(&file, layer.w3); - } +#pragma message("TODO: implement file saving using gguf") + (void) vocab; + (void) model; +// // write_magic +// file.write_u32(LLAMA_FILE_MAGIC); // magic +// file.write_u32(LLAMA_FILE_VERSION); // version +// // write_hparams +// file.write_u32(model->hparams.n_vocab); +// file.write_u32(model->hparams.n_embd); +// file.write_u32(model->hparams.n_mult); +// file.write_u32(model->hparams.n_head); +// file.write_u32(model->hparams.n_layer); +// file.write_u32(model->hparams.n_rot); +// file.write_u32(LLAMA_FTYPE_ALL_F32); +// // write_vocab +// uint32_t n_vocab = model->hparams.n_vocab; +// for (uint32_t i = 0; i < n_vocab; i++) { +// const auto & token_data = vocab->id_to_token.at(i); +// file.write_u32((uint32_t) token_data.tok.size()); +// file.write_raw(token_data.tok.data(), token_data.tok.size()); +// file.write_raw(&token_data.score, sizeof(token_data.score)); +// } +// // write tensors +// write_tensor(&file, model->tok_embeddings); +// write_tensor(&file, model->norm); +// write_tensor(&file, model->output); +// for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { +// auto & layer = model->layers[i]; +// +// write_tensor(&file, layer.attention_norm); +// write_tensor(&file, layer.wq); +// write_tensor(&file, layer.wk); +// write_tensor(&file, layer.wv); +// write_tensor(&file, layer.wo); +// write_tensor(&file, layer.ffn_norm); +// write_tensor(&file, layer.w1); +// write_tensor(&file, layer.w2); +// write_tensor(&file, layer.w3); +// } } float cosine_decay(const int decay_steps, const float alpha, int step) { @@ -3052,20 +3057,13 @@ int main(int argc, char ** argv) { struct llama_vocab vocab; { - std::vector strings; - std::vector scores; - int n_vocab = llama_n_vocab(lctx); - strings.resize(n_vocab, NULL); - scores.resize(n_vocab, 0); - n_vocab = llama_get_vocab(lctx, strings.data(), scores.data(), n_vocab); - GGML_ASSERT(n_vocab == llama_n_vocab(lctx)); + const int n_vocab = llama_n_vocab(lctx); vocab.id_to_token.resize(n_vocab); for (int i=0; i train_samples; train_samples.push_back(0); for (int i = 1; i < (int) train_tokens.size() - n_tokens; ++i) { - if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl())) { + if (!params.samples_start_after_nl || (train_tokens[i-1] == llama_token_nl(lctx))) { train_samples.push_back(i); } } @@ -3338,7 +3336,7 @@ int main(int argc, char ** argv) { struct ggml_tensor * target_logits = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); struct ggml_tensor * target_probs = ggml_new_tensor_2d(model.ctx, GGML_TYPE_F32, n_vocab, n_tokens); - get_example_targets(train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs); + get_example_targets(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), rand()%train_samples.size(), tokens_input, target_logits, target_probs); for (int i=sample_ctx; in_cb = n_cb; } diff --git a/ggml.c b/ggml.c index 44c43b424..c917d73c7 100644 --- a/ggml.c +++ b/ggml.c @@ -213,10 +213,10 @@ inline static void * ggml_aligned_malloc(size_t size) { error_desc = "insufficient memory"; break; } - GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", - __func__, error_desc, size/(1024.0*1024.0)); + GGML_PRINT("%s: %s (attempted to allocate %6.2f MB)\n", __func__, error_desc, size/(1024.0*1024.0)); return NULL; } + return aligned_memory; } #define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size) @@ -4091,7 +4091,11 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) { // // is enough, but just in case, adding the second part - return GGML_PAD(MAX(tensor->ne[3]*tensor->nb[3], ggml_nelements(tensor)*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type), GGML_MEM_ALIGN); + return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*ggml_type_size(tensor->type))/ggml_blck_size(tensor->type)); +} + +size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { + return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN); } size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) { @@ -9118,6 +9122,8 @@ static void ggml_compute_forward_mul( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { + GGML_ASSERT(src1->type == GGML_TYPE_F32 && "only f32 src1 supported for now"); + switch (src0->type) { case GGML_TYPE_F32: { @@ -16881,7 +16887,7 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { // compute size of intermediate results // TODO: does not take into account scratch buffers !!!! for (int i = 0; i < cgraph->n_nodes; ++i) { - size_eval += ggml_nbytes(cgraph->nodes[i]); + size_eval += ggml_nbytes_pad(cgraph->nodes[i]); } // print @@ -18542,6 +18548,1005 @@ size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, i //////////////////////////////////////////////////////////////////////////////// +struct gguf_str { + uint32_t n; + char * data; +}; + +static const size_t GGUF_TYPE_SIZE[GGUF_TYPE_COUNT] = { + [GGUF_TYPE_UINT8] = sizeof(uint8_t), + [GGUF_TYPE_INT8] = sizeof(int8_t), + [GGUF_TYPE_UINT16] = sizeof(uint16_t), + [GGUF_TYPE_INT16] = sizeof(int16_t), + [GGUF_TYPE_UINT32] = sizeof(uint32_t), + [GGUF_TYPE_INT32] = sizeof(int32_t), + [GGUF_TYPE_FLOAT32] = sizeof(float), + [GGUF_TYPE_BOOL] = sizeof(bool), + [GGUF_TYPE_STRING] = sizeof(struct gguf_str), + [GGUF_TYPE_ARRAY] = 0, // undefined +}; +static_assert(GGUF_TYPE_COUNT == 10, "GGUF_TYPE_COUNT != 10"); + +static const char * GGUF_TYPE_NAME[GGUF_TYPE_COUNT] = { + [GGUF_TYPE_UINT8] = "u8", + [GGUF_TYPE_INT8] = "i8", + [GGUF_TYPE_UINT16] = "u16", + [GGUF_TYPE_INT16] = "i16", + [GGUF_TYPE_UINT32] = "u32", + [GGUF_TYPE_INT32] = "i32", + [GGUF_TYPE_FLOAT32] = "f32", + [GGUF_TYPE_BOOL] = "bool", + [GGUF_TYPE_STRING] = "str", + [GGUF_TYPE_ARRAY] = "arr", +}; +static_assert(GGUF_TYPE_COUNT == 10, "GGUF_TYPE_COUNT != 10"); + +union gguf_value { + uint8_t uint8; + int8_t int8; + uint16_t uint16; + int16_t int16; + uint32_t uint32; + int32_t int32; + float float32; + bool bool_; + + struct gguf_str str; + + struct { + enum gguf_type type; + + uint32_t n; + void * data; + } arr; +}; + +struct gguf_kv { + struct gguf_str key; + + uint32_t n_bytes; // TODO: is this actually needed? + + enum gguf_type type; + union gguf_value value; +}; + +struct gguf_header { + uint32_t magic; + uint32_t version; + uint32_t n_tensors; + uint32_t n_kv; +}; + +struct gguf_tensor_info { + struct gguf_str name; + + uint32_t n_dims; + uint32_t ne[GGML_MAX_DIMS]; + + enum ggml_type type; + + uint64_t offset; // offset from start of `data`, must be a multiple of `ALIGNMENT` + + // for writing API + const void * data; + size_t size; +}; + +struct gguf_context { + struct gguf_header header; + + struct gguf_kv * kv; + struct gguf_tensor_info * infos; + + size_t alignment; + size_t offset; // offset of `data` from beginning of file + size_t size; // size of `data` in bytes + + //uint8_t * padding; + void * data; +}; + +static bool gguf_fread_el(FILE * file, void * dst, size_t size, size_t * offset) { + const size_t n = fread(dst, 1, size, file); + *offset += n; + return n == size; +} + +static bool gguf_fread_str(FILE * file, struct gguf_str * p, size_t * offset) { + p->n = 0; + p->data = NULL; + + bool ok = true; + + // TODO: how to avoid mallocs for strings? + ok = ok && gguf_fread_el(file, &p->n, sizeof(p->n), offset); p->data = calloc(p->n + 1, 1); + ok = ok && gguf_fread_el(file, p->data, p->n, offset); + + return ok; +} + +struct gguf_context * gguf_init_empty(void) { + struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); + + ctx->header.magic = GGUF_MAGIC; + ctx->header.version = GGUF_VERSION; + ctx->header.n_tensors = 0; + ctx->header.n_kv = 0; + + ctx->kv = NULL; + ctx->infos = NULL; + + ctx->alignment = GGUF_DEFAULT_ALIGNMENT; + ctx->offset = 0; + ctx->size = 0; + + ctx->data = NULL; + + return ctx; +} + +struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params) { + FILE * file = fopen(fname, "rb"); + if (!file) { + return NULL; + } + + // offset from start of file + size_t offset = 0; + + uint32_t magic = 0; + + // check the magic before making allocations + { + gguf_fread_el(file, &magic, sizeof(magic), &offset); + + if (magic != GGUF_MAGIC) { + fprintf(stderr, "%s: invalid magic number %08x\n", __func__, magic); + fclose(file); + return NULL; + } + } + + bool ok = true; + + struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); + + // read the header + { + ctx->header.magic = magic; + + ctx->kv = NULL; + ctx->infos = NULL; + ctx->data = NULL; + + ok = ok && gguf_fread_el(file, &ctx->header.version, sizeof(ctx->header.version), &offset); + ok = ok && gguf_fread_el(file, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors), &offset); + ok = ok && gguf_fread_el(file, &ctx->header.n_kv, sizeof(ctx->header.n_kv), &offset); + + if (!ok) { + fprintf(stderr, "%s: failed to read header\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + } + + // read the kv pairs + { + ctx->kv = GGML_ALIGNED_MALLOC(ctx->header.n_kv * sizeof(struct gguf_kv)); + + for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv * kv = &ctx->kv[i]; + + //fprintf(stderr, "%s: reading kv %d\n", __func__, i); + + ok = ok && gguf_fread_str(file, &kv->key, &offset); + //ok = ok && gguf_fread_el (file, &kv->n_bytes, sizeof(kv->n_bytes), &offset); + ok = ok && gguf_fread_el (file, &kv->type, sizeof(kv->type), &offset); + + //fprintf(stderr, "%s: reading kv with key %s\n", __func__, kv->key.data); + + switch (kv->type) { + case GGUF_TYPE_UINT8: ok = ok && gguf_fread_el (file, &kv->value.uint8, sizeof(kv->value.uint8), &offset); break; + case GGUF_TYPE_INT8: ok = ok && gguf_fread_el (file, &kv->value.int8, sizeof(kv->value.int8), &offset); break; + case GGUF_TYPE_UINT16: ok = ok && gguf_fread_el (file, &kv->value.uint16, sizeof(kv->value.uint16), &offset); break; + case GGUF_TYPE_INT16: ok = ok && gguf_fread_el (file, &kv->value.int16, sizeof(kv->value.int16), &offset); break; + case GGUF_TYPE_UINT32: ok = ok && gguf_fread_el (file, &kv->value.uint32, sizeof(kv->value.uint32), &offset); break; + case GGUF_TYPE_INT32: ok = ok && gguf_fread_el (file, &kv->value.int32, sizeof(kv->value.int32), &offset); break; + case GGUF_TYPE_FLOAT32: ok = ok && gguf_fread_el (file, &kv->value.float32, sizeof(kv->value.float32), &offset); break; + case GGUF_TYPE_BOOL: ok = ok && gguf_fread_el (file, &kv->value.bool_, sizeof(kv->value.bool_), &offset); break; + case GGUF_TYPE_STRING: ok = ok && gguf_fread_str(file, &kv->value.str, &offset); break; + case GGUF_TYPE_ARRAY: + { + ok = ok && gguf_fread_el(file, &kv->value.arr.type, sizeof(kv->value.arr.type), &offset); + ok = ok && gguf_fread_el(file, &kv->value.arr.n, sizeof(kv->value.arr.n), &offset); + + switch (kv->value.arr.type) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_BOOL: + { + kv->value.arr.data = malloc(kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]); + ok = ok && gguf_fread_el(file, kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type], &offset); + } break; + case GGUF_TYPE_STRING: + { + kv->value.arr.data = malloc(kv->value.arr.n * sizeof(struct gguf_str)); + for (uint32_t j = 0; j < kv->value.arr.n; ++j) { + ok = ok && gguf_fread_str(file, &((struct gguf_str *) kv->value.arr.data)[j], &offset); + } + } break; + case GGUF_TYPE_ARRAY: + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; + }; + } break; + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); + }; + + if (!ok) { + break; + } + } + + if (!ok) { + fprintf(stderr, "%s: failed to read key-value pairs\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + } + + // read the tensor infos + { + ctx->infos = GGML_ALIGNED_MALLOC(ctx->header.n_tensors * sizeof(struct gguf_tensor_info)); + + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + for (int j = 0; j < GGML_MAX_DIMS; ++j) { + info->ne[j] = 1; + } + + ok = ok && gguf_fread_str(file, &info->name, &offset); + ok = ok && gguf_fread_el (file, &info->n_dims, sizeof(info->n_dims), &offset); + for (uint32_t j = 0; j < info->n_dims; ++j) { + ok = ok && gguf_fread_el(file, &info->ne[j], sizeof(info->ne[j]), &offset); + } + ok = ok && gguf_fread_el (file, &info->type, sizeof(info->type), &offset); + ok = ok && gguf_fread_el (file, &info->offset, sizeof(info->offset), &offset); + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor info\n", __func__); + fclose(file); + gguf_free(ctx); + return NULL; + } + } + } + + ctx->alignment = GGUF_DEFAULT_ALIGNMENT; + + int alignment_idx = gguf_find_key(ctx, "general.alignment"); + if (alignment_idx != -1) { + ctx->alignment = gguf_get_val_u32(ctx, alignment_idx); + } + + // we require the data section to be aligned, so take into account any padding + { + const size_t offset_pad = offset % ctx->alignment; + + if (offset_pad != 0) { + offset += ctx->alignment - offset_pad; + fseek(file, offset, SEEK_SET); + } + } + + // store the current file offset - this is where the data section starts + ctx->offset = offset; + + // compute the total size of the data section, taking into account the alignment + { + ctx->size = 0; + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + const int64_t ne = + (int64_t) info->ne[0] * + (int64_t) info->ne[1] * + (int64_t) info->ne[2] * + (int64_t) info->ne[3]; + + if (ne % ggml_blck_size(info->type) != 0) { + fprintf(stderr, "%s: tensor '%s' number of elements (%" PRId64 ") is not a multiple of block size (%d)\n", + __func__, info->name.data, ne, ggml_blck_size(info->type)); + fclose(file); + gguf_free(ctx); + return NULL; + } + + const size_t size_cur = (ne*ggml_type_size(info->type))/ggml_blck_size(info->type); + + ctx->size += GGML_PAD(size_cur, ctx->alignment); + } + } + + // load the tensor data only if requested + if (params.ctx != NULL) { + // if the provided gguf_context is no_alloc, then we create "empty" tensors and do not read the binary blob + // otherwise, we load the binary blob into the created ggml_context as well, and point the "data" members of + // the ggml_tensor structs to the appropriate locations in the binary blob + + // compute the exact size needed for the new ggml_context + const size_t mem_size = + params.no_alloc ? + (ctx->header.n_tensors )*ggml_tensor_overhead() : + (ctx->header.n_tensors + 1)*ggml_tensor_overhead() + ctx->size; + + struct ggml_init_params pdata = { + .mem_size = mem_size, + .mem_buffer = NULL, + .no_alloc = params.no_alloc, + }; + + *params.ctx = ggml_init(pdata); + + struct ggml_context * ctx_data = *params.ctx; + + struct ggml_tensor * data = NULL; + + if (params.no_alloc == false) { + data = ggml_new_tensor_1d(ctx_data, GGML_TYPE_I8, ctx->size); + + ok = ok && data != NULL; + + // read the binary blob with the tensor data + ok = ok && gguf_fread_el(file, data->data, ctx->size, &offset); + + if (!ok) { + fprintf(stderr, "%s: failed to read tensor data\n", __func__); + fclose(file); + ggml_free(ctx_data); + gguf_free(ctx); + return NULL; + } + + ctx->data = data->data; + } + + ggml_set_no_alloc(ctx_data, true); + + // create the tensors + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + const int64_t ne[GGML_MAX_DIMS] = { + ctx->infos[i].ne[0], + ctx->infos[i].ne[1], + ctx->infos[i].ne[2], + ctx->infos[i].ne[3], + }; + + struct ggml_tensor * cur = ggml_new_tensor(ctx_data, ctx->infos[i].type, ctx->infos[i].n_dims, ne); + + ok = ok && cur != NULL; + + ggml_set_name(cur, ctx->infos[i].name.data); + + if (!ok) { + break; + } + + // point the data member to the appropriate location in the binary blob using the tensor infos + if (params.no_alloc == false) { + //cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file + cur->data = (char *) data->data + ctx->infos[i].offset; // offset from data + } + } + + if (!ok) { + fprintf(stderr, "%s: failed to read the tensor data\n", __func__); + fclose(file); + ggml_free(ctx_data); + gguf_free(ctx); + return NULL; + } + + ggml_set_no_alloc(ctx_data, params.no_alloc); + } + + fclose(file); + + return ctx; +} + +void gguf_free(struct gguf_context * ctx) { + if (ctx == NULL) { + return; + } + + if (ctx->kv) { + // free string memory - not great.. + for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv * kv = &ctx->kv[i]; + + if (kv->key.data) { + free(kv->key.data); + } + + if (kv->type == GGUF_TYPE_STRING) { + if (kv->value.str.data) { + free(kv->value.str.data); + } + } + + if (kv->type == GGUF_TYPE_ARRAY) { + if (kv->value.arr.data) { + if (kv->value.arr.type == GGUF_TYPE_STRING) { + for (uint32_t j = 0; j < kv->value.arr.n; ++j) { + struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[j]; + if (str->data) { + free(str->data); + } + } + } + free(kv->value.arr.data); + } + } + } + + GGML_ALIGNED_FREE(ctx->kv); + } + + if (ctx->infos) { + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + if (info->name.data) { + free(info->name.data); + } + } + + GGML_ALIGNED_FREE(ctx->infos); + } + + GGML_ALIGNED_FREE(ctx); +} + +const char * gguf_type_name(enum gguf_type type) { + return GGUF_TYPE_NAME[type]; +} + +int gguf_get_version(struct gguf_context * ctx) { + return ctx->header.version; +} + +size_t gguf_get_alignment(struct gguf_context * ctx) { + return ctx->alignment; +} + +size_t gguf_get_data_offset(struct gguf_context * ctx) { + return ctx->offset; +} + +void * gguf_get_data(struct gguf_context * ctx) { + return ctx->data; +} + +int gguf_get_n_kv(struct gguf_context * ctx) { + return ctx->header.n_kv; +} + +int gguf_find_key(struct gguf_context * ctx, const char * key) { + // return -1 if key not found + int keyfound = -1; + + const int n_kv = gguf_get_n_kv(ctx); + + for (int i = 0; i < n_kv; ++i) { + if (strcmp(key, gguf_get_key(ctx, i)) == 0) { + keyfound = i; + break; + } + } + + return keyfound; +} + +const char * gguf_get_key(struct gguf_context * ctx, int i) { + return ctx->kv[i].key.data; +} + +enum gguf_type gguf_get_kv_type(struct gguf_context * ctx, int i) { + return ctx->kv[i].type; +} + +enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.arr.type; +} + +const void * gguf_get_arr_data(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.arr.data; +} + +const char * gguf_get_arr_str(struct gguf_context * ctx, int key_id, int i) { + struct gguf_kv * kv = &ctx->kv[key_id]; + struct gguf_str * str = &((struct gguf_str *) kv->value.arr.data)[i]; + return str->data; +} + +int gguf_get_arr_n(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.arr.n; +} + +uint8_t gguf_get_val_u8(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.uint8; +} + +int8_t gguf_get_val_i8(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.int8; +} + +uint16_t gguf_get_val_u16(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.uint16; +} + +int16_t gguf_get_val_i16(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.int16; +} + +uint32_t gguf_get_val_u32(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.uint32; +} + +int32_t gguf_get_val_i32(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.int32; +} + +float gguf_get_val_f32(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.float32; +} + +bool gguf_get_val_bool(struct gguf_context * ctx, int i) { + return ctx->kv[i].value.bool_; +} + +const char * gguf_get_val_str (struct gguf_context * ctx, int i) { + return ctx->kv[i].value.str.data; +} + +int gguf_get_n_tensors(struct gguf_context * ctx) { + return ctx->header.n_tensors; +} + +int gguf_find_tensor(struct gguf_context * ctx, const char * name) { + // return -1 if tensor not found + int tensorfound = -1; + + const int n_tensors = gguf_get_n_tensors(ctx); + + for (int i = 0; i < n_tensors; ++i) { + if (strcmp(name, gguf_get_tensor_name(ctx, i)) == 0) { + tensorfound = i; + break; + } + } + + return tensorfound; +} + +size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i) { + return ctx->infos[i].offset; +} + +char * gguf_get_tensor_name(struct gguf_context * ctx, int i) { + return ctx->infos[i].name.data; +} + +// returns the index +static int gguf_get_or_add_key(struct gguf_context * ctx, const char * key) { + const int idx = gguf_find_key(ctx, key); + if (idx >= 0) { + return idx; + } + + const int n_kv = gguf_get_n_kv(ctx); + + ctx->kv = realloc(ctx->kv, (n_kv + 1) * sizeof(struct gguf_kv)); + ctx->kv[n_kv].key.n = strlen(key) + 1; + ctx->kv[n_kv].key.data = strdup(key); + ctx->header.n_kv++; + + return n_kv; +} + +void gguf_set_val_u8(struct gguf_context * ctx, const char * key, uint8_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT8; + ctx->kv[idx].value.uint8 = val; +} + +void gguf_set_val_i8(struct gguf_context * ctx, const char * key, int8_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT8; + ctx->kv[idx].value.int8 = val; +} + +void gguf_set_val_u16(struct gguf_context * ctx, const char * key, uint16_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT16; + ctx->kv[idx].value.uint16 = val; +} + +void gguf_set_val_i16(struct gguf_context * ctx, const char * key, int16_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT16; + ctx->kv[idx].value.int16 = val; +} + +void gguf_set_val_u32(struct gguf_context * ctx, const char * key, uint32_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_UINT32; + ctx->kv[idx].value.uint32 = val; +} + +void gguf_set_val_i32(struct gguf_context * ctx, const char * key, int32_t val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_INT32; + ctx->kv[idx].value.int32 = val; +} + +void gguf_set_val_f32(struct gguf_context * ctx, const char * key, float val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_FLOAT32; + ctx->kv[idx].value.float32 = val; +} + +void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_BOOL; + ctx->kv[idx].value.bool_ = val; +} + +void gguf_set_val_str(struct gguf_context * ctx, const char * key, const char * val) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_STRING; + ctx->kv[idx].value.str.n = strlen(val) + 1; + ctx->kv[idx].value.str.data = strdup(val); +} + +void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_ARRAY; + ctx->kv[idx].value.arr.type = type; + ctx->kv[idx].value.arr.n = n; + ctx->kv[idx].value.arr.data = malloc(n*GGUF_TYPE_SIZE[type]); + memcpy(ctx->kv[idx].value.arr.data, data, n*GGUF_TYPE_SIZE[type]); +} + +void gguf_set_arr_str(struct gguf_context * ctx, const char * key, const char ** data, int n) { + const int idx = gguf_get_or_add_key(ctx, key); + + ctx->kv[idx].type = GGUF_TYPE_ARRAY; + ctx->kv[idx].value.arr.type = GGUF_TYPE_STRING; + ctx->kv[idx].value.arr.n = n; + ctx->kv[idx].value.arr.data = malloc(n*sizeof(struct gguf_str)); + for (int i = 0; i < n; i++) { + struct gguf_str * str = &((struct gguf_str *)ctx->kv[idx].value.arr.data)[i]; + str->n = strlen(data[i]) + 1; + str->data = strdup(data[i]); + } +} + +// set or add KV pairs from another context +void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src) { + for (uint32_t i = 0; i < src->header.n_kv; i++) { + switch (src->kv[i].type) { + case GGUF_TYPE_UINT8: gguf_set_val_u8 (ctx, src->kv[i].key.data, src->kv[i].value.uint8); break; + case GGUF_TYPE_INT8: gguf_set_val_i8 (ctx, src->kv[i].key.data, src->kv[i].value.int8); break; + case GGUF_TYPE_UINT16: gguf_set_val_u16 (ctx, src->kv[i].key.data, src->kv[i].value.uint16); break; + case GGUF_TYPE_INT16: gguf_set_val_i16 (ctx, src->kv[i].key.data, src->kv[i].value.int16); break; + case GGUF_TYPE_UINT32: gguf_set_val_u32 (ctx, src->kv[i].key.data, src->kv[i].value.uint32); break; + case GGUF_TYPE_INT32: gguf_set_val_i32 (ctx, src->kv[i].key.data, src->kv[i].value.int32); break; + case GGUF_TYPE_FLOAT32: gguf_set_val_f32 (ctx, src->kv[i].key.data, src->kv[i].value.float32); break; + case GGUF_TYPE_BOOL: gguf_set_val_bool(ctx, src->kv[i].key.data, src->kv[i].value.bool_); break; + case GGUF_TYPE_STRING: gguf_set_val_str (ctx, src->kv[i].key.data, src->kv[i].value.str.data); break; + case GGUF_TYPE_ARRAY: + { + if (src->kv[i].value.arr.type == GGUF_TYPE_STRING) { + const char ** data = malloc(src->kv[i].value.arr.n*sizeof(char *)); + for (uint32_t j = 0; j < src->kv[i].value.arr.n; j++) { + data[j] = ((struct gguf_str *)src->kv[i].value.arr.data)[j].data; + } + gguf_set_arr_str(ctx, src->kv[i].key.data, data, src->kv[i].value.arr.n); + free(data); + } else if (src->kv[i].value.arr.type == GGUF_TYPE_ARRAY) { + GGML_ASSERT(false && "nested arrays not supported"); + } else { + gguf_set_arr_data(ctx, src->kv[i].key.data, src->kv[i].value.arr.type, src->kv[i].value.arr.data, src->kv[i].value.arr.n); + } + } break; + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; + } + } +} + +void gguf_add_tensor( + struct gguf_context * ctx, + const struct ggml_tensor * tensor) { + const int idx = ctx->header.n_tensors; + ctx->infos = realloc(ctx->infos, (idx + 1)*sizeof(struct gguf_tensor_info)); + + ctx->infos[idx].name.n = strlen(tensor->name) + 1; + ctx->infos[idx].name.data = strdup(tensor->name); + + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + ctx->infos[idx].ne[i] = 1; + } + + ctx->infos[idx].n_dims = tensor->n_dims; + for (int i = 0; i < tensor->n_dims; i++) { + ctx->infos[idx].ne[i] = tensor->ne[i]; + } + + ctx->infos[idx].type = tensor->type; + ctx->infos[idx].offset = 0; + ctx->infos[idx].data = tensor->data; + ctx->infos[idx].size = ggml_nbytes(tensor); + + if (ctx->header.n_tensors > 0) { + ctx->infos[idx].offset = ctx->infos[idx - 1].offset + GGML_PAD(ctx->infos[idx - 1].size, ctx->alignment); + } + + ctx->header.n_tensors++; +} + +void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type) { + const int idx = gguf_find_tensor(ctx, name); + if (idx < 0) { + GGML_ASSERT(false && "tensor not found"); + } + + ctx->infos[idx].type = type; +} + +void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size) { + const int idx = gguf_find_tensor(ctx, name); + if (idx < 0) { + GGML_ASSERT(false && "tensor not found"); + } + + ctx->infos[idx].data = data; + ctx->infos[idx].size = size; + + // update offsets + for (uint32_t i = idx + 1; i < ctx->header.n_tensors; ++i) { + ctx->infos[i].offset = ctx->infos[i - 1].offset + GGML_PAD(ctx->infos[i - 1].size, ctx->alignment); + } +} + +//static void gguf_fwrite_str(FILE * file, const struct gguf_str * val) { +// fwrite(&val->n, sizeof(val->n), 1, file); +// fwrite(val->data, sizeof(char), val->n, file); +//} +// +//static void gguf_fwrite_el(FILE * file, const void * val, size_t size) { +// fwrite(val, sizeof(char), size, file); +//} + +struct gguf_buf { + void * data; + size_t size; + size_t offset; +}; + +static struct gguf_buf gguf_buf_init(size_t size) { + struct gguf_buf buf = { + /*buf.data =*/ size == 0 ? NULL : malloc(size), + /*buf.size =*/ size, + /*buf.offset =*/ 0, + }; + + return buf; +} + +static void gguf_buf_free(struct gguf_buf buf) { + if (buf.data) { + free(buf.data); + } +} + +static void gguf_buf_grow(struct gguf_buf * buf, size_t size) { + if (buf->offset + size > buf->size) { + buf->size = 1.5*(buf->offset + size); + if (buf->data) { + buf->data = realloc(buf->data, buf->size); + } + } +} + +static void gguf_bwrite_str(struct gguf_buf * buf, const struct gguf_str * val) { + gguf_buf_grow(buf, sizeof(val->n) + val->n); + + if (buf->data) { + memcpy((char *) buf->data + buf->offset, &val->n, sizeof(val->n)); + } + buf->offset += sizeof(val->n); + + if (buf->data) { + memcpy((char *) buf->data + buf->offset, val->data, val->n); + } + buf->offset += val->n; +} + +static void gguf_bwrite_el(struct gguf_buf * buf, const void * val, size_t el_size) { + gguf_buf_grow(buf, el_size); + + if (buf->data) { + memcpy((char *) buf->data + buf->offset, val, el_size); + } + buf->offset += el_size; +} + +static void gguf_write_to_buf(struct gguf_context * ctx, struct gguf_buf * buf, bool only_meta) { + // write header + gguf_bwrite_el(buf, &ctx->header.magic, sizeof(ctx->header.magic)); + gguf_bwrite_el(buf, &ctx->header.version, sizeof(ctx->header.version)); + gguf_bwrite_el(buf, &ctx->header.n_tensors, sizeof(ctx->header.n_tensors)); + gguf_bwrite_el(buf, &ctx->header.n_kv, sizeof(ctx->header.n_kv)); + + // write key-value pairs + for (uint32_t i = 0; i < ctx->header.n_kv; ++i) { + struct gguf_kv * kv = &ctx->kv[i]; + + gguf_bwrite_str(buf, &kv->key); + gguf_bwrite_el (buf, &kv->type, sizeof(kv->type)); + + switch (kv->type) { + case GGUF_TYPE_UINT8: gguf_bwrite_el( buf, &kv->value.uint8, sizeof(kv->value.uint8) ); break; + case GGUF_TYPE_INT8: gguf_bwrite_el (buf, &kv->value.int8, sizeof(kv->value.int8) ); break; + case GGUF_TYPE_UINT16: gguf_bwrite_el (buf, &kv->value.uint16, sizeof(kv->value.uint16) ); break; + case GGUF_TYPE_INT16: gguf_bwrite_el (buf, &kv->value.int16, sizeof(kv->value.int16) ); break; + case GGUF_TYPE_UINT32: gguf_bwrite_el (buf, &kv->value.uint32, sizeof(kv->value.uint32) ); break; + case GGUF_TYPE_INT32: gguf_bwrite_el (buf, &kv->value.int32, sizeof(kv->value.int32) ); break; + case GGUF_TYPE_FLOAT32: gguf_bwrite_el (buf, &kv->value.float32, sizeof(kv->value.float32)); break; + case GGUF_TYPE_BOOL: gguf_bwrite_el (buf, &kv->value.bool_, sizeof(kv->value.bool_) ); break; + case GGUF_TYPE_STRING: gguf_bwrite_str(buf, &kv->value.str ); break; + case GGUF_TYPE_ARRAY: + { + gguf_bwrite_el(buf, &kv->value.arr.type, sizeof(kv->value.arr.type)); + gguf_bwrite_el(buf, &kv->value.arr.n, sizeof(kv->value.arr.n) ); + + switch (kv->value.arr.type) { + case GGUF_TYPE_UINT8: + case GGUF_TYPE_INT8: + case GGUF_TYPE_UINT16: + case GGUF_TYPE_INT16: + case GGUF_TYPE_UINT32: + case GGUF_TYPE_INT32: + case GGUF_TYPE_FLOAT32: + case GGUF_TYPE_BOOL: + { + gguf_bwrite_el(buf, kv->value.arr.data, kv->value.arr.n * GGUF_TYPE_SIZE[kv->value.arr.type]); + } break; + case GGUF_TYPE_STRING: + { + for (uint32_t j = 0; j < kv->value.arr.n; ++j) { + gguf_bwrite_str(buf, &((struct gguf_str *) kv->value.arr.data)[j]); + } + } break; + case GGUF_TYPE_ARRAY: + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); break; + }; + } break; + case GGUF_TYPE_COUNT: GGML_ASSERT(false && "invalid type"); + }; + } + + // write tensor infos + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + gguf_bwrite_str(buf, &info->name); + gguf_bwrite_el (buf, &info->n_dims, sizeof(info->n_dims)); + for (uint32_t j = 0; j < info->n_dims; ++j) { + gguf_bwrite_el(buf, &info->ne[j], sizeof(info->ne[j])); + } + gguf_bwrite_el(buf, &info->type, sizeof(info->type)); + gguf_bwrite_el(buf, &info->offset, sizeof(info->offset)); + } + + // we require the data section to be aligned, so take into account any padding + { + const size_t offset = buf->offset; + const size_t offset_pad = GGML_PAD(offset, ctx->alignment); + + if (offset_pad != offset) { + uint8_t pad = 0; + for (size_t i = 0; i < offset_pad - offset; ++i) { + gguf_bwrite_el(buf, &pad, sizeof(pad)); + } + } + } + + if (only_meta) { + return; + } + + size_t offset = 0; + + // write tensor data + for (uint32_t i = 0; i < ctx->header.n_tensors; ++i) { + struct gguf_tensor_info * info = &ctx->infos[i]; + + const size_t size = info->size; + const size_t size_pad = GGML_PAD(size, ctx->alignment); + + gguf_bwrite_el(buf, info->data, size); + + if (size_pad != size) { + uint8_t pad = 0; + for (size_t j = 0; j < size_pad - size; ++j) { + gguf_bwrite_el(buf, &pad, sizeof(pad)); + } + } + + GGML_ASSERT(offset == info->offset); + + offset += size_pad; + } +} + +void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta) { + FILE * file = fopen(fname, "wb"); + if (!file) { + GGML_ASSERT(false && "failed to open file for writing"); + } + + struct gguf_buf buf = gguf_buf_init(16*1024); + + gguf_write_to_buf(ctx, &buf, only_meta); + + fwrite(buf.data, 1, buf.offset, file); + + gguf_buf_free(buf); + + fclose(file); +} + +size_t gguf_get_meta_size(struct gguf_context * ctx) { + // no allocs - only compute size + struct gguf_buf buf = gguf_buf_init(0); + + gguf_write_to_buf(ctx, &buf, true); + + return buf.offset; +} + +void gguf_get_meta_data(struct gguf_context * ctx, void * data) { + struct gguf_buf buf = gguf_buf_init(16*1024); + + gguf_write_to_buf(ctx, &buf, true); + + memcpy(data, buf.data, buf.offset); + + gguf_buf_free(buf); +} + +//////////////////////////////////////////////////////////////////////////////// + int ggml_cpu_has_avx(void) { #if defined(__AVX__) return 1; diff --git a/ggml.h b/ggml.h index 3a946dbdc..544ad2d11 100644 --- a/ggml.h +++ b/ggml.h @@ -207,14 +207,18 @@ #define GGML_MAX_PARAMS 256 #define GGML_MAX_CONTEXTS 64 #define GGML_MAX_SRC 6 -#define GGML_MAX_NAME 48 +#define GGML_MAX_NAME 64 #define GGML_MAX_OP_PARAMS 32 #define GGML_DEFAULT_N_THREADS 4 - #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 +#define GGUF_MAGIC 0x46554747 // "GGUF" +#define GGUF_VERSION 1 + +#define GGUF_DEFAULT_ALIGNMENT 32 + #define GGML_UNUSED(x) (void)(x) #define GGML_PAD(x, n) (((x) + (n) - 1) & ~((n) - 1)) @@ -562,6 +566,7 @@ extern "C" { GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor); GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); + GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN GGML_API size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split); GGML_API int ggml_blck_size (enum ggml_type type); @@ -1494,7 +1499,6 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * tensor); - GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); @@ -1703,6 +1707,118 @@ extern "C" { GGML_API size_t ggml_quantize_chunk(enum ggml_type type, const float * src, void * dst, int start, int n, int64_t * hist); + // + // gguf + // + + enum gguf_type { + GGUF_TYPE_UINT8 = 0, + GGUF_TYPE_INT8 = 1, + GGUF_TYPE_UINT16 = 2, + GGUF_TYPE_INT16 = 3, + GGUF_TYPE_UINT32 = 4, + GGUF_TYPE_INT32 = 5, + GGUF_TYPE_FLOAT32 = 6, + GGUF_TYPE_BOOL = 7, + GGUF_TYPE_STRING = 8, + GGUF_TYPE_ARRAY = 9, + GGUF_TYPE_COUNT, // marks the end of the enum + }; + + struct gguf_context; + + struct gguf_init_params { + bool no_alloc; + + // if not NULL, create a ggml_context and allocate the tensor data in it + struct ggml_context ** ctx; + }; + + GGML_API struct gguf_context * gguf_init_empty(void); + GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); + //GGML_API struct gguf_context * gguf_init_from_buffer(..); + + GGML_API void gguf_free(struct gguf_context * ctx); + + GGML_API const char * gguf_type_name(enum gguf_type type); + + GGML_API int gguf_get_version (struct gguf_context * ctx); + GGML_API size_t gguf_get_alignment (struct gguf_context * ctx); + GGML_API size_t gguf_get_data_offset(struct gguf_context * ctx); + GGML_API void * gguf_get_data (struct gguf_context * ctx); + + GGML_API int gguf_get_n_kv(struct gguf_context * ctx); + GGML_API int gguf_find_key(struct gguf_context * ctx, const char * key); + GGML_API const char * gguf_get_key (struct gguf_context * ctx, int i); + + GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); + GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i); + + // results are undefined if the wrong type is used for the key + GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i); + GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i); + GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i); + GGML_API int16_t gguf_get_val_i16 (struct gguf_context * ctx, int i); + GGML_API uint32_t gguf_get_val_u32 (struct gguf_context * ctx, int i); + GGML_API int32_t gguf_get_val_i32 (struct gguf_context * ctx, int i); + GGML_API float gguf_get_val_f32 (struct gguf_context * ctx, int i); + GGML_API bool gguf_get_val_bool(struct gguf_context * ctx, int i); + GGML_API const char * gguf_get_val_str (struct gguf_context * ctx, int i); + GGML_API int gguf_get_arr_n (struct gguf_context * ctx, int i); + GGML_API const void * gguf_get_arr_data(struct gguf_context * ctx, int i); + GGML_API const char * gguf_get_arr_str (struct gguf_context * ctx, int key_id, int i); + + GGML_API int gguf_get_n_tensors (struct gguf_context * ctx); + GGML_API int gguf_find_tensor (struct gguf_context * ctx, const char * name); + GGML_API size_t gguf_get_tensor_offset(struct gguf_context * ctx, int i); + GGML_API char * gguf_get_tensor_name (struct gguf_context * ctx, int i); + + // overrides existing values or adds a new one + GGML_API void gguf_set_val_u8 (struct gguf_context * ctx, const char * key, uint8_t val); + GGML_API void gguf_set_val_i8 (struct gguf_context * ctx, const char * key, int8_t val); + GGML_API void gguf_set_val_u16 (struct gguf_context * ctx, const char * key, uint16_t val); + GGML_API void gguf_set_val_i16 (struct gguf_context * ctx, const char * key, int16_t val); + GGML_API void gguf_set_val_u32 (struct gguf_context * ctx, const char * key, uint32_t val); + GGML_API void gguf_set_val_i32 (struct gguf_context * ctx, const char * key, int32_t val); + GGML_API void gguf_set_val_f32 (struct gguf_context * ctx, const char * key, float val); + GGML_API void gguf_set_val_bool(struct gguf_context * ctx, const char * key, bool val); + GGML_API void gguf_set_val_str (struct gguf_context * ctx, const char * key, const char * val); + GGML_API void gguf_set_arr_data(struct gguf_context * ctx, const char * key, enum gguf_type type, const void * data, int n); + GGML_API void gguf_set_arr_str (struct gguf_context * ctx, const char * key, const char ** data, int n); + + // set or add KV pairs from another context + GGML_API void gguf_set_kv(struct gguf_context * ctx, struct gguf_context * src); + + // manage tensor info + GGML_API void gguf_add_tensor(struct gguf_context * ctx, const struct ggml_tensor * tensor); + GGML_API void gguf_set_tensor_type(struct gguf_context * ctx, const char * name, enum ggml_type type); + GGML_API void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const void * data, size_t size); + + // writing gguf files can be done in 2 ways: + // + // - write the entire gguf_context to a binary file in a single pass: + // + // gguf_write_to_file(ctx, fname); + // + // - first prepare a file with a placeholder for the meta data, write the tensor data, then write the meta data: + // + // FILE * f = fopen(fname, "wb"); + // fseek(f, gguf_get_meta_size(ctx), SEEK_SET); + // fwrite(f, ...); + // void * data = gguf_meta_get_meta_data(ctx); + // fseek(f, 0, SEEK_SET); + // fwrite(f, data, gguf_get_meta_size(ctx)); + // free(data); + // fclose(f); + // + + // write the entire context to a binary file + GGML_API void gguf_write_to_file(struct gguf_context * ctx, const char * fname, bool only_meta); + + // get the size in bytes of the meta data (header, kv pairs, tensor info) including padding + GGML_API size_t gguf_get_meta_size(struct gguf_context * ctx); + GGML_API void gguf_get_meta_data(struct gguf_context * ctx, void * data); + // // system info // diff --git a/gguf.py b/gguf.py new file mode 100644 index 000000000..9776649c7 --- /dev/null +++ b/gguf.py @@ -0,0 +1,718 @@ +import shutil +import sys +import struct +import tempfile +import numpy as np + +from enum import IntEnum, auto +from typing import Any, IO, List, Optional + +# +# constants +# + +GGUF_MAGIC = 0x46554747 +GGUF_VERSION = 1 +GGUF_DEFAULT_ALIGNMENT = 32 + +# general +KEY_GENERAL_ARCHITECTURE = "general.architecture" +KEY_GENERAL_QUANTIZATION_VERSION = "general.quantization_version" +KEY_GENERAL_ALIGNMENT = "general.alignment" +KEY_GENERAL_NAME = "general.name" +KEY_GENERAL_AUTHOR = "general.author" +KEY_GENERAL_URL = "general.url" +KEY_GENERAL_DESCRIPTION = "general.description" +KEY_GENERAL_LICENSE = "general.license" +KEY_GENERAL_SOURCE_URL = "general.source.url" +KEY_GENERAL_SOURCE_HF_REPO = "general.source.hugginface.repository" + +# LLM +KEY_LLM_CONTEXT_LENGTH = "{arch}.context_length" +KEY_LLM_EMBEDDING_LENGTH = "{arch}.embedding_length" +KEY_LLM_BLOCK_COUNT = "{arch}.block_count" +KEY_LLM_FEED_FORWARD_LENGTH = "{arch}.feed_forward_length" +KEY_LLM_USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual" +KEY_LLM_TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout" + +# attention +KEY_ATTENTION_HEAD_COUNT = "{arch}.attention.head_count" +KEY_ATTENTION_HEAD_COUNT_KV = "{arch}.attention.head_count_kv" +KEY_ATTENTION_MAX_ALIBI_BIAS = "{arch}.attention.max_alibi_bias" +KEY_ATTENTION_CLAMP_KQV = "{arch}.attention.clamp_kqv" +KEY_ATTENTION_LAYERNORM_EPS = "{arch}.attention.layer_norm_epsilon" +KEY_ATTENTION_LAYERNORM_RMS_EPS = "{arch}.attention.layer_norm_rms_epsilon" + +# RoPE +KEY_ROPE_DIMENSION_COUNT = "{arch}.rope.dimension_count" +KEY_ROPE_SCALE_LINEAR = "{arch}.rope.scale_linear" + +# tokenization +KEY_TOKENIZER_MODEL = "tokenizer.ggml.model" +KEY_TOKENIZER_LIST = "tokenizer.ggml.tokens" +KEY_TOKENIZER_TOKEN_TYPE = "tokenizer.ggml.token_type" +KEY_TOKENIZER_SCORES = "tokenizer.ggml.scores" +KEY_TOKENIZER_MERGES = "tokenizer.ggml.merges" +KEY_TOKENIZER_BOS_ID = "tokenizer.ggml.bos_token_id" +KEY_TOKENIZER_EOS_ID = "tokenizer.ggml.eos_token_id" +KEY_TOKENIZER_UNK_ID = "tokenizer.ggml.unknown_token_id" +KEY_TOKENIZER_SEP_ID = "tokenizer.ggml.seperator_token_id" +KEY_TOKENIZER_PAD_ID = "tokenizer.ggml.padding_token_id" +KEY_TOKENIZER_HF_JSON = "tokenizer.huggingface.json" +KEY_TOKENIZER_RWKV = "tokenizer.rwkv.world" + + +# +# recommended mapping of model tensor names for storage in gguf +# + + +class MODEL_ARCH(IntEnum): + LLAMA = auto() + FALCON = auto() + GPT2 = auto() + GPTJ = auto() + GPTNEOX = auto() + MPT = auto() + + +class MODEL_TENSOR(IntEnum): + TOKEN_EMBD = auto() + POS_EMBD = auto() + OUTPUT = auto() + OUTPUT_NORM = auto() + ROPE_FREQS = auto() + ATTN_Q = auto() + ATTN_K = auto() + ATTN_V = auto() + ATTN_QKV = auto() + ATTN_OUT = auto() + ATTN_NORM = auto() + ATTN_NORM_2 = auto() + ATTN_ROT_EMBD = auto() + FFN_GATE = auto() + FFN_DOWN = auto() + FFN_UP = auto() + FFN_NORM = auto() + + +MODEL_ARCH_NAMES = { + MODEL_ARCH.LLAMA: "llama", + MODEL_ARCH.FALCON: "falcon", + MODEL_ARCH.GPT2: "gpt2", + MODEL_ARCH.GPTJ: "gptj", + MODEL_ARCH.GPTNEOX: "gptneox", + MODEL_ARCH.MPT: "mpt", +} + +MODEL_TENSOR_NAMES = { + MODEL_ARCH.LLAMA: { + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", + MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", + MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + }, + MODEL_ARCH.GPTNEOX: { + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + }, + MODEL_ARCH.FALCON: { + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + }, + MODEL_ARCH.GPT2: { + # TODO + }, + # TODO +} + +# tensors that will not be serialized +MODEL_TENSOR_SKIP = { + MODEL_ARCH.LLAMA: [ + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_ROT_EMBD, + ], +} + + +# TODO: the following helper functions should be removed +# instead, get_tensor_name_map should return tuples of (name, MODEL_TENSOR) +# however, my Python is very bad, and I couldn't figure out how to do this, hence these functions +# REMOVE +def should_skip_tensor_TMP(arch: MODEL_ARCH, n_blocks: int, name: str) -> bool: + for skip in MODEL_TENSOR_SKIP.get(arch, []): + for i in range(n_blocks): + if name == MODEL_TENSOR_NAMES[arch][skip].format(bid=i): + return True + + return False + + +def get_tensor_name_map(arch: MODEL_ARCH, n_blocks: int) -> dict: + tensor_map = {} + + # Token embeddings + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.TOKEN_EMBD, None) + + tensor_map["gpt_neox.embed_in"] = mapped_to # gptneox + tensor_map["transformer.wte"] = mapped_to # gpt2 mpt + tensor_map["transformer.word_embeddings"] = mapped_to # falcon + tensor_map["model.embed_tokens"] = mapped_to # llama-hf + tensor_map["tok_embeddings"] = mapped_to # llama-pth + + # Position embeddings + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.POS_EMBD, None) + + tensor_map["transformer.wpe"] = mapped_to # gpt2 + + # Output + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT, None) + + tensor_map["embed_out"] = mapped_to # gptneox + tensor_map["lm_head"] = mapped_to # gpt2 mpt falcon llama-hf + tensor_map["output"] = mapped_to # llama-pth + + # Output norm + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.OUTPUT_NORM, None) + + tensor_map["gpt_neox.final_layer_norm"] = mapped_to # gptneox + tensor_map["transformer.ln_f"] = mapped_to # gpt2 falcon + tensor_map["transformer.norm_f"] = mapped_to # mpt + tensor_map["model.norm"] = mapped_to # llama-hf + tensor_map["norm"] = mapped_to # llama-pth + + # Rope frequencies + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ROPE_FREQS, None) + + tensor_map["rope.freqs"] = mapped_to # llama-pth + + # Attention and feed-forward blocks + for i in range(0, n_blocks): + # Attention norm + # TODO: is there are simpler way to write these 2 lines in Python? + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM, None) + mapped_to = mapped_to.format(bid=i) if mapped_to else None + + tensor_map["gpt_neox.layers."+str(i)+".input_layernorm"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".ln_1"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".norm_1"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".input_layernorm"] = mapped_to # falcon7b + tensor_map["transformer.h."+str(i)+".ln_mlp"] = mapped_to # falcon40b + tensor_map["model.layers."+str(i)+".input_layernorm"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention_norm"] = mapped_to # llama-pth + + # Attention norm 2 + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_NORM_2, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["transformer.h."+str(i)+".ln_attn"] = mapped_to # falcon40b + + # Attention query-key-value + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_QKV, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["gpt_neox.layers."+str(i)+".attention.query_key_value"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".attn.c_attn"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".attn.Wqkv"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".self_attention.query_key_value"] = mapped_to # falcon + + # Attention query + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_Q, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["model.layers."+str(i)+".self_attn.q_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wq"] = mapped_to # llama-pth + + # Attention key + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_K, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["model.layers."+str(i)+".self_attn.k_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wk"] = mapped_to # llama-pth + + # Attention value + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_V, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["model.layers."+str(i)+".self_attn.v_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wv"] = mapped_to # llama-pth + + # Attention output + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_OUT, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["gpt_neox.layers."+str(i)+".attention.dense"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".attn.c_proj"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".attn.out_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".self_attention.dense"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".self_attn.o_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.wo"] = mapped_to # llama-pth + + # Rotary embeddings + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.ATTN_ROT_EMBD, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["model.layers."+str(i)+".self_attn.rotary_emb.inv_freq"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".attention.inner_attention.rope.freqs"] = mapped_to # llama-pth + + # Feed-forward norm + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_NORM, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["gpt_neox.layers."+str(i)+".post_attention_layernorm"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".ln_2"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".norm_2"] = mapped_to # mpt + tensor_map["model.layers."+str(i)+".post_attention_layernorm"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".ffn_norm"] = mapped_to # llama-pth + + # Feed-forward up + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_UP, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".mlp.c_fc"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".ffn.up_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".mlp.dense_h_to_4h"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".mlp.up_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w3"] = mapped_to # llama-pth + + # Feed-forward gate + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_GATE, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["model.layers."+str(i)+".mlp.gate_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w1"] = mapped_to # llama-pth + + # Feed-forward down + mapped_to = MODEL_TENSOR_NAMES[arch].get(MODEL_TENSOR.FFN_DOWN, None) + mapped_to = mapped_to.format(bid=i) if mapped_to is not None else None + + tensor_map["gpt_neox.layers."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # gptneox + tensor_map["transformer.h."+str(i)+".mlp.c_proj"] = mapped_to # gpt2 + tensor_map["transformer.blocks."+str(i)+".ffn.down_proj"] = mapped_to # mpt + tensor_map["transformer.h."+str(i)+".mlp.dense_4h_to_h"] = mapped_to # falcon + tensor_map["model.layers."+str(i)+".mlp.down_proj"] = mapped_to # llama-hf + tensor_map["layers."+str(i)+".feed_forward.w2"] = mapped_to # llama-pth + + return tensor_map + + +class TokenType(IntEnum): + NORMAL = 1 + UNKNOWN = 2 + CONTROL = 3 + USER_DEFINED = 4 + UNUSED = 5 + BYTE = 6 + +# +# implementation +# + + +class GGMLQuantizationType(IntEnum): + F32 = 0 + F16 = 1 + Q4_0 = 2 + Q4_1 = 3 + Q5_0 = 6 + Q5_1 = 7 + Q8_0 = 8 + Q8_1 = 9 + Q2_K = 10 + Q3_K = 11 + Q4_K = 12 + Q5_K = 13 + Q6_K = 14 + Q8_K = 15 + + +class GGUFValueType(IntEnum): + UINT8 = 0 + INT8 = 1 + UINT16 = 2 + INT16 = 3 + UINT32 = 4 + INT32 = 5 + FLOAT32 = 6 + BOOL = 7 + STRING = 8 + ARRAY = 9 + + @staticmethod + def get_type(val): + if isinstance(val, str) or isinstance(val, bytes) or isinstance(val, bytearray): + return GGUFValueType.STRING + elif isinstance(val, list): + return GGUFValueType.ARRAY + elif isinstance(val, float): + return GGUFValueType.FLOAT32 + elif isinstance(val, bool): + return GGUFValueType.BOOL + elif isinstance(val, int): + return GGUFValueType.INT32 + else: + print("Unknown type: "+str(type(val))) + sys.exit() + + +class GGUFWriter: + def __init__(self, path: str, arch: str, use_temp_file = True): + self.fout = open(path, "wb") + self.arch = arch + self.offset_tensor = 0 + self.data_alignment = GGUF_DEFAULT_ALIGNMENT + self.kv_data = b"" + self.kv_data_count = 0 + self.ti_data = b"" + self.ti_data_count = 0 + self.add_architecture() + self.use_temp_file = use_temp_file + self.tensors = [] + + def write_header_to_file(self): + self.fout.write(struct.pack(" int: + return ((x + n - 1) // n) * n + + def add_tensor_info(self, name: str, tensor_shape: np.ndarray, tensor_dtype: np.dtype, tensor_nbytes: int, raw_dtype: Optional[GGMLQuantizationType] = None): + assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now" + + encoded_name = name.encode("utf8") + self.ti_data += struct.pack(" -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#ifdef __has_include - #if __has_include() - #include - #if defined(_POSIX_MAPPED_FILES) - #include - #endif - #if defined(_POSIX_MEMLOCK_RANGE) - #include - #endif - #endif -#endif - -#if defined(_WIN32) - #define WIN32_LEAN_AND_MEAN - #ifndef NOMINMAX - #define NOMINMAX - #endif - #include - #include - #include // for _fseeki64 -#endif - -#define LLAMA_ASSERT(x) \ - do { \ - if (!(x)) { \ - fprintf(stderr, "LLAMA_ASSERT: %s:%d: %s\n", __FILE__, __LINE__, #x); \ - abort(); \ - } \ - } while (0) - -#ifdef __GNUC__ -#ifdef __MINGW32__ -__attribute__((format(gnu_printf, 1, 2))) -#else -__attribute__((format(printf, 1, 2))) -#endif -#endif -static std::string format(const char * fmt, ...) { - va_list ap, ap2; - va_start(ap, fmt); - va_copy(ap2, ap); - int size = vsnprintf(NULL, 0, fmt, ap); - LLAMA_ASSERT(size >= 0 && size < INT_MAX); - std::vector buf(size + 1); - int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); - LLAMA_ASSERT(size2 == size); - va_end(ap2); - va_end(ap); - return std::string(buf.data(), size); -} - -struct llama_file { - // use FILE * so we don't have to re-open the file to mmap - FILE * fp; - size_t size; - - llama_file(const char * fname, const char * mode) { - fp = std::fopen(fname, mode); - if (fp == NULL) { - throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); - } - seek(0, SEEK_END); - size = tell(); - seek(0, SEEK_SET); - } - - size_t tell() const { -#ifdef _WIN32 - __int64 ret = _ftelli64(fp); -#else - long ret = std::ftell(fp); -#endif - LLAMA_ASSERT(ret != -1); // this really shouldn't fail - return (size_t) ret; - } - - void seek(size_t offset, int whence) { -#ifdef _WIN32 - int ret = _fseeki64(fp, (__int64) offset, whence); -#else - int ret = std::fseek(fp, (long) offset, whence); -#endif - LLAMA_ASSERT(ret == 0); // same - } - - void read_raw(void * ptr, size_t len) const { - if (len == 0) { - return; - } - errno = 0; - std::size_t ret = std::fread(ptr, len, 1, fp); - if (ferror(fp)) { - throw std::runtime_error(format("read error: %s", strerror(errno))); - } - if (ret != 1) { - throw std::runtime_error(std::string("unexpectedly reached end of file")); - } - } - - std::uint32_t read_u32() { - std::uint32_t ret; - read_raw(&ret, sizeof(ret)); - return ret; - } - - std::string read_string(std::uint32_t len) { - std::vector chars(len); - read_raw(chars.data(), len); - return std::string(chars.data(), len); - } - - void write_raw(const void * ptr, size_t len) const { - if (len == 0) { - return; - } - errno = 0; - size_t ret = std::fwrite(ptr, len, 1, fp); - if (ret != 1) { - throw std::runtime_error(format("write error: %s", strerror(errno))); - } - } - - void write_u32(std::uint32_t val) { - write_raw(&val, sizeof(val)); - } - - ~llama_file() { - if (fp) { - std::fclose(fp); - } - } -}; - -// llama_context_data -struct llama_data_context { - virtual void write(const void * src, size_t size) = 0; - virtual size_t get_size_written() = 0; - virtual ~llama_data_context() = default; -}; - -struct llama_data_buffer_context : llama_data_context { - uint8_t* ptr; - size_t size_written = 0; - - llama_data_buffer_context(uint8_t * p) : ptr(p) {} - - void write(const void * src, size_t size) override { - memcpy(ptr, src, size); - ptr += size; - size_written += size; - } - - size_t get_size_written() override { - return size_written; - } -}; - -struct llama_data_file_context : llama_data_context { - llama_file* file; - size_t size_written = 0; - - llama_data_file_context(llama_file * f) : file(f) {} - - void write(const void * src, size_t size) override { - file->write_raw(src, size); - size_written += size; - } - - size_t get_size_written() override { - return size_written; - } -}; - -#if defined(_WIN32) -static std::string llama_format_win_err(DWORD err) { - LPSTR buf; - size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); - if (!size) { - return "FormatMessageA failed"; - } - std::string ret(buf, size); - LocalFree(buf); - return ret; -} -#endif - -struct llama_mmap { - void * addr; - size_t size; - - llama_mmap(const llama_mmap &) = delete; - -#ifdef _POSIX_MAPPED_FILES - static constexpr bool SUPPORTED = true; - - llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) { - size = file->size; - int fd = fileno(file->fp); - int flags = MAP_SHARED; - // prefetch/readahead impairs performance on NUMA systems - if (numa) { prefetch = 0; } -#ifdef __linux__ - if (prefetch >= file->size) { flags |= MAP_POPULATE; } -#endif - addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); - if (addr == MAP_FAILED) { - throw std::runtime_error(format("mmap failed: %s", strerror(errno))); - } - - if (prefetch > 0) { - // Advise the kernel to preload the mapped memory - if (madvise(addr, std::min(file->size, prefetch), MADV_WILLNEED)) { - fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n", - strerror(errno)); - } - } - if (numa) { - // advise the kernel not to use readahead - // (because the next page might not belong on the same node) - if (madvise(addr, file->size, MADV_RANDOM)) { - fprintf(stderr, "warning: madvise(.., MADV_RANDOM) failed: %s\n", - strerror(errno)); - } - } - } - - ~llama_mmap() { - munmap(addr, size); - } -#elif defined(_WIN32) - static constexpr bool SUPPORTED = true; - - llama_mmap(struct llama_file * file, bool prefetch = true, bool numa = false) { - (void) numa; - - size = file->size; - - HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp)); - - HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); - DWORD error = GetLastError(); - - if (hMapping == NULL) { - throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); - } - - addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); - error = GetLastError(); - CloseHandle(hMapping); - - if (addr == NULL) { - throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); - } - - if (prefetch) { - // The PrefetchVirtualMemory API is only present on Windows 8 and above, so we - // will dynamically load it using GetProcAddress. - BOOL (WINAPI *pPrefetchVirtualMemory) (HANDLE, ULONG_PTR, PWIN32_MEMORY_RANGE_ENTRY, ULONG); - HMODULE hKernel32; - - // This call is guaranteed to succeed. - hKernel32 = GetModuleHandleW(L"kernel32.dll"); - - // This call may fail if on a pre-Win8 system. - pPrefetchVirtualMemory = reinterpret_cast (GetProcAddress(hKernel32, "PrefetchVirtualMemory")); - - if (pPrefetchVirtualMemory) { - // Advise the kernel to preload the mapped memory. - WIN32_MEMORY_RANGE_ENTRY range; - range.VirtualAddress = addr; - range.NumberOfBytes = (SIZE_T)size; - if (!pPrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { - fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); - } - } - } - } - - ~llama_mmap() { - if (!UnmapViewOfFile(addr)) { - fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); - } - } -#else - static constexpr bool SUPPORTED = false; - - llama_mmap(struct llama_file *, bool prefetch = true, bool numa = false) { - (void) prefetch; - (void) numa; - - throw std::runtime_error(std::string("mmap not supported")); - } -#endif -}; - -// Represents some region of memory being locked using mlock or VirtualLock; -// will automatically unlock on destruction. -struct llama_mlock { - void * addr = NULL; - size_t size = 0; - bool failed_already = false; - - llama_mlock() {} - llama_mlock(const llama_mlock &) = delete; - - ~llama_mlock() { - if (size) { - raw_unlock(addr, size); - } - } - - void init(void * ptr) { - LLAMA_ASSERT(addr == NULL && size == 0); - addr = ptr; - } - - void grow_to(size_t target_size) { - LLAMA_ASSERT(addr); - if (failed_already) { - return; - } - size_t granularity = lock_granularity(); - target_size = (target_size + granularity - 1) & ~(granularity - 1); - if (target_size > size) { - if (raw_lock((uint8_t *) addr + size, target_size - size)) { - size = target_size; - } else { - failed_already = true; - } - } - } - -#ifdef _POSIX_MEMLOCK_RANGE - static constexpr bool SUPPORTED = true; - - size_t lock_granularity() { - return (size_t) sysconf(_SC_PAGESIZE); - } - - #ifdef __APPLE__ - #define MLOCK_SUGGESTION \ - "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ - "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n" - #else - #define MLOCK_SUGGESTION \ - "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n" - #endif - - bool raw_lock(const void * addr, size_t size) { - if (!mlock(addr, size)) { - return true; - } else { - char* errmsg = std::strerror(errno); - bool suggest = (errno == ENOMEM); - - // Check if the resource limit is fine after all - struct rlimit lock_limit; - if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) - suggest = false; - if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) - suggest = false; - - fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", - size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); - return false; - } - } - - #undef MLOCK_SUGGESTION - - void raw_unlock(void * addr, size_t size) { - if (munlock(addr, size)) { - fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno)); - } - } -#elif defined(_WIN32) - static constexpr bool SUPPORTED = true; - - size_t lock_granularity() { - SYSTEM_INFO si; - GetSystemInfo(&si); - return (size_t) si.dwPageSize; - } - - bool raw_lock(void * ptr, size_t len) { - for (int tries = 1; ; tries++) { - if (VirtualLock(ptr, len)) { - return true; - } - if (tries == 2) { - fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", - len, size, llama_format_win_err(GetLastError()).c_str()); - return false; - } - - // It failed but this was only the first try; increase the working - // set size and try again. - SIZE_T min_ws_size, max_ws_size; - if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) { - fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); - return false; - } - // Per MSDN: "The maximum number of pages that a process can lock - // is equal to the number of pages in its minimum working set minus - // a small overhead." - // Hopefully a megabyte is enough overhead: - size_t increment = len + 1048576; - // The minimum must be <= the maximum, so we need to increase both: - min_ws_size += increment; - max_ws_size += increment; - if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) { - fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); - return false; - } - } - } - - void raw_unlock(void * ptr, size_t len) { - if (!VirtualUnlock(ptr, len)) { - fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n", - llama_format_win_err(GetLastError()).c_str()); - } - } -#else - static constexpr bool SUPPORTED = false; - - size_t lock_granularity() { - return (size_t) 65536; - } - - bool raw_lock(const void * addr, size_t len) { - fprintf(stderr, "warning: mlock not supported on this system\n"); - return false; - } - - void raw_unlock(const void * addr, size_t len) {} -#endif -}; - -// Replacement for std::vector that doesn't require zero-initialization. -struct llama_buffer { - uint8_t * addr = NULL; - size_t size = 0; - - llama_buffer() = default; - - void resize(size_t len) { -#ifdef GGML_USE_METAL - free(addr); - int result = posix_memalign((void **) &addr, getpagesize(), len); - if (result == 0) { - memset(addr, 0, len); - } - else { - addr = NULL; - } -#else - delete[] addr; - addr = new uint8_t[len]; -#endif - size = len; - } - - ~llama_buffer() { -#ifdef GGML_USE_METAL - free(addr); -#else - delete[] addr; -#endif - addr = NULL; - } - - // disable copy and move - llama_buffer(const llama_buffer&) = delete; - llama_buffer(llama_buffer&&) = delete; - llama_buffer& operator=(const llama_buffer&) = delete; - llama_buffer& operator=(llama_buffer&&) = delete; -}; - -#ifdef GGML_USE_CUBLAS -#include "ggml-cuda.h" -struct llama_ctx_buffer { - uint8_t * addr = NULL; - bool is_cuda; - size_t size = 0; - - llama_ctx_buffer() = default; - - void resize(size_t size) { - free(); - - addr = (uint8_t *) ggml_cuda_host_malloc(size); - if (addr) { - is_cuda = true; - } - else { - // fall back to pageable memory - addr = new uint8_t[size]; - is_cuda = false; - } - this->size = size; - } - - void free() { - if (addr) { - if (is_cuda) { - ggml_cuda_host_free(addr); - } - else { - delete[] addr; - } - } - addr = NULL; - } - - ~llama_ctx_buffer() { - free(); - } - - // disable copy and move - llama_ctx_buffer(const llama_ctx_buffer&) = delete; - llama_ctx_buffer(llama_ctx_buffer&&) = delete; - llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete; - llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete; -}; -#else -typedef llama_buffer llama_ctx_buffer; -#endif - -#endif diff --git a/llama.cpp b/llama.cpp index f2cbe7641..c97aaee69 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6,94 +6,146 @@ #include #endif -#include "llama-util.h" #include "llama.h" #include "ggml.h" + +#if !defined(GGML_USE_CUBLAS) +# include "ggml-alloc.h" +# define LLAMA_USE_ALLOCATOR +#else +# define LLAMA_USE_SCRATCH +# define LLAMA_MAX_SCRATCH_BUFFERS 16 +#endif + #ifdef GGML_USE_CUBLAS -#include "ggml-cuda.h" +# include "ggml-cuda.h" #elif defined(GGML_USE_CLBLAST) -#include "ggml-opencl.h" +# include "ggml-opencl.h" #endif #ifdef GGML_USE_METAL -#include "ggml-metal.h" +# include "ggml-metal.h" #endif #ifdef GGML_USE_MPI -#include "ggml-mpi.h" +# include "ggml-mpi.h" #endif #ifdef GGML_USE_K_QUANTS -#ifndef QK_K -#ifdef GGML_QKK_64 -#define QK_K 64 -#else -#define QK_K 256 -#endif -#endif +# ifndef QK_K +# ifdef GGML_QKK_64 +# define QK_K 64 +# else +# define QK_K 256 +# endif +# endif +#endif + +#ifdef __has_include + #if __has_include() + #include + #if defined(_POSIX_MAPPED_FILES) + #include + #endif + #if defined(_POSIX_MEMLOCK_RANGE) + #include + #endif + #endif +#endif + +#if defined(_WIN32) + #define WIN32_LEAN_AND_MEAN + #ifndef NOMINMAX + #define NOMINMAX + #endif + #include + #include + #include // for _fseeki64 #endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include +#include +#include #include -#include -#include +#include +#include #include -#include #include +#include +#include +#include +#include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data #endif -static void llama_log_internal(llama_log_level level, const char* format, ...); +// tensor names +#define TN_TOKEN_EMBD "token_embd.weight" +#define TN_OUTPUT_NORM "output_norm.weight" +#define TN_OUTPUT "output.weight" +#define TN_ATTN_NORM "blk.%d.attn_norm.weight" +#define TN_ATTN_Q "blk.%d.attn_q.weight" +#define TN_ATTN_K "blk.%d.attn_k.weight" +#define TN_ATTN_V "blk.%d.attn_v.weight" +#define TN_ATTN_OUTPUT "blk.%d.attn_output.weight" +#define TN_FFN_NORM "blk.%d.ffn_norm.weight" +#define TN_FFN_GATE "blk.%d.ffn_gate.weight" +#define TN_FFN_DOWN "blk.%d.ffn_down.weight" +#define TN_FFN_UP "blk.%d.ffn_up.weight" + +#ifdef __GNUC__ +#ifdef __MINGW32__ +#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(gnu_printf, __VA_ARGS__))) +#else +#define LLAMA_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__))) +#endif +#else +#define LLAMA_ATTRIBUTE_FORMAT(...) +#endif + +// +// logging +// +LLAMA_ATTRIBUTE_FORMAT(2, 3) +static void llama_log_internal (llama_log_level level, const char* format, ...); static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data); + #define LLAMA_LOG_INFO(...) llama_log_internal(LLAMA_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(LLAMA_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__) +// +// helpers +// -#if !defined(GGML_USE_CUBLAS) -#include "ggml-alloc.h" -#define LLAMA_USE_ALLOCATOR -#else -#define LLAMA_USE_SCRATCH -#define LLAMA_MAX_SCRATCH_BUFFERS 16 -#endif +static void zeros(std::ofstream & file, size_t n) { + char zero = 0; + for (size_t i = 0; i < n; ++i) { + file.write(&zero, 1); + } +} - -// available llama models -enum e_model { - MODEL_UNKNOWN, - MODEL_3B, - MODEL_7B, - MODEL_13B, - MODEL_30B, - MODEL_65B, - MODEL_70B, -}; - -static const size_t kB = 1024; -static const size_t MB = 1024*1024; - -// computed for n_ctx == 2048 -// TODO: dynamically determine these sizes -// needs modifications in ggml - -typedef void (*offload_func_t)(struct ggml_tensor * tensor); - -void llama_nop(struct ggml_tensor * tensor) { // don't offload by default - (void) tensor; +LLAMA_ATTRIBUTE_FORMAT(1, 2) +static std::string format(const char * fmt, ...) { + va_list ap; + va_list ap2; + va_start(ap, fmt); + va_copy(ap2, ap); + int size = vsnprintf(NULL, 0, fmt, ap); + GGML_ASSERT(size >= 0 && size < INT_MAX); // NOLINT + std::vector buf(size + 1); + int size2 = vsnprintf(buf.data(), size + 1, fmt, ap2); + GGML_ASSERT(size2 == size); + va_end(ap2); + va_end(ap); + return std::string(buf.data(), size); } // @@ -111,10 +163,453 @@ static void ggml_graph_compute_helper(std::vector & buf, ggml_cgraph * ggml_graph_compute(graph, &plan); } +// +// llama helpers +// + +#ifdef GGML_USE_CUBLAS +# define llama_host_malloc(n) ggml_cuda_host_malloc(n) +# define llama_host_free(data) ggml_cuda_host_free(data) +#elif GGML_USE_METAL +# define llama_host_malloc(n) ggml_metal_host_malloc(n) +# define llama_host_free(data) ggml_metal_host_free(data) +#else +# define llama_host_malloc(n) malloc(n) +# define llama_host_free(data) free(data) +#endif + +#if defined(_WIN32) +static std::string llama_format_win_err(DWORD err) { + LPSTR buf; + size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, + NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); + if (!size) { + return "FormatMessageA failed"; + } + std::string ret(buf, size); + LocalFree(buf); + return ret; +} +#endif + +struct llama_buffer { + void * data = NULL; + size_t size = 0; + + // fallback to malloc / free + // useful in cases where CUDA can try to allocate PINNED memory + bool fallback = false; + + void resize(size_t n) { + llama_host_free(data); + + data = llama_host_malloc(n); + if (!data) { + fallback = true; + data = malloc(n); + } else { + fallback = false; + } + + GGML_ASSERT(data); + size = n; + } + + ~llama_buffer() { + if (data) { + if (fallback) { // NOLINT + free(data); + } else { + llama_host_free(data); + } + } + + data = NULL; + } +}; + +struct llama_file { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + llama_file(const char * fname, const char * mode) { + fp = std::fopen(fname, mode); + if (fp == NULL) { + throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + } + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + + size_t tell() const { +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + GGML_ASSERT(ret != -1); // this really shouldn't fail + return (size_t) ret; + } + + void seek(size_t offset, int whence) const { +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + GGML_ASSERT(ret == 0); // same + } + + void read_raw(void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, len, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error(std::string("unexpectedly reached end of file")); + } + } + + uint32_t read_u32() const { + uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + void write_raw(const void * ptr, size_t len) const { + if (len == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, len, 1, fp); + if (ret != 1) { + throw std::runtime_error(format("write error: %s", strerror(errno))); + } + } + + void write_u32(std::uint32_t val) const { + write_raw(&val, sizeof(val)); + } + + ~llama_file() { + if (fp) { + std::fclose(fp); + } + } +}; + +struct llama_mmap { + void * addr; + size_t size; + + llama_mmap(const llama_mmap &) = delete; + +#ifdef _POSIX_MAPPED_FILES + static constexpr bool SUPPORTED = true; + + llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) { + size = file->size; + int fd = fileno(file->fp); + int flags = MAP_SHARED; + // prefetch/readahead impairs performance on NUMA systems + if (numa) { prefetch = 0; } +#ifdef __linux__ + if (prefetch) { flags |= MAP_POPULATE; } +#endif + addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0); + if (addr == MAP_FAILED) { + throw std::runtime_error(format("mmap failed: %s", strerror(errno))); + } + + if (prefetch > 0) { + // Advise the kernel to preload the mapped memory + if (madvise(addr, std::min(file->size, prefetch), MADV_WILLNEED)) { + fprintf(stderr, "warning: madvise(.., MADV_WILLNEED) failed: %s\n", + strerror(errno)); + } + } + if (numa) { + // advise the kernel not to use readahead + // (because the next page might not belong on the same node) + if (madvise(addr, file->size, MADV_RANDOM)) { + fprintf(stderr, "warning: madvise(.., MADV_RANDOM) failed: %s\n", + strerror(errno)); + } + } + } + + ~llama_mmap() { + munmap(addr, size); + } +#elif defined(_WIN32) + static constexpr bool SUPPORTED = true; + + llama_mmap(struct llama_file * file, bool prefetch = true, bool numa = false) { + (void) numa; + + size = file->size; + + HANDLE hFile = (HANDLE) _get_osfhandle(_fileno(file->fp)); + + HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + DWORD error = GetLastError(); + + if (hMapping == NULL) { + throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); + } + + addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); + error = GetLastError(); + CloseHandle(hMapping); + + if (addr == NULL) { + throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); + } + + #if _WIN32_WINNT >= _WIN32_WINNT_WIN8 + if (prefetch) { + // Advise the kernel to preload the mapped memory + WIN32_MEMORY_RANGE_ENTRY range; + range.VirtualAddress = addr; + range.NumberOfBytes = (SIZE_T)size; + if (!PrefetchVirtualMemory(GetCurrentProcess(), 1, &range, 0)) { + fprintf(stderr, "warning: PrefetchVirtualMemory failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } + #else + #pragma message("warning: You are building for pre-Windows 8; prefetch not supported") + #endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8 + } + + ~llama_mmap() { + if (!UnmapViewOfFile(addr)) { + fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static constexpr bool SUPPORTED = false; + + llama_mmap(struct llama_file * file, bool prefetch = true, bool numa = false) { + (void) file; + (void) prefetch; + (void) numa; + + throw std::runtime_error(std::string("mmap not supported")); + } +#endif +}; + +// Represents some region of memory being locked using mlock or VirtualLock; +// will automatically unlock on destruction. +struct llama_mlock { + void * addr = NULL; + size_t size = 0; + + bool failed_already = false; + + llama_mlock() {} + llama_mlock(const llama_mlock &) = delete; + + ~llama_mlock() { + if (size) { + raw_unlock(addr, size); + } + } + + void init(void * ptr) { + GGML_ASSERT(addr == NULL && size == 0); // NOLINT + addr = ptr; + } + + void grow_to(size_t target_size) { + GGML_ASSERT(addr); + if (failed_already) { + return; + } + size_t granularity = lock_granularity(); + target_size = (target_size + granularity - 1) & ~(granularity - 1); + if (target_size > size) { + if (raw_lock((uint8_t *) addr + size, target_size - size)) { + size = target_size; + } else { + failed_already = true; + } + } + } + +#ifdef _POSIX_MEMLOCK_RANGE + static constexpr bool SUPPORTED = true; + + static size_t lock_granularity() { + return (size_t) sysconf(_SC_PAGESIZE); + } + + #ifdef __APPLE__ + #define MLOCK_SUGGESTION \ + "Try increasing the sysctl values 'vm.user_wire_limit' and 'vm.global_user_wire_limit' and/or " \ + "decreasing 'vm.global_no_user_wire_amount'. Also try increasing RLIMIT_MLOCK (ulimit -l).\n" + #else + #define MLOCK_SUGGESTION \ + "Try increasing RLIMIT_MLOCK ('ulimit -l' as root).\n" + #endif + + bool raw_lock(const void * addr, size_t size) const { + if (!mlock(addr, size)) { + return true; + } + + char* errmsg = std::strerror(errno); + bool suggest = (errno == ENOMEM); + + // Check if the resource limit is fine after all + struct rlimit lock_limit; + if (suggest && getrlimit(RLIMIT_MEMLOCK, &lock_limit)) { + suggest = false; + } + if (suggest && (lock_limit.rlim_max > lock_limit.rlim_cur + size)) { + suggest = false; + } + + fprintf(stderr, "warning: failed to mlock %zu-byte buffer (after previously locking %zu bytes): %s\n%s", + size, this->size, errmsg, suggest ? MLOCK_SUGGESTION : ""); + return false; + } + + #undef MLOCK_SUGGESTION + + static void raw_unlock(void * addr, size_t size) { + if (munlock(addr, size)) { + fprintf(stderr, "warning: failed to munlock buffer: %s\n", std::strerror(errno)); + } + } +#elif defined(_WIN32) + static constexpr bool SUPPORTED = true; + + static size_t lock_granularity() { + SYSTEM_INFO si; + GetSystemInfo(&si); + return (size_t) si.dwPageSize; + } + + bool raw_lock(void * ptr, size_t len) const { + for (int tries = 1; ; tries++) { + if (VirtualLock(ptr, len)) { + return true; + } + if (tries == 2) { + fprintf(stderr, "warning: failed to VirtualLock %zu-byte buffer (after previously locking %zu bytes): %s\n", + len, size, llama_format_win_err(GetLastError()).c_str()); + return false; + } + + // It failed but this was only the first try; increase the working + // set size and try again. + SIZE_T min_ws_size, max_ws_size; + if (!GetProcessWorkingSetSize(GetCurrentProcess(), &min_ws_size, &max_ws_size)) { + fprintf(stderr, "warning: GetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + // Per MSDN: "The maximum number of pages that a process can lock + // is equal to the number of pages in its minimum working set minus + // a small overhead." + // Hopefully a megabyte is enough overhead: + size_t increment = len + 1048576; + // The minimum must be <= the maximum, so we need to increase both: + min_ws_size += increment; + max_ws_size += increment; + if (!SetProcessWorkingSetSize(GetCurrentProcess(), min_ws_size, max_ws_size)) { + fprintf(stderr, "warning: SetProcessWorkingSetSize failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + return false; + } + } + } + + static void raw_unlock(void * ptr, size_t len) { + if (!VirtualUnlock(ptr, len)) { + fprintf(stderr, "warning: failed to VirtualUnlock buffer: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } +#else + static constexpr bool SUPPORTED = false; + + static size_t lock_granularity() { + return (size_t) 65536; + } + + bool raw_lock(const void * addr, size_t len) const { + fprintf(stderr, "warning: mlock not supported on this system\n"); + return false; + } + + static void raw_unlock(const void * addr, size_t len) {} +#endif +}; + +typedef void (*offload_func_t)(struct ggml_tensor * tensor); + +static void llama_nop(struct ggml_tensor * tensor) { // don't offload by default + (void) tensor; +} + +static std::string llama_token_to_text(const struct llama_context * ctx, llama_token token) { + std::vector result(8, 0); + const int n_tokens = llama_token_to_str(ctx, token, result.data(), result.size()); + if (n_tokens < 0) { + result.resize(-n_tokens); + int check = llama_token_to_str(ctx, token, result.data(), result.size()); + GGML_ASSERT(check == -n_tokens); + } else { + result.resize(n_tokens); + } + + return std::string(result.data(), result.size()); +} + +// +// globals +// + +struct llama_state { + // We save the log callback globally + llama_log_callback log_callback = llama_log_callback_default; + void * log_callback_user_data = nullptr; +}; + +static llama_state g_state; + // // memory sizes (calculated for n_batch == 512) // +// computed for n_ctx == 2048 +// TODO: dynamically determine these sizes +// needs modifications in ggml + +// available llama models +enum e_model { + MODEL_UNKNOWN, + MODEL_3B, + MODEL_7B, + MODEL_13B, + MODEL_30B, + MODEL_65B, + MODEL_70B, +}; + +static const size_t kB = 1024; +static const size_t MB = 1024*1024; + static std::map MEM_REQ_SCRATCH0(int n_ctx) { std::map k_sizes = { @@ -187,25 +682,21 @@ static const std::map & VRAM_REQ_SCRATCH_PER_CONTEXT() // default hparams (LLaMA 7B) struct llama_hparams { - uint32_t n_vocab = 32000; - uint32_t n_ctx = 512; // this is provided as user input? - uint32_t n_embd = 4096; - uint32_t n_mult = 256; - uint32_t n_head = 32; - uint32_t n_head_kv = 32; - uint32_t n_layer = 32; - uint32_t n_rot = 64; + uint32_t n_vocab = 32000; + uint32_t n_ctx_train = 2048; // the context size used during training + uint32_t n_ctx = 512; // the context size used during inference + uint32_t n_embd = 4096; + uint32_t n_head = 32; + uint32_t n_head_kv = 32; + uint32_t n_layer = 32; + uint32_t n_rot = 64; + uint32_t n_ff = 11008; - // LLaMAv2 - // TODO: load from model data hparams - float f_ffn_mult = 1.0f; - float f_rms_norm_eps = LLAMA_DEFAULT_RMS_EPS; + float f_norm_rms_eps = 1e-5; float rope_freq_base = 10000.0f; float rope_freq_scale = 1.0f; - enum llama_ftype ftype = LLAMA_FTYPE_MOSTLY_F16; - bool operator!=(const llama_hparams & other) const { return static_cast(memcmp(this, &other, sizeof(llama_hparams))); // NOLINT } @@ -257,7 +748,7 @@ struct llama_kv_cache { struct ggml_context * ctx = NULL; - llama_ctx_buffer buf; + llama_buffer buf; int n; // number of tokens currently in the cache @@ -274,22 +765,41 @@ struct llama_kv_cache { }; struct llama_vocab { + // TODO: + // - add a vector of merges + // so that we can pass it to different types of tokenizers with a common interface + using id = int32_t; using token = std::string; + using ttype = llama_token_type; - struct token_score { - token tok; + struct token_data { + token text; float score; + ttype type; }; + llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; + std::unordered_map token_to_id; - std::vector id_to_token; + std::vector id_to_token; + + // default LLaMA special tokens + id special_bos_id = 1; + id special_eos_id = 2; + id special_unk_id = -1; + id special_sep_id = -1; + id special_pad_id = -1; + + id linefeed_id = 13; }; struct llama_model { - e_model type = MODEL_UNKNOWN; + e_model type = MODEL_UNKNOWN; + llama_ftype ftype = LLAMA_FTYPE_ALL_F32; llama_hparams hparams; + llama_vocab vocab; struct ggml_tensor * tok_embeddings; @@ -303,7 +813,7 @@ struct llama_model { struct ggml_context * ctx = NULL; // the model memory buffer - llama_ctx_buffer buf; + llama_buffer buf; // model memory mapped file std::unique_ptr mapping; @@ -318,8 +828,6 @@ struct llama_model { int64_t t_load_us = 0; int64_t t_start_us = 0; - llama_vocab vocab; - ~llama_model() { if (ctx) { ggml_free(ctx); @@ -391,16 +899,16 @@ struct llama_context { std::vector work_buffer; // memory buffers used to evaluate the model - // TODO: move in llama_state - llama_ctx_buffer buf_compute; + llama_buffer buf_compute; #ifdef LLAMA_USE_ALLOCATOR - llama_ctx_buffer buf_alloc; + llama_buffer buf_alloc; ggml_allocr * alloc = NULL; #endif #ifdef LLAMA_USE_SCRATCH - llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; + llama_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; + int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; #endif @@ -413,7 +921,7 @@ struct llama_context { ggml_mpi_context * ctx_mpi = NULL; #endif - void use_buf(struct ggml_context * ctx, int i) { + void use_buf(struct ggml_context * ctx, int i) { // NOLINT #if defined(LLAMA_USE_SCRATCH) size_t last_size = 0; @@ -421,7 +929,7 @@ struct llama_context { last_size = ggml_set_scratch(ctx, { 0, 0, nullptr, }); } else { auto & buf = buf_scratch[i]; - last_size = ggml_set_scratch(ctx, { 0, buf.size, buf.addr, }); + last_size = ggml_set_scratch(ctx, { 0, buf.size, buf.data, }); } if (buf_last >= 0) { @@ -435,7 +943,7 @@ struct llama_context { #endif } - size_t get_buf_max_mem(int i) const { + size_t get_buf_max_mem(int i) { // NOLINT #if defined(LLAMA_USE_SCRATCH) return buf_max_size[i]; #else @@ -445,418 +953,11 @@ struct llama_context { } }; -struct llama_state { - // We save the log callback globally - llama_log_callback log_callback = llama_log_callback_default; - void * log_callback_user_data = nullptr; -}; -// global state -static llama_state g_state; - -template -static T checked_mul(T a, T b) { - T ret = a * b; - if (a != 0 && ret / a != b) { - throw std::runtime_error(format("overflow multiplying %llu * %llu", - (unsigned long long) a, (unsigned long long) b)); - } - return ret; -} - -static size_t checked_div(size_t a, size_t b) { - if (b == 0 || a % b != 0) { - throw std::runtime_error(format("error dividing %zu / %zu", a, b)); - } - return a / b; -} - -static std::string llama_format_tensor_shape(const std::vector & ne) { - char buf[256]; - snprintf(buf, sizeof(buf), "%5u", ne.at(0)); - for (size_t i = 1; i < ne.size(); i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), " x %5u", ne.at(i)); - } - return buf; -} - -static size_t llama_calc_tensor_size(const std::vector & ne, enum ggml_type type) { - size_t size = ggml_type_size(type); - for (uint32_t dim : ne) { - size = checked_mul(size, dim); - } - return size / ggml_blck_size(type); -} - -struct llama_load_tensor { - std::string name; - enum ggml_type type = GGML_TYPE_F32; - std::vector ne; - size_t file_off; - size_t size; - struct ggml_tensor * ggml_tensor = NULL; - uint8_t * data; -}; - -struct llama_load_tensors_map { - // tensors is kept in a separate vector to preserve file order - std::vector tensors; - std::unordered_map name_to_idx; -}; - -enum llama_file_version { - LLAMA_FILE_VERSION_GGML, - LLAMA_FILE_VERSION_GGMF_V1, // added version field and scores in vocab - LLAMA_FILE_VERSION_GGJT_V1, // added padding - LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format - LLAMA_FILE_VERSION_GGJT_V3, // changed Q4 and Q8 quantization format -}; - -struct llama_file_loader { - llama_file file; - llama_file_version file_version; - llama_hparams hparams; - llama_vocab vocab; - - llama_file_loader(const char * fname, llama_load_tensors_map & tensors_map) - : file(fname, "rb") { - LLAMA_LOG_INFO("llama.cpp: loading model from %s\n", fname); - read_magic(); - read_hparams(); - read_vocab(); - read_tensor_metadata(tensors_map); - } - void read_magic() { - uint32_t magic = file.read_u32(); - - if (magic == LLAMA_FILE_MAGIC_GGML) { - file_version = LLAMA_FILE_VERSION_GGML; - return; - } - - uint32_t version = file.read_u32(); - - switch (magic) { - case LLAMA_FILE_MAGIC_GGMF: - switch (version) { - case 1: file_version = LLAMA_FILE_VERSION_GGMF_V1; return; - } - break; - case LLAMA_FILE_MAGIC_GGJT: - switch (version) { - case 1: file_version = LLAMA_FILE_VERSION_GGJT_V1; return; - case 2: file_version = LLAMA_FILE_VERSION_GGJT_V2; return; - case 3: file_version = LLAMA_FILE_VERSION_GGJT_V3; return; - } - } - - throw std::runtime_error(format("unknown (magic, version) combination: %08x, %08x; is this really a GGML file?", - magic, version)); - } - void read_hparams() { - hparams.n_vocab = file.read_u32(); - hparams.n_embd = file.read_u32(); - hparams.n_mult = file.read_u32(); - hparams.n_head = file.read_u32(); - hparams.n_layer = file.read_u32(); - hparams.n_rot = file.read_u32(); - hparams.ftype = (enum llama_ftype) file.read_u32(); - - // LLaMAv2 - // TODO: read from header - hparams.n_head_kv = hparams.n_head; - } - void read_vocab() { - vocab.id_to_token.resize(hparams.n_vocab); - - for (uint32_t i = 0; i < hparams.n_vocab; i++) { - uint32_t len = file.read_u32(); - std::string word = file.read_string(len); - - float score = 0.0f; - file.read_raw(&score, sizeof(score)); - - vocab.token_to_id[word] = i; - - auto & tok_score = vocab.id_to_token[i]; - tok_score.tok = std::move(word); - tok_score.score = score; - } - } - void read_tensor_metadata(llama_load_tensors_map & tensors_map) { - while (file.tell() < file.size) { - llama_load_tensor tensor; - uint32_t n_dims = file.read_u32(); - uint32_t name_len = file.read_u32(); - tensor.type = (enum ggml_type) file.read_u32(); - tensor.ne.resize(n_dims); - file.read_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * n_dims); - std::string name = file.read_string(name_len); - if (n_dims < 1 || n_dims > 2) { - throw std::runtime_error(format("llama.cpp: tensor '%s' should not be %u-dimensional", name.c_str(), n_dims)); - } - switch (tensor.type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - break; - default: { - throw std::runtime_error(format("unrecognized tensor type %u\n", tensor.type)); - } - } - - // skip to the next multiple of 32 bytes - if (file_version >= LLAMA_FILE_VERSION_GGJT_V1) { - file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); - } - - tensor.file_off = file.tell(); - tensor.name = name; - tensor.size = llama_calc_tensor_size(tensor.ne, tensor.type); - file.seek(tensor.size, SEEK_CUR); - - tensors_map.tensors.push_back(tensor); - tensors_map.name_to_idx[name] = tensors_map.tensors.size() - 1; - } - } -}; - -struct llama_file_saver { - llama_file file; - llama_file_loader * any_file_loader; - llama_file_saver(const char * fname, llama_file_loader * any_file_loader, enum llama_ftype new_ftype) - : file(fname, "wb"), any_file_loader(any_file_loader) { - LLAMA_LOG_INFO("llama.cpp: saving model to %s\n", fname); - write_magic(); - write_hparams(new_ftype); - write_vocab(); - } - void write_magic() { - file.write_u32(LLAMA_FILE_MAGIC); // magic - file.write_u32(LLAMA_FILE_VERSION); // version - } - void write_hparams(enum llama_ftype new_ftype) { - const llama_hparams & hparams = any_file_loader->hparams; - file.write_u32(hparams.n_vocab); - file.write_u32(hparams.n_embd); - file.write_u32(hparams.n_mult); - file.write_u32(hparams.n_head); - file.write_u32(hparams.n_layer); - file.write_u32(hparams.n_rot); - file.write_u32(new_ftype); - } - void write_vocab() { - if (any_file_loader->file_version == LLAMA_FILE_VERSION_GGML) { - LLAMA_LOG_WARN("llama.cpp: WARNING: input is an old file that doesn't have scores; will add dummy scores\n"); - } - uint32_t n_vocab = any_file_loader->hparams.n_vocab; - for (uint32_t i = 0; i < n_vocab; i++) { - const auto & token_score = any_file_loader->vocab.id_to_token.at(i); - file.write_u32((uint32_t) token_score.tok.size()); - file.write_raw(token_score.tok.data(), token_score.tok.size()); - file.write_raw(&token_score.score, sizeof(token_score.score)); - } - } - void write_tensor(llama_load_tensor & tensor, enum ggml_type new_type, const void * new_data, size_t new_size) { - switch (new_type) { - case GGML_TYPE_F32: - case GGML_TYPE_F16: - case GGML_TYPE_Q4_0: - case GGML_TYPE_Q4_1: - case GGML_TYPE_Q5_0: - case GGML_TYPE_Q5_1: - case GGML_TYPE_Q8_0: - case GGML_TYPE_Q2_K: - case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: - case GGML_TYPE_Q5_K: - case GGML_TYPE_Q6_K: - break; - default: LLAMA_ASSERT(false); - } - file.write_u32((uint32_t) tensor.ne.size()); - file.write_u32((uint32_t) tensor.name.size()); - file.write_u32(new_type); - file.write_raw(tensor.ne.data(), sizeof(tensor.ne[0]) * tensor.ne.size()); - file.write_raw(tensor.name.data(), tensor.name.size()); - file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); - LLAMA_ASSERT(new_size == llama_calc_tensor_size(tensor.ne, new_type)); - file.write_raw(new_data, new_size); - } -}; - -struct llama_model_loader { - std::unique_ptr file_loader; - llama_load_tensors_map tensors_map; - bool use_mmap; - size_t num_ggml_tensors_created = 0; - struct ggml_context * ggml_ctx = NULL; - std::unique_ptr mapping; - - llama_model_loader(const std::string & fname_base, bool use_mmap) { - file_loader = std::unique_ptr(new llama_file_loader(fname_base.c_str(), tensors_map)); - if (!llama_mmap::SUPPORTED) { - use_mmap = false; - } - this->use_mmap = use_mmap; - } - - void calc_sizes(size_t * ctx_size_p, size_t * mmapped_size_p) const { - *ctx_size_p = *mmapped_size_p = 0; - for (const llama_load_tensor & lt : tensors_map.tensors) { - *ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE; - *(use_mmap ? mmapped_size_p : ctx_size_p) += lt.size + 16; - } - } - - struct ggml_tensor * get_tensor(const std::string & name, const std::vector & ne, ggml_backend backend) { - auto it = tensors_map.name_to_idx.find(name); - if (it == tensors_map.name_to_idx.end()) { - throw std::runtime_error(std::runtime_error(format("llama.cpp: tensor '%s' is missing from model", name.c_str()))); - } - llama_load_tensor & lt = tensors_map.tensors.at(it->second); - if (lt.ne != ne) { - throw std::runtime_error(format("llama.cpp: tensor '%s' has wrong shape; expected %s, got %s", - name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str())); - } - - return get_tensor_for(lt, backend); - } - - struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) { - struct ggml_tensor * tensor; - if (backend != GGML_BACKEND_CPU) { - ggml_set_no_alloc(ggml_ctx, true); - } - if (lt.ne.size() == 2) { - tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1)); - } else { - LLAMA_ASSERT(lt.ne.size() == 1); - tensor = ggml_new_tensor_1d(ggml_ctx, lt.type, lt.ne.at(0)); - } - ggml_set_name(tensor, lt.name.c_str()); - LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor - - if (backend != GGML_BACKEND_CPU) { - ggml_set_no_alloc(ggml_ctx, use_mmap); - } - tensor->backend = backend; - lt.ggml_tensor = tensor; - num_ggml_tensors_created++; - return tensor; - } - - void done_getting_tensors() const { - if (num_ggml_tensors_created != tensors_map.tensors.size()) { - throw std::runtime_error(std::string("llama.cpp: file contained more tensors than expected")); - } - } - - void load_all_data(llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { - size_t data_size = 0; - size_t prefetch_size = file_loader->file.size; - size_t lock_size = 0; - for (const llama_load_tensor & lt : tensors_map.tensors) { - data_size += lt.size; - if (lt.ggml_tensor->backend != GGML_BACKEND_CPU) { - prefetch_size -= lt.size; - } - } - - if (use_mmap) { - mapping.reset(new llama_mmap(&file_loader->file, prefetch_size, ggml_is_numa())); - if (lmlock) { - lmlock->init(mapping->addr); - } - } - - size_t done_size = 0; - for (llama_load_tensor & lt : tensors_map.tensors) { - if (progress_callback) { - progress_callback((float) done_size / data_size, progress_callback_user_data); - } - LLAMA_ASSERT(lt.ggml_tensor); // unused tensors should have been caught by load_data already - lt.data = (uint8_t *) lt.ggml_tensor->data; - - // allocate temp buffer if not using mmap - if (!use_mmap && lt.data == NULL) { - GGML_ASSERT(lt.ggml_tensor->backend != GGML_BACKEND_CPU); - lt.data = (uint8_t*)malloc(ggml_nbytes(lt.ggml_tensor)); - } - - load_data_for(lt); - - switch(lt.ggml_tensor->backend) { - case GGML_BACKEND_CPU: - lt.ggml_tensor->data = lt.data; - if (use_mmap && lmlock) { - lock_size += lt.size; - lmlock->grow_to(lock_size); - } - break; -#if defined(GGML_USE_CUBLAS) - case GGML_BACKEND_GPU: - case GGML_BACKEND_GPU_SPLIT: - ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor); - if (!use_mmap) { - free(lt.data); - } - break; -#elif defined(GGML_USE_CLBLAST) - case GGML_BACKEND_GPU: - ggml_cl_transform_tensor(lt.data, lt.ggml_tensor); - if (!use_mmap) { - free(lt.data); - } - break; -#endif - default: - continue; - } - - done_size += lt.size; - } - } - - void load_data_for(llama_load_tensor & lt) { - if (use_mmap) { - lt.data = (uint8_t *) mapping->addr + lt.file_off; - } else { - llama_file & file = file_loader->file; - file.seek(lt.file_off, SEEK_SET); - file.read_raw(lt.data, lt.size); - } - - if (0) { - print_checksum(lt); - } - } - - static void print_checksum(llama_load_tensor & lt) { - uint32_t sum = 0; - for (size_t i = 0; i < lt.size; i++) { - uint8_t byte = lt.data[i]; - sum = byte + (sum << 6) + (sum << 16) - sum; // sdbm hash - } - LLAMA_LOG_INFO("%s checksum: %#08x (%s, size %zu)\n", lt.name.c_str(), sum, - llama_format_tensor_shape(lt.ne).c_str(), lt.size); - } - -}; - // -// kv cache +// kv cache helpers // -static bool kv_cache_init( +static bool llama_kv_cache_init( const struct llama_hparams & hparams, struct llama_kv_cache & cache, ggml_type wtype, @@ -873,7 +974,7 @@ static bool kv_cache_init( struct ggml_init_params params; params.mem_size = cache.buf.size; - params.mem_buffer = cache.buf.addr; + params.mem_buffer = cache.buf.data; params.no_alloc = false; cache.ctx = ggml_init(params); @@ -901,102 +1002,328 @@ static bool kv_cache_init( return true; } -struct llama_context_params llama_context_default_params() { - struct llama_context_params result = { - /*.seed =*/ LLAMA_DEFAULT_SEED, - /*.n_ctx =*/ 512, - /*.n_batch =*/ 512, - /*.n_gqa =*/ 1, - /*.rms_norm_eps =*/ LLAMA_DEFAULT_RMS_EPS, - /*.gpu_layers =*/ 0, - /*.main_gpu =*/ 0, - /*.tensor_split =*/ nullptr, - /*.rope_freq_base =*/ 10000.0f, - /*.rope_freq_scale =*/ 1.0f, - /*.progress_callback =*/ nullptr, - /*.progress_callback_user_data =*/ nullptr, - /*.low_vram =*/ false, - /*.mul_mat_q =*/ false, - /*.f16_kv =*/ true, - /*.logits_all =*/ false, - /*.vocab_only =*/ false, - /*.use_mmap =*/ true, - /*.use_mlock =*/ false, - /*.embedding =*/ false, - }; - - return result; -} - -struct llama_model_quantize_params llama_model_quantize_default_params() { - struct llama_model_quantize_params result = { - /*.nthread =*/ 0, - /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, - /*.allow_requantize =*/ false, - /*.quantize_output_tensor =*/ true, - }; - - return result; -} - -int llama_max_devices() { - return LLAMA_MAX_DEVICES; -} - -bool llama_mmap_supported() { - return llama_mmap::SUPPORTED; -} - -bool llama_mlock_supported() { - return llama_mlock::SUPPORTED; -} - -void llama_backend_init(bool numa) { - ggml_time_init(); - - // needed to initialize f16 tables - { - struct ggml_init_params params = { 0, NULL, false }; - struct ggml_context * ctx = ggml_init(params); - ggml_free(ctx); - } - - if (numa) { - ggml_numa_init(); - } - -#ifdef GGML_USE_MPI - ggml_mpi_backend_init(); -#endif -} - -void llama_backend_free() { -#ifdef GGML_USE_MPI - ggml_mpi_backend_free(); -#endif -} - -int64_t llama_time_us() { - return ggml_time_us(); -} - // -// model loading +// model loading and saving // +enum llama_file_version { + GGUF_FILE_VERSION_V1 = 1, +}; + static const char * llama_file_version_name(llama_file_version version) { switch (version) { - case LLAMA_FILE_VERSION_GGML: return "'ggml' (old version with low tokenizer quality and no mmap support)"; - case LLAMA_FILE_VERSION_GGMF_V1: return "ggmf v1 (old version with no mmap support)"; - case LLAMA_FILE_VERSION_GGJT_V1: return "ggjt v1 (pre #1405)"; - case LLAMA_FILE_VERSION_GGJT_V2: return "ggjt v2 (pre #1508)"; - case LLAMA_FILE_VERSION_GGJT_V3: return "ggjt v3 (latest)"; + case GGUF_FILE_VERSION_V1: return "GGUF V1 (latest)"; } return "unknown"; } -const char * llama_ftype_name(enum llama_ftype ftype) { +static std::string llama_format_tensor_shape(const std::vector & ne) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5u", ne.at(0)); + for (size_t i = 1; i < ne.size(); i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5u", ne.at(i)); + } + return buf; +} + +static std::string llama_format_tensor_shape(const struct ggml_tensor * t) { + char buf[256]; + snprintf(buf, sizeof(buf), "%5" PRId64, t->ne[0]); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, t->ne[i]); + } + return buf; +} + +struct llama_model_loader { + int n_kv = 0; + int n_tensors = 0; + int n_created = 0; + + int64_t n_elements = 0; + + bool use_mmap = false; + + llama_file file; + llama_ftype ftype; + llama_file_version fver; + + std::unique_ptr mapping; + + struct gguf_context * ctx_gguf = NULL; + struct ggml_context * ctx_meta = NULL; + + llama_model_loader(const std::string & fname, bool use_mmap) : file(fname.c_str(), "rb") { + struct gguf_init_params params = { + /*.no_alloc = */ true, + /*.ctx = */ &ctx_meta, + }; + + ctx_gguf = gguf_init_from_file(fname.c_str(), params); + if (!ctx_gguf) { + throw std::runtime_error(format("%s: failed to load model from %s\n", __func__, fname.c_str())); + } + + n_kv = gguf_get_n_kv(ctx_gguf); + n_tensors = gguf_get_n_tensors(ctx_gguf); + + fver = (enum llama_file_version) gguf_get_version(ctx_gguf); + + for (int i = 0; i < n_tensors; i++) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + struct ggml_tensor * t = ggml_get_tensor(ctx_meta, name); + n_elements += ggml_nelements(t); + } + + LLAMA_LOG_INFO("%s: loaded meta data with %d key-value pairs and %d tensors from %s (version %s)\n", + __func__, n_kv, n_tensors, fname.c_str(), llama_file_version_name(fver)); + + // determine file type based on the number of tensors for each quantization and print meta data + // TODO: make optional + { + std::map n_type; + + uint32_t n_type_max = 0; + enum ggml_type type_max = GGML_TYPE_F32; + + for (int i = 0; i < n_tensors; i++) { + const char * name = gguf_get_tensor_name(ctx_gguf, i); + struct ggml_tensor * meta = ggml_get_tensor(ctx_meta, name); + + n_type[meta->type]++; + + if (n_type_max < n_type[meta->type]) { + n_type_max = n_type[meta->type]; + type_max = meta->type; + } + + LLAMA_LOG_INFO("%s: - tensor %4d: %32s %-8s [ %s ]\n", __func__, i, name, ggml_type_name(meta->type), llama_format_tensor_shape(meta).c_str()); + } + + switch (type_max) { + case GGML_TYPE_F32: ftype = LLAMA_FTYPE_ALL_F32; break; + case GGML_TYPE_F16: ftype = LLAMA_FTYPE_MOSTLY_F16; break; + case GGML_TYPE_Q4_0: ftype = LLAMA_FTYPE_MOSTLY_Q4_0; break; + case GGML_TYPE_Q4_1: ftype = LLAMA_FTYPE_MOSTLY_Q4_1; break; + case GGML_TYPE_Q5_0: ftype = LLAMA_FTYPE_MOSTLY_Q5_0; break; + case GGML_TYPE_Q5_1: ftype = LLAMA_FTYPE_MOSTLY_Q5_1; break; + case GGML_TYPE_Q8_0: ftype = LLAMA_FTYPE_MOSTLY_Q8_0; break; + case GGML_TYPE_Q2_K: ftype = LLAMA_FTYPE_MOSTLY_Q2_K; break; + case GGML_TYPE_Q3_K: ftype = LLAMA_FTYPE_MOSTLY_Q3_K_M; break; + case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break; + case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break; + case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break; + default: + { + LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); + ftype = LLAMA_FTYPE_ALL_F32; + } break; + } + + for (int i = 0; i < n_kv; i++) { + const char * name = gguf_get_key(ctx_gguf, i); + const enum gguf_type type = gguf_get_kv_type(ctx_gguf, i); + + LLAMA_LOG_INFO("%s: - kv %3d: %42s %-8s\n", __func__, i, name, gguf_type_name(type)); + } + + // print type counts + for (auto & kv : n_type) { + if (kv.second == 0) { + continue; + } + + LLAMA_LOG_INFO("%s: - type %4s: %4d tensors\n", __func__, ggml_type_name(kv.first), kv.second); + } + } + + if (!llama_mmap::SUPPORTED) { + LLAMA_LOG_WARN("%s: mmap is not supported on this platform\n", __func__); + use_mmap = false; + } + + this->use_mmap = use_mmap; + } + + ~llama_model_loader() { + if (ctx_gguf) { + gguf_free(ctx_gguf); + } + if (ctx_meta) { + ggml_free(ctx_meta); + } + } + + const char * get_tensor_name(int i) const { + return gguf_get_tensor_name(ctx_gguf, i); + } + + struct ggml_tensor * get_tensor_meta(int i) const { + return ggml_get_tensor(ctx_meta, get_tensor_name(i)); + } + + void calc_sizes(size_t & ctx_size_p, size_t & mmapped_size_p) const { + ctx_size_p = 0; + mmapped_size_p = 0; + + for (int i = 0; i < n_tensors; i++) { + struct ggml_tensor * meta = get_tensor_meta(i); + ctx_size_p += sizeof(struct ggml_tensor) + GGML_OBJECT_SIZE; + (use_mmap ? mmapped_size_p : ctx_size_p) += ggml_nbytes_pad(meta); + } + } + + struct ggml_tensor * create_tensor_for(struct ggml_context * ctx, struct ggml_tensor * meta, ggml_backend backend) { + if (backend != GGML_BACKEND_CPU) { + ggml_set_no_alloc(ctx, true); + } + + struct ggml_tensor * tensor = ggml_dup_tensor(ctx, meta); + tensor->backend = backend; // TODO: ggml_set_backend + ggml_set_name(tensor, ggml_get_name(meta)); + + if (backend != GGML_BACKEND_CPU) { + ggml_set_no_alloc(ctx, use_mmap); + } + + n_created++; + + return tensor; + } + + struct ggml_tensor * create_tensor(struct ggml_context * ctx, const std::string & name, const std::vector & ne, ggml_backend backend) { + struct ggml_tensor * cur = ggml_get_tensor(ctx_meta, name.c_str()); + + if (cur == NULL) { + throw std::runtime_error(format("%s: tensor '%s' not found", __func__, name.c_str())); + } + + { + bool is_ok = true; + for (size_t i = 0; i < ne.size(); ++i) { + if (ne[i] != cur->ne[i]) { + is_ok = false; + break; + } + } + if (!is_ok) { + throw std::runtime_error( + format("%s: tensor '%s' has wrong shape; expected %s, got %s", + __func__, name.c_str(), + llama_format_tensor_shape(ne).c_str(), + llama_format_tensor_shape(cur).c_str())); + } + } + + return create_tensor_for(ctx, cur, backend); + } + + void done_getting_tensors() const { + if (n_created != n_tensors) { + throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created)); + } + } + + size_t file_offset(const char * name) const { + const int idx = gguf_find_tensor(ctx_gguf, name); + + if (idx < 0) { + throw std::runtime_error(format("%s: tensor '%s' not found in the file", __func__, name)); + } + + return gguf_get_data_offset(ctx_gguf) + gguf_get_tensor_offset(ctx_gguf, idx); + } + + void load_data_for(struct ggml_tensor * cur) const { + const size_t offs = file_offset(ggml_get_name(cur)); + + if (use_mmap) { + cur->data = (uint8_t *) mapping->addr + offs; + } else { + file.seek(offs, SEEK_SET); + file.read_raw(cur->data, ggml_nbytes(cur)); + } + } + + void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, llama_mlock * lmlock) { + size_t size_data = 0; + size_t size_lock = 0; + size_t size_pref = 0; // prefetch + + for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); + size_data += ggml_nbytes(cur); + if (cur->backend == GGML_BACKEND_CPU) { + size_pref += ggml_nbytes(cur); + } + } + + if (use_mmap) { + mapping.reset(new llama_mmap(&file, size_pref, ggml_is_numa())); + if (lmlock) { + lmlock->init(mapping->addr); + } + } + + size_t done_size = 0; + for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, gguf_get_tensor_name(ctx_gguf, i)); + GGML_ASSERT(cur); // unused tensors should have been caught by load_data already + + if (progress_callback) { + progress_callback((float) done_size / size_data, progress_callback_user_data); + } + + // allocate temp buffer if not using mmap + if (!use_mmap && cur->data == NULL) { + GGML_ASSERT(cur->backend != GGML_BACKEND_CPU); + cur->data = malloc(ggml_nbytes(cur)); + } + + load_data_for(cur); + + switch (cur->backend) { + case GGML_BACKEND_CPU: + if (use_mmap && lmlock) { + size_lock += ggml_nbytes(cur); + lmlock->grow_to(size_lock); + } + break; +#if defined(GGML_USE_CUBLAS) + case GGML_BACKEND_GPU: + case GGML_BACKEND_GPU_SPLIT: + // old code: + //ggml_cuda_transform_tensor(lt.data, lt.ggml_tensor); + + // TODO: test if this works !! + ggml_cuda_transform_tensor(cur->data, cur); + if (!use_mmap) { + free(cur->data); + } + break; +#elif defined(GGML_USE_CLBLAST) + case GGML_BACKEND_GPU: + ggml_cl_transform_tensor(cur->data, cur); + if (!use_mmap) { + free(cur->data); + } + break; +#endif + default: + continue; + } + + done_size += ggml_nbytes(cur); + } + } +}; + +// +// load LLaMA models +// + +const char * llama_model_ftype_name(enum llama_ftype ftype) { switch (ftype) { case LLAMA_FTYPE_ALL_F32: return "all F32"; case LLAMA_FTYPE_MOSTLY_F16: return "mostly F16"; @@ -1007,8 +1334,9 @@ const char * llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_0: return "mostly Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "mostly Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "mostly Q8_0"; + // K-quants - case LLAMA_FTYPE_MOSTLY_Q2_K: return "mostly Q2_K"; + case LLAMA_FTYPE_MOSTLY_Q2_K: return "mostly Q2_K"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "mostly Q3_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "mostly Q3_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "mostly Q3_K - Large"; @@ -1016,20 +1344,21 @@ const char * llama_ftype_name(enum llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "mostly Q4_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "mostly Q5_K - Small"; case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "mostly Q5_K - Medium"; - case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K"; - default: return "unknown, may not work"; + case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K"; + + default: return "unknown, may not work"; } } static const char * llama_model_type_name(e_model type) { switch (type) { - case MODEL_3B: return "3B"; - case MODEL_7B: return "7B"; + case MODEL_3B: return "3B"; + case MODEL_7B: return "7B"; case MODEL_13B: return "13B"; case MODEL_30B: return "30B"; case MODEL_65B: return "65B"; case MODEL_70B: return "70B"; - default: LLAMA_ASSERT(false); + default: GGML_ASSERT(false); } } @@ -1039,8 +1368,6 @@ static void llama_model_load_internal( llama_vocab & vocab, int n_ctx, int n_batch, - int n_gqa, - float rms_norm_eps, int n_gpu_layers, int main_gpu, const float * tensor_split, @@ -1054,22 +1381,83 @@ static void llama_model_load_internal( bool vocab_only, llama_progress_callback progress_callback, void * progress_callback_user_data) { - model.t_start_us = ggml_time_us(); std::unique_ptr ml(new llama_model_loader(fname, use_mmap)); - vocab = std::move(ml->file_loader->vocab); - model.hparams = ml->file_loader->hparams; model.n_gpu_layers = n_gpu_layers; - llama_file_version file_version = ml->file_loader->file_version; auto & hparams = model.hparams; - // TODO: read from file - hparams.f_rms_norm_eps = rms_norm_eps; + std::string general_name = "n/a"; + std::string general_arch = "n/a"; + // read hparams { + struct gguf_context * ctx = ml->ctx_gguf; + +#define GGUF_GET(dst, func, type, req, key) \ + { \ + const int kid = gguf_find_key(ctx, key); \ + if (kid >= 0) { \ + enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ + if (ktype != (type)) { \ + throw std::runtime_error(format("key %s has wrong type: %s", key, gguf_type_name(ktype))); \ + } \ + (dst) = func(ctx, kid); \ + } else if (req) { \ + throw std::runtime_error(format("key not found in model: %s", key)); \ + } \ + } + + std::string tokenizer_name; + GGUF_GET(tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model"); + + if (tokenizer_name == "llama") { + vocab.type = LLAMA_VOCAB_TYPE_SPM; + } else if (tokenizer_name == "gpt2") { + vocab.type = LLAMA_VOCAB_TYPE_BPE; + } else { + LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str()); + LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__); + vocab.type = LLAMA_VOCAB_TYPE_SPM; + } + + // get hparams kv + GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens"); + GGUF_GET(hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length"); + GGUF_GET(hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length"); + GGUF_GET(hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length"); + GGUF_GET(hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count"); + GGUF_GET(hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.block_count"); + GGUF_GET(hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.rope.dimension_count"); + GGUF_GET(hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, "llama.attention.layer_norm_rms_epsilon"); + + // n_head_kv is optional, default to n_head + hparams.n_head_kv = hparams.n_head; + GGUF_GET(hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv"); + + // TODO: manually setting rope scale should override this + // rope_freq_scale (inverse of the kv) is optional + float ropescale = 1.0f; + GGUF_GET(ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, "llama.rope.scale_linear"); + if (ropescale != 1.0f) { + rope_freq_scale = 1.0f/ropescale; + } + + // get general kv + GGUF_GET(general_name, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.name"); + GGUF_GET(general_arch, gguf_get_val_str, GGUF_TYPE_STRING, false, "general.architecture"); + + // special tokens + GGUF_GET(vocab.special_bos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.bos_token_id"); + GGUF_GET(vocab.special_eos_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.eos_token_id"); + GGUF_GET(vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.unknown_token_id"); + GGUF_GET(vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.separator_token_id"); + GGUF_GET(vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "tokenizer.ggml.padding_token_id"); + +#undef GGUF_GET + switch (hparams.n_layer) { case 26: model.type = e_model::MODEL_3B; break; case 32: model.type = e_model::MODEL_7B; break; @@ -1084,64 +1472,103 @@ static void llama_model_load_internal( } break; } + model.ftype = ml->ftype; + hparams.n_ctx = n_ctx; // LLaMAv2 - // TODO: temporary until GGUF - LLAMA_ASSERT(hparams.n_head % n_gqa == 0); - hparams.n_head_kv = hparams.n_head / n_gqa; - if (model.type == e_model::MODEL_65B && n_gqa == 8) { - LLAMA_LOG_WARN("%s: warning: assuming 70B model based on GQA == %d\n", __func__, n_gqa); - model.type = e_model::MODEL_70B; - hparams.f_ffn_mult = 1.3f; // from the params.json of the 70B model + // TODO: probably not needed + { + const auto n_gqa = hparams.n_gqa(); + + if (model.type == e_model::MODEL_65B && n_gqa == 8) { + LLAMA_LOG_WARN("%s: assuming 70B model based on GQA == %d\n", __func__, n_gqa); + model.type = e_model::MODEL_70B; + } } hparams.rope_freq_base = rope_freq_base; hparams.rope_freq_scale = rope_freq_scale; } - // ref: https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/model.py#L194-L199 - const uint32_t n_ff_raw = 2*(4*hparams.n_embd)/3; - const uint32_t n_ff_mult = hparams.f_ffn_mult*n_ff_raw; - const uint32_t n_ff = ((n_ff_mult + hparams.n_mult - 1)/hparams.n_mult)*hparams.n_mult; - //const uint32_t n_ff = 28672; + // read vocab + { + struct gguf_context * ctx = ml->ctx_gguf; + + vocab.id_to_token.resize(hparams.n_vocab); + + const int token_idx = gguf_find_key(ctx, "tokenizer.ggml.tokens"); + if (token_idx == -1) { + throw std::runtime_error("cannot find tokenizer vocab in model file\n"); + } + + const int score_idx = gguf_find_key(ctx, "tokenizer.ggml.scores"); + if (score_idx == -1) { + throw std::runtime_error("cannot find tokenizer scores in model file\n"); + } + + const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + + const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type"); + if (toktype_idx == -1) { + throw std::runtime_error("cannot find token type list in GGUF file\n"); + } + + const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + + for (uint32_t i = 0; i < hparams.n_vocab; i++) { + std::string word = gguf_get_arr_str(ctx, token_idx, i); + + vocab.token_to_id[word] = i; + + auto & token_data = vocab.id_to_token[i]; + token_data.text = std::move(word); + token_data.score = scores[i]; + token_data.type = (llama_token_type) toktypes[i]; + + // determine the newline token: 0x0A == 10 == '\n' + if (token_data.text == "<0x0A>") { + vocab.linefeed_id = i; + } + } + } { - LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(file_version)); - LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); - LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); - LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); - LLAMA_LOG_INFO("%s: n_mult = %u\n", __func__, hparams.n_mult); - LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); - LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); - LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); - LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim - LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); - LLAMA_LOG_INFO("%s: rnorm_eps = %.1e\n", __func__, hparams.f_rms_norm_eps); - LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, n_ff); - LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); - LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); - LLAMA_LOG_INFO("%s: ftype = %u (%s)\n", __func__, hparams.ftype, llama_ftype_name(hparams.ftype)); - LLAMA_LOG_INFO("%s: model size = %s\n", __func__, llama_model_type_name(model.type)); - } + // hparams + LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->fver)); + LLAMA_LOG_INFO("%s: arch = %s\n", __func__, general_arch.c_str()); + LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix + LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab); + LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train); + LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx); + LLAMA_LOG_INFO("%s: n_embd = %u\n", __func__, hparams.n_embd); + LLAMA_LOG_INFO("%s: n_head = %u\n", __func__, hparams.n_head); + LLAMA_LOG_INFO("%s: n_head_kv = %u\n", __func__, hparams.n_head_kv); + LLAMA_LOG_INFO("%s: n_layer = %u\n", __func__, hparams.n_layer); + LLAMA_LOG_INFO("%s: n_rot = %u\n", __func__, hparams.n_rot); // a.k.a. n_embd_head, n_head_dim + LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); + LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); + LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, hparams.rope_freq_base); + LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, hparams.rope_freq_scale); + LLAMA_LOG_INFO("%s: model type = %s\n", __func__, llama_model_type_name(model.type)); + LLAMA_LOG_INFO("%s: model ftype = %s\n", __func__, llama_model_ftype_name(model.ftype)); + LLAMA_LOG_INFO("%s: model size = %.2f B\n", __func__, ml->n_elements*1e-9); - if (file_version < LLAMA_FILE_VERSION_GGJT_V2) { - if (hparams.ftype != LLAMA_FTYPE_ALL_F32 && - hparams.ftype != LLAMA_FTYPE_MOSTLY_F16 && - hparams.ftype != LLAMA_FTYPE_MOSTLY_Q8_0) { - throw std::runtime_error(format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1405)")); - } - } + // general kv + LLAMA_LOG_INFO("%s: general.name = %s\n", __func__, general_name.c_str()); - if (file_version < LLAMA_FILE_VERSION_GGJT_V3) { - if (hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || - hparams.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 || - hparams.ftype == LLAMA_FTYPE_MOSTLY_Q8_0) { - throw std::runtime_error(format("this format is no longer supported (see https://github.com/ggerganov/llama.cpp/pull/1508)")); - } + // special tokens + if (vocab.special_bos_id != -1) { LLAMA_LOG_INFO( "%s: BOS token = %d '%s'\n", __func__, vocab.special_bos_id, vocab.id_to_token[vocab.special_bos_id].text.c_str() ); } + if (vocab.special_eos_id != -1) { LLAMA_LOG_INFO( "%s: EOS token = %d '%s'\n", __func__, vocab.special_eos_id, vocab.id_to_token[vocab.special_eos_id].text.c_str() ); } + if (vocab.special_unk_id != -1) { LLAMA_LOG_INFO( "%s: UNK token = %d '%s'\n", __func__, vocab.special_unk_id, vocab.id_to_token[vocab.special_unk_id].text.c_str() ); } + if (vocab.special_sep_id != -1) { LLAMA_LOG_INFO( "%s: SEP token = %d '%s'\n", __func__, vocab.special_sep_id, vocab.id_to_token[vocab.special_sep_id].text.c_str() ); } + if (vocab.special_pad_id != -1) { LLAMA_LOG_INFO( "%s: PAD token = %d '%s'\n", __func__, vocab.special_pad_id, vocab.id_to_token[vocab.special_pad_id].text.c_str() ); } + if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); } } if (vocab_only) { + LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__); return; } @@ -1149,20 +1576,22 @@ static void llama_model_load_internal( size_t ctx_size; size_t mmapped_size; - ml->calc_sizes(&ctx_size, &mmapped_size); + + ml->calc_sizes(ctx_size, mmapped_size); + LLAMA_LOG_INFO("%s: ggml ctx size = %7.2f MB\n", __func__, ctx_size/1024.0/1024.0); // create the ggml context { model.buf.resize(ctx_size); if (use_mlock) { - model.mlock_buf.init (model.buf.addr); + model.mlock_buf.init (model.buf.data); model.mlock_buf.grow_to(model.buf.size); } struct ggml_init_params params = { /*.mem_size =*/ model.buf.size, - /*.mem_buffer =*/ model.buf.addr, + /*.mem_buffer =*/ model.buf.data, /*.no_alloc =*/ ml->use_mmap, }; @@ -1198,9 +1627,7 @@ static void llama_model_load_internal( const uint32_t n_layer = hparams.n_layer; const uint32_t n_vocab = hparams.n_vocab; - ml->ggml_ctx = ctx; - - model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_embeddings = ml->create_tensor(ctx, TN_TOKEN_EMBD, {n_embd, n_vocab}, GGML_BACKEND_CPU); // "output" tensor { @@ -1221,8 +1648,8 @@ static void llama_model_load_internal( backend_output = GGML_BACKEND_CPU; } - model.norm = ml->get_tensor("norm.weight", {n_embd}, backend_norm); - model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output); + model.norm = ml->create_tensor(ctx, TN_OUTPUT_NORM, {n_embd}, backend_norm); + model.output = ml->create_tensor(ctx, TN_OUTPUT, {n_embd, n_vocab}, backend_output); if (backend_norm == GGML_BACKEND_GPU) { vram_weights += ggml_nbytes(model.norm); } @@ -1231,6 +1658,8 @@ static void llama_model_load_internal( } } + const uint32_t n_ff = hparams.n_ff; + const int i_gpu_start = n_layer - n_gpu_layers; model.layers.resize(n_layer); @@ -1239,21 +1668,18 @@ static void llama_model_load_internal( const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT auto & layer = model.layers[i]; + layer.attention_norm = ml->create_tensor(ctx, format(TN_ATTN_NORM, i), {n_embd}, backend); - std::string layers_i = "layers." + std::to_string(i); + layer.wq = ml->create_tensor(ctx, format(TN_ATTN_Q, i), {n_embd, n_embd}, backend_split); + layer.wk = ml->create_tensor(ctx, format(TN_ATTN_K, i), {n_embd, n_embd_gqa}, backend_split); + layer.wv = ml->create_tensor(ctx, format(TN_ATTN_V, i), {n_embd, n_embd_gqa}, backend_split); + layer.wo = ml->create_tensor(ctx, format(TN_ATTN_OUTPUT, i), {n_embd, n_embd}, backend_split); - layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend); + layer.ffn_norm = ml->create_tensor(ctx, format(TN_FFN_NORM, i), {n_embd}, backend); - layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend_split); - layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd_gqa}, backend_split); - layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd_gqa}, backend_split); - layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend_split); - - layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - - layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd, n_ff}, backend_split); - layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", { n_ff, n_embd}, backend_split); - layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd, n_ff}, backend_split); + layer.w1 = ml->create_tensor(ctx, format(TN_FFN_GATE, i), {n_embd, n_ff}, backend_split); + layer.w2 = ml->create_tensor(ctx, format(TN_FFN_DOWN, i), { n_ff, n_embd}, backend_split); + layer.w3 = ml->create_tensor(ctx, format(TN_FFN_UP, i), {n_embd, n_ff}, backend_split); if (backend == GGML_BACKEND_GPU) { vram_weights += @@ -1351,8 +1777,9 @@ static void llama_model_load_internal( } // populate `tensors_by_name` - for (llama_load_tensor & lt : ml->tensors_map.tensors) { - model.tensors_by_name.emplace_back(lt.name, lt.ggml_tensor); + for (int i = 0; i < ml->n_tensors; ++i) { + struct ggml_tensor * cur = ggml_get_tensor(ctx, ml->get_tensor_name(i)); + model.tensors_by_name.emplace_back(ggml_get_name(cur), cur); } (void) tensor_split; @@ -1362,7 +1789,7 @@ static void llama_model_load_internal( } #endif - ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); + ml->load_all_data(ctx, progress_callback, progress_callback_user_data, use_mlock ? &model.mlock_mmap : NULL); if (progress_callback) { progress_callback(1.0f, progress_callback_user_data); @@ -1381,8 +1808,6 @@ static bool llama_model_load( llama_vocab & vocab, int n_ctx, int n_batch, - int n_gqa, - float rms_norm_eps, int n_gpu_layers, int main_gpu, const float * tensor_split, @@ -1397,7 +1822,7 @@ static bool llama_model_load( llama_progress_callback progress_callback, void *progress_callback_user_data) { try { - llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, + llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gpu_layers, main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type, use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); return true; @@ -1414,7 +1839,7 @@ static struct ggml_cgraph * llama_build_graph( int n_tokens, int n_past) { - LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); + GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT const int N = n_tokens; @@ -1423,7 +1848,7 @@ static struct ggml_cgraph * llama_build_graph( const auto & kv_self = lctx.kv_self; - LLAMA_ASSERT(!!kv_self.ctx); + GGML_ASSERT(!!kv_self.ctx); const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; @@ -1433,21 +1858,20 @@ static struct ggml_cgraph * llama_build_graph( const int64_t n_embd_head = hparams.n_embd_head(); const int64_t n_embd_gqa = hparams.n_embd_gqa(); - LLAMA_ASSERT(n_embd_head == hparams.n_rot); + GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = hparams.rope_freq_base; - const float freq_scale = hparams.rope_freq_scale; - const float rms_norm_eps = hparams.f_rms_norm_eps; + const float freq_base = hparams.rope_freq_base; + const float freq_scale = hparams.rope_freq_scale; + const float norm_rms_eps = hparams.f_norm_rms_eps; const int n_gpu_layers = model.n_gpu_layers; auto & mem_per_token = lctx.mem_per_token; auto & buf_compute = lctx.buf_compute; - struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, - /*.mem_buffer =*/ buf_compute.addr, + /*.mem_buffer =*/ buf_compute.data, /*.no_alloc =*/ false, }; @@ -1545,7 +1969,7 @@ static struct ggml_cgraph * llama_build_graph( // norm { - cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); + cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); offload_func(cur); ggml_set_name(cur, "rms_norm_0"); @@ -1690,7 +2114,7 @@ static struct ggml_cgraph * llama_build_graph( { // norm { - cur = ggml_rms_norm(ctx0, inpFF, rms_norm_eps); + cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps); offload_func(cur); ggml_set_name(cur, "rms_norm_1"); @@ -1740,7 +2164,7 @@ static struct ggml_cgraph * llama_build_graph( // norm { - cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); + cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps); offload_func_nr(cur); ggml_set_name(cur, "rms_norm_2"); @@ -1797,14 +2221,14 @@ static bool llama_eval_internal( int n_threads, const char * cgraph_fname) { - LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); + GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT - LLAMA_ASSERT(n_tokens > 0); - LLAMA_ASSERT(n_past >= 0); - LLAMA_ASSERT(n_threads > 0); + GGML_ASSERT(n_tokens > 0); + GGML_ASSERT(n_past >= 0); + GGML_ASSERT(n_threads > 0); // TODO: keep the values of n_batch and n_ctx - // LLAMA_ASSERT(n_tokens <= n_batch); - // LLAMA_ASSERT(n_past + n_tokens <= n_ctx); + // GGML_ASSERT(n_tokens <= n_batch); + // GGML_ASSERT(n_past + n_tokens <= n_ctx); const int64_t t_start_us = ggml_time_us(); @@ -1819,7 +2243,7 @@ static bool llama_eval_internal( const auto & kv_self = lctx.kv_self; - LLAMA_ASSERT(!!kv_self.ctx); + GGML_ASSERT(!!kv_self.ctx); const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -1843,8 +2267,8 @@ static bool llama_eval_internal( struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - LLAMA_ASSERT(strcmp(res->name, "result_output") == 0); - LLAMA_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); #if GGML_USE_MPI const int64_t n_layer = hparams.n_layer; @@ -1927,6 +2351,89 @@ static bool llama_eval_internal( // tokenizer // +static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { + return vocab.type; +} + +static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL; +} + +static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN; +} + +static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; +} + +static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; +} + +static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNUSED; +} + +static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; +} + +static bool llama_is_bos_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(llama_is_control_token(vocab, id)); + return id == vocab.special_bos_id; +} + +static bool llama_is_eos_token(const llama_vocab & vocab, llama_token id ) { + GGML_ASSERT(llama_is_control_token(vocab, id)); + return id == vocab.special_eos_id; +} + +static bool llama_is_pad_token(const llama_vocab & vocab, llama_token id ) { + GGML_ASSERT(id < 0 || llama_is_control_token(vocab, id)); + return id == vocab.special_pad_id; +} + +static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(llama_is_byte_token(vocab, id)); + const auto& token_data = vocab.id_to_token.at(id); + auto buf = token_data.text.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); +} + +static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { + char buf[7]; + int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); + GGML_ASSERT(0 <= result && result < 7); + return vocab.token_to_id.at(buf); +} + +static std::string llama_escape_whitespace(const std::string& text) { + std::string result; + bool escaping = false; + result += "\xe2\x96\x81"; + for (size_t offs = 0; offs < text.length(); ++offs) { + if (text[offs] == ' ') { + if (!escaping) { + result += "\xe2\x96\x81"; + escaping = true; + } + } + else { + escaping = false; + result += text[offs]; + } + } + return result; +} + +static std::string llama_unescape_whitespace(const std::string& word) { + if (word.length() >= 3 && word.substr(0, 3) == "\xe2\x96\x81") { + return std::string(" ") + word.substr(3); + } + return word; +} + static size_t utf8_len(char src) { const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; uint8_t highbits = static_cast(src) >> 4; @@ -1968,10 +2475,11 @@ struct llama_tokenizer { size_t offs = 0; while (offs < text.size()) { llama_sp_symbol sym; - size_t char_len = std::min(text.size() - offs, utf8_len(text[offs])); + size_t len = utf8_len(text[offs]); + GGML_ASSERT(offs + len <= text.size()); sym.text = text.c_str() + offs; - sym.n = char_len; - offs += char_len; + sym.n = len; + offs += len; sym.prev = index - 1; sym.next = offs == text.size() ? -1 : index + 1; index++; @@ -2016,23 +2524,36 @@ struct llama_tokenizer { for (int i = 0; i != -1; i = symbols_[i].next) { auto & symbol = symbols_[i]; - auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n)); - - if (token == vocab_.token_to_id.end()) { - // output any symbols that did not form tokens as bytes. - for (int j = 0; j < (int) symbol.n; ++j) { - // NOTE: old version, before #2420 - not sure what are the implications of this - //llama_vocab::id token_id = static_cast(symbol.text[j]) + 3; - llama_vocab::id token_id = vocab_.token_to_id.at(std::string(1, symbol.text[j])); - output.push_back(token_id); - } - } else { - output.push_back((*token).second); - } + resegment(symbol, output); } } private: + void resegment(llama_sp_symbol &symbol, std::vector &output) { + auto text = std::string(symbol.text, symbol.n); + auto token = vocab_.token_to_id.find(text); + + // Do we need to support is_unused? + if (token != vocab_.token_to_id.end()) { + output.push_back((*token).second); + return; + } + + const auto p = rev_merge.find(text); + + if (p == rev_merge.end()) { + // output any symbols that did not form tokens as bytes. + for (int j = 0; j < (int)symbol.n; ++j) { + llama_vocab::id token_id = llama_byte_to_token(vocab_, symbol.text[j]); + output.push_back(token_id); + } + return; + } + + resegment(symbols_[p->second.first], output); + resegment(symbols_[p->second.second], output); + } + void try_add_bigram(int left, int right) { if (left == -1 || right == -1) { return; @@ -2049,31 +2570,42 @@ private: return; } - const auto &tok_score = vocab_.id_to_token[(*token).second]; + const auto &tok_data = vocab_.id_to_token[(*token).second]; llama_sp_bigram bigram; bigram.left = left; bigram.right = right; - bigram.score = tok_score.score; + bigram.score = tok_data.score; bigram.size = text.size(); work_queue_.push(bigram); + + // Do we need to support is_unused? + rev_merge[text] = std::make_pair(left, right); } const llama_vocab & vocab_; std::vector symbols_; llama_sp_bigram::queue work_queue_; + std::map > rev_merge; }; -static std::vector llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { +static std::vector llama_tokenize_internal(const llama_vocab & vocab, const std::string & raw_text, bool bos, bool escape) { llama_tokenizer tokenizer(vocab); std::vector output; - if (text.empty()) { + if (raw_text.empty()) { return output; } if (bos) { - output.push_back(llama_token_bos()); + output.push_back(vocab.special_bos_id); + } + + std::string text; + if (escape) { + text = llama_escape_whitespace(raw_text); + } else { + text = raw_text; } tokenizer.tokenize(text, output); @@ -2164,8 +2696,8 @@ std::pair, llama_partial_utf8> decode_utf8( // returns true iff pos points to the end of one of the definitions of a rule static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { switch (pos->type) { - case LLAMA_GRETYPE_END: return true; - case LLAMA_GRETYPE_ALT: return true; + case LLAMA_GRETYPE_END: return true; // NOLINT + case LLAMA_GRETYPE_ALT: return true; // NOLINT default: return false; } } @@ -2178,7 +2710,8 @@ static std::pair llama_grammar_match_char( bool found = false; bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; - LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + + GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); // NOLINT do { if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { @@ -2203,7 +2736,7 @@ static bool llama_grammar_match_partial_char( const llama_partial_utf8 partial_utf8) { bool is_positive_char = pos->type == LLAMA_GRETYPE_CHAR; - LLAMA_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); + GGML_ASSERT(is_positive_char || pos->type == LLAMA_GRETYPE_CHAR_NOT); uint32_t partial_value = partial_utf8.value; int n_remain = partial_utf8.n_remain; @@ -2296,7 +2829,7 @@ static void llama_grammar_advance_stack( // end of alternate (LLAMA_GRETYPE_END, LLAMA_GRETYPE_ALT) or middle of char range // (LLAMA_GRETYPE_CHAR_ALT, LLAMA_GRETYPE_CHAR_RNG_UPPER); stack should never be left on // those - LLAMA_ASSERT(false); + GGML_ASSERT(false); } } @@ -2371,7 +2904,7 @@ static std::vector llama_grammar_reject_candidates_for_ } } - auto stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; + const auto * stack_pos_after = llama_grammar_match_char(stack_pos, 0).second; // update top of stack to next element, if any std::vector stack_after(stack.begin(), stack.end() - 1); @@ -2393,7 +2926,7 @@ static std::vector llama_grammar_reject_candidates( const std::vector> & rules, const std::vector> & stacks, const std::vector & candidates) { - LLAMA_ASSERT(!stacks.empty()); // REVIEW + GGML_ASSERT(!stacks.empty()); // REVIEW if (candidates.empty()) { return std::vector(); @@ -2460,7 +2993,7 @@ void llama_grammar_free(struct llama_grammar * grammar) { // void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { - assert(candidates->size > 0); + GGML_ASSERT(candidates->size > 0); const int64_t t_start_sample_us = ggml_time_us(); @@ -2604,7 +3137,6 @@ void llama_sample_tail_free(struct llama_context * ctx, llama_token_data_array * } } - void llama_sample_typical(struct llama_context * ctx, llama_token_data_array * candidates, float p, size_t min_keep) { // Reference implementation: // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr @@ -2741,7 +3273,7 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l } void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { - assert(ctx); + GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); bool allow_eos = false; @@ -2752,31 +3284,28 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } } - const llama_token eos = llama_token_eos(); + const llama_token eos = llama_token_eos(ctx); std::vector, llama_partial_utf8>> candidates_decoded; std::vector candidates_grammar; for (size_t i = 0; i < candidates->size; ++i) { - const llama_token id = candidates->data[i].id; - const char * str = llama_token_to_str(ctx, id); + const llama_token id = candidates->data[i].id; + const std::string text = llama_token_to_text(ctx, id); if (id == eos) { if (!allow_eos) { candidates->data[i].logit = -INFINITY; } - } else if (*str == 0) { + } else if (text.empty()) { candidates->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(str, grammar->partial_utf8)); - candidates_grammar.push_back({ - i, candidates_decoded.back().first.data(), candidates_decoded.back().second - }); + candidates_decoded.push_back(decode_utf8(text.c_str(), grammar->partial_utf8)); + candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } - const auto rejects = - llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); - for (auto & reject : rejects) { + const auto rejects = llama_grammar_reject_candidates(grammar->rules, grammar->stacks, candidates_grammar); + for (const auto & reject : rejects) { candidates->data[reject.index].logit = -INFINITY; } @@ -2804,10 +3333,12 @@ void llama_sample_classifier_free_guidance( float scale) { int64_t t_start_sample_us = ggml_time_us(); - assert(ctx); + GGML_ASSERT(ctx); + auto n_vocab = llama_n_vocab(ctx); - assert(n_vocab == (int)candidates->size); - assert(!candidates->sorted); + + GGML_ASSERT(n_vocab == (int)candidates->size); + GGML_ASSERT(!candidates->sorted); std::vector logits_base; logits_base.reserve(candidates->size); @@ -2831,7 +3362,8 @@ void llama_sample_classifier_free_guidance( } llama_token llama_sample_token_mirostat(struct llama_context * ctx, llama_token_data_array * candidates, float tau, float eta, int m, float * mu) { - assert(ctx); + GGML_ASSERT(ctx); + auto N = float(llama_n_vocab(ctx)); int64_t t_start_sample_us; t_start_sample_us = ggml_time_us(); @@ -2937,7 +3469,8 @@ llama_token llama_sample_token_greedy(struct llama_context * ctx, llama_token_da } llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_array * candidates) { - assert(ctx); + GGML_ASSERT(ctx); + const int64_t t_start_sample_us = ggml_time_us(); llama_sample_softmax(nullptr, candidates); @@ -2961,25 +3494,25 @@ llama_token llama_sample_token(struct llama_context * ctx, llama_token_data_arra void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar * grammar, llama_token token) { const int64_t t_start_sample_us = ggml_time_us(); - if (token == llama_token_eos()) { + if (token == llama_token_eos(ctx)) { for (const auto & stack : grammar->stacks) { if (stack.empty()) { return; } } - LLAMA_ASSERT(false); + GGML_ASSERT(false); } - const char * str = llama_token_to_str(ctx, token); + const std::string text = llama_token_to_text(ctx, token); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(str, grammar->partial_utf8); + const auto decoded = decode_utf8(text.c_str(), grammar->partial_utf8); const auto & code_points = decoded.first; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); } grammar->partial_utf8 = decoded.second; - LLAMA_ASSERT(!grammar->stacks.empty()); + GGML_ASSERT(!grammar->stacks.empty()); ctx->t_sample_us += ggml_time_us() - t_start_sample_us; } @@ -2988,37 +3521,37 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar // quantization // -static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llama_buffer & output, const int nelements, const int nthread) { - if (output.size < nelements * sizeof(float)) { - output.resize(nelements * sizeof(float)); +static void llama_convert_tensor_internal(struct ggml_tensor * tensor, std::vector & output, const size_t nelements, const int nthread) { + if (output.size() < nelements) { + output.resize(nelements); } - float * f32_output = (float *) output.addr; + float * f32_output = (float *) output.data(); ggml_type_traits_t qtype; - if (ggml_is_quantized(tensor.type)) { - qtype = ggml_internal_get_type_traits(tensor.type); + if (ggml_is_quantized(tensor->type)) { + qtype = ggml_internal_get_type_traits(tensor->type); if (qtype.to_float == NULL) { - throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor.type))); + throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type))); } - } else if (tensor.type != GGML_TYPE_F16) { - throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor.type))); + } else if (tensor->type != GGML_TYPE_F16) { + throw std::runtime_error(format("cannot dequantize/convert tensor type %s", ggml_type_name(tensor->type))); } if (nthread < 2) { - if (tensor.type == GGML_TYPE_F16) { - ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor.data, f32_output, nelements); - } else if (ggml_is_quantized(tensor.type)) { - qtype.to_float(tensor.data, f32_output, nelements); + if (tensor->type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((ggml_fp16_t *)tensor->data, f32_output, nelements); + } else if (ggml_is_quantized(tensor->type)) { + qtype.to_float(tensor->data, f32_output, nelements); } else { - LLAMA_ASSERT(false); // unreachable + GGML_ASSERT(false); // unreachable } return; } - auto block_size = tensor.type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor.type); - auto block_size_bytes = ggml_type_size(tensor.type); + auto block_size = tensor->type == GGML_TYPE_F16 ? 1 : (size_t)ggml_blck_size(tensor->type); + auto block_size_bytes = ggml_type_size(tensor->type); - LLAMA_ASSERT(nelements % block_size == 0); + GGML_ASSERT(nelements % block_size == 0); auto nblocks = nelements / block_size; auto blocks_per_thread = nblocks / nthread; auto spare_blocks = nblocks - (blocks_per_thread * nthread); // if blocks aren't divisible by thread count @@ -3036,20 +3569,18 @@ static void llama_convert_tensor_internal(const llama_load_tensor & tensor, llam qtype.to_float(inbuf, outbuf, nels); } }; - workers.push_back(std::thread(compute, tensor.type, tensor.data + in_buff_offs, f32_output + out_buff_offs, thr_elems)); + workers.push_back(std::thread(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems)); in_buff_offs += thr_block_bytes; out_buff_offs += thr_elems; } for (auto & worker : workers) { worker.join(); } - } static void llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, const llama_model_quantize_params * params) { ggml_type quantized_type; llama_ftype ftype = params->ftype; - int nthread = params->nthread; switch (params->ftype) { case LLAMA_FTYPE_MOSTLY_Q4_0: quantized_type = GGML_TYPE_Q4_0; break; @@ -3075,21 +3606,35 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s default: throw std::runtime_error(format("invalid output file type %d\n", ftype)); } + int nthread = params->nthread; + if (nthread <= 0) { nthread = std::thread::hardware_concurrency(); } std::unique_ptr model_loader(new llama_model_loader(fname_inp, /*use_mmap*/ false)); - llama_file_saver file_saver(fname_out.c_str(), model_loader->file_loader.get(), params->ftype); + + const size_t align = GGUF_DEFAULT_ALIGNMENT; + struct gguf_context * ctx_out = gguf_init_empty(); + + // copy the KV pairs from the input file + gguf_set_kv (ctx_out, model_loader->ctx_gguf); + gguf_set_val_u32(ctx_out, "general.quantization_version", GGML_QNT_VERSION); #ifdef GGML_USE_K_QUANTS int n_attention_wv = 0; int n_feed_forward_w2 = 0; - for (auto& tensor : model_loader->tensors_map.tensors) { - if (tensor.name.find("attention.wv.weight") != std::string::npos) { + + for (int i = 0; i < model_loader->n_tensors; ++i) { + struct ggml_tensor * meta = model_loader->get_tensor_meta(i); + + const std::string name = ggml_get_name(meta); + + // TODO: avoid hardcoded tensor names - use the TN_* constants + if (name.find("attn_v.weight") != std::string::npos) { ++n_attention_wv; } - else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { + else if (name.find("ffn_down.weight") != std::string::npos) { ++n_feed_forward_w2; } } @@ -3109,46 +3654,69 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s return i_layer < num_layers/8 || i_layer >= 7*num_layers/8 || (i_layer - num_layers/8)%3 == 2; }; - size_t idx = 0; - for (llama_load_tensor & tensor : model_loader->tensors_map.tensors) { - llama_buffer read_data; - read_data.resize(tensor.size); - tensor.data = read_data.addr; + int idx = 0; + + std::vector read_data; + std::vector work; + + // populate the original tensors so we get an initial meta data + for (int i = 0; i < model_loader->n_tensors; ++i) { + struct ggml_tensor * meta = model_loader->get_tensor_meta(i); + gguf_add_tensor(ctx_out, meta); + } + + std::ofstream fout(fname_out, std::ios::binary); + + const size_t meta_size = gguf_get_meta_size(ctx_out); + + LLAMA_LOG_INFO("%s: meta size = %zu bytes\n", __func__, meta_size); + + // placeholder for the meta data + ::zeros(fout, meta_size); + + for (int i = 0; i < model_loader->n_tensors; ++i) { + struct ggml_tensor * tensor = model_loader->get_tensor_meta(i); + + const std::string name = ggml_get_name(tensor); + + read_data.resize(ggml_nbytes(tensor)); + tensor->data = read_data.data(); model_loader->load_data_for(tensor); - LLAMA_LOG_INFO("[%4zu/%4zu] %36s - %16s, type = %6s, ", - ++idx, model_loader->tensors_map.tensors.size(), - tensor.name.c_str(), llama_format_tensor_shape(tensor.ne).c_str(), - ggml_type_name(tensor.type)); + LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ", + ++idx, model_loader->n_tensors, + ggml_get_name(tensor), + llama_format_tensor_shape(tensor).c_str(), + ggml_type_name(tensor->type)); // This used to be a regex, but has an extreme cost to compile times. - bool quantize = tensor.name.rfind("weight") == tensor.name.size() - 6; // ends with 'weight'? + bool quantize = name.rfind("weight") == name.size() - 6; // ends with 'weight'? // quantize only 2D tensors - quantize &= (tensor.ne.size() == 2); - quantize &= params->quantize_output_tensor || tensor.name != "output.weight"; - quantize &= quantized_type != tensor.type; + quantize &= (tensor->n_dims == 2); + quantize &= params->quantize_output_tensor || name != "output.weight"; + quantize &= quantized_type != tensor->type; enum ggml_type new_type; void * new_data; size_t new_size; - llama_buffer work; if (!quantize) { - new_type = tensor.type; - new_data = tensor.data; - new_size = tensor.size; - LLAMA_LOG_INFO("size = %8.3f MB\n", tensor.size/1024.0/1024.0); + new_type = tensor->type; + new_data = tensor->data; + new_size = ggml_nbytes(tensor); + LLAMA_LOG_INFO("size = %8.3f MB\n", ggml_nbytes(tensor)/1024.0/1024.0); } else { new_type = quantized_type; #ifdef GGML_USE_K_QUANTS - if (tensor.name == "output.weight") { - int nx = tensor.ne.at(0); - int ny = tensor.ne.at(1); + // TODO: avoid hardcoded tensor names - use the TN_* constants + if (name == TN_OUTPUT) { + int nx = tensor->ne[0]; + int ny = tensor->ne[1]; if (nx % QK_K == 0 && ny % QK_K == 0) { new_type = GGML_TYPE_Q6_K; } - } else if (tensor.name.find("attention.wv.weight") != std::string::npos) { + } else if (name.find("attn_v.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && @@ -3156,32 +3724,32 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s else if (QK_K == 64 && (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_S) && (i_attention_wv < n_attention_wv/8 || i_attention_wv >= 7*n_attention_wv/8)) new_type = GGML_TYPE_Q6_K; ++i_attention_wv; - } else if (tensor.name.find("feed_forward.w2.weight") != std::string::npos) { + } else if (name.find("ffn_down.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; //else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < n_feed_forward_w2/8) new_type = GGML_TYPE_Q6_K; ++i_feed_forward_w2; - } else if (tensor.name.find("attention.wo.weight") != std::string::npos) { + } else if (name.find("attn_output.weight") != std::string::npos) { if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; } bool convert_incompatible_tensor = false; if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { - int nx = tensor.ne.at(0); - int ny = tensor.ne.at(1); + int nx = tensor->ne[0]; + int ny = tensor->ne[1]; if (nx % QK_K != 0 || ny % QK_K != 0) { LLAMA_LOG_INFO("\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K); convert_incompatible_tensor = true; } } if (convert_incompatible_tensor) { - if (tensor.name == "output.weight") { + if (name == TN_OUTPUT) { new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. LLAMA_LOG_WARN("F16 will be used for this tensor instead.\n"); - } else if (tensor.name == "tok_embeddings.weight") { + } else if (name == TN_TOKEN_EMBD) { new_type = GGML_TYPE_Q4_0; //fall back to Q4_0 instead of just failing. LLAMA_LOG_WARN("Q4_0 will be used for this tensor instead.\n"); } else { @@ -3190,27 +3758,28 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } #endif - float * f32_data; - size_t nelements = tensor.ne.at(0) * tensor.ne.at(1); - llama_buffer f32_conv_buf; + const size_t nelements = ggml_nelements(tensor); - if (tensor.type == GGML_TYPE_F32) { - f32_data = (float *) tensor.data; - } else if (ggml_is_quantized(tensor.type) && !params->allow_requantize) { - throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor.type))); + float * f32_data; + std::vector f32_conv_buf; + + if (tensor->type == GGML_TYPE_F32) { + f32_data = (float *) tensor->data; + } else if (ggml_is_quantized(tensor->type) && !params->allow_requantize) { + throw std::runtime_error(format("requantizing from type %s is disabled", ggml_type_name(tensor->type))); } else { llama_convert_tensor_internal(tensor, f32_conv_buf, nelements, nthread); - f32_data = (float *) f32_conv_buf.addr; + f32_data = (float *) f32_conv_buf.data(); } LLAMA_LOG_INFO("quantizing to %s .. ", ggml_type_name(new_type)); fflush(stdout); work.resize(nelements * 4); // upper bound on size - new_data = work.addr; + new_data = work.data(); std::vector hist_cur(1 << 4, 0); - int chunk_size = 32 * 512; + static const int chunk_size = 32 * 512; const int nchunk = (nelements + chunk_size - 1)/chunk_size; const int nthread_use = nthread > 1 ? std::max(1, std::min(nthread, nchunk)) : 1; if (nthread_use < 2) { @@ -3218,7 +3787,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } else { size_t counter = 0; new_size = 0; - auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements, chunk_size] () { + auto compute = [&mutex, &counter, &hist_cur, &new_size, new_type, f32_data, new_data, nelements]() { std::vector local_hist; size_t local_size = 0; while (true) { @@ -3253,7 +3822,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } - LLAMA_LOG_INFO("size = %8.2f MB -> %8.2f MB | hist: ", tensor.size/1024.0/1024.0, new_size/1024.0/1024.0); + LLAMA_LOG_INFO("size = %8.2f MB -> %8.2f MB | hist: ", ggml_nbytes(tensor)/1024.0/1024.0, new_size/1024.0/1024.0); int64_t tot_count = 0; for (size_t i = 0; i < hist_cur.size(); i++) { hist_all[i] += hist_cur[i]; @@ -3267,14 +3836,34 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } LLAMA_LOG_INFO("\n"); } - total_size_org += tensor.size; + total_size_org += ggml_nbytes(tensor); total_size_new += new_size; - file_saver.write_tensor(tensor, new_type, new_data, new_size); + + // update the gguf meta data as we go + gguf_set_tensor_type(ctx_out, name.c_str(), new_type); + gguf_set_tensor_data(ctx_out, name.c_str(), new_data, new_size); + + // write tensor data + padding + fout.write((const char *) new_data, new_size); + zeros(fout, GGML_PAD(new_size, align) - new_size); } + // go back to beginning of file and write the updated meta data + { + fout.seekp(0); + std::vector data(gguf_get_meta_size(ctx_out)); + gguf_get_meta_data(ctx_out, data.data()); + fout.write((const char *) data.data(), data.size()); + } + + fout.close(); + + gguf_free(ctx_out); + LLAMA_LOG_INFO("%s: model size = %8.2f MB\n", __func__, total_size_org/1024.0/1024.0); LLAMA_LOG_INFO("%s: quant size = %8.2f MB\n", __func__, total_size_new/1024.0/1024.0); + // print histogram for all tensors { int64_t sum_all = 0; for (size_t i = 0; i < hist_all.size(); i++) { @@ -3291,238 +3880,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s } } - - -// -// interface implementation -// - -struct llama_model * llama_load_model_from_file( - const char * path_model, - struct llama_context_params params) { - ggml_time_init(); - - llama_model * model = new llama_model; - - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - - if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers, - params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,params.low_vram, - memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, - params.progress_callback_user_data)) { - LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); - delete model; - return nullptr; - } - - return model; -} - -void llama_free_model(struct llama_model * model) { - delete model; -} - -struct llama_context * llama_new_context_with_model( - struct llama_model * model, - struct llama_context_params params) { - - if (!model) { - return nullptr; - } - - llama_context * ctx = new llama_context(*model); - - if (params.seed == LLAMA_DEFAULT_SEED) { - params.seed = time(NULL); - } - - unsigned cur_percentage = 0; - if (params.progress_callback == NULL) { - params.progress_callback_user_data = &cur_percentage; - params.progress_callback = [](float progress, void * ctx) { - unsigned * cur_percentage_p = (unsigned *) ctx; - unsigned percentage = (unsigned) (100 * progress); - while (percentage > *cur_percentage_p) { - *cur_percentage_p = percentage; - LLAMA_LOG_INFO("."); - if (percentage >= 100) { - LLAMA_LOG_INFO("\n"); - } - } - }; - } - - ctx->rng = std::mt19937(params.seed); - ctx->logits_all = params.logits_all; - - ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - - // reserve memory for context buffers - if (!params.vocab_only) { - if (!kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) { - LLAMA_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); - llama_free(ctx); - return nullptr; - } - - { - const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); - LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); - } - - const auto & hparams = ctx->model.hparams; - - // resized during inference - if (params.logits_all) { - ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); - } else { - ctx->logits.reserve(hparams.n_vocab); - } - - if (params.embedding){ - ctx->embedding.resize(hparams.n_embd); - } - -#ifdef LLAMA_USE_ALLOCATOR - { - static const size_t tensor_alignment = 32; - // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data - ctx->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); - - // create measure allocator - ctx->alloc = ggml_allocr_new_measure(tensor_alignment); - - // build worst-case graph - int n_tokens = std::min((int)hparams.n_ctx, params.n_batch); - int n_past = hparams.n_ctx - n_tokens; - llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); -#ifdef GGML_USE_METAL - if (params.n_gpu_layers > 0) { - ctx->ctx_metal = ggml_metal_init(1); - if (!ctx->ctx_metal) { - LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); - llama_free(ctx); - return NULL; - } - ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); - ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); - } -#endif - // measure memory requirements for the graph - size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; - - LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); - - // debug - for comparison with scratch buffer - //size_t prev_req = - // MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) + - // MEM_REQ_SCRATCH1().at(ctx->model.type) + - // MEM_REQ_EVAL().at(ctx->model.type); - //LLAMA_LOG_INFO("%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0); - - // recreate allocator with exact memory requirements - ggml_allocr_free(ctx->alloc); - - ctx->buf_alloc.resize(alloc_size); - ctx->alloc = ggml_allocr_new(ctx->buf_alloc.addr, ctx->buf_alloc.size, tensor_alignment); -#ifdef GGML_USE_METAL - if (ctx->ctx_metal) { - ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); - } -#endif - } -#else - ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); -#endif - -#ifdef LLAMA_USE_SCRATCH - ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); - ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); -#endif - } - -#ifdef GGML_USE_METAL - if (params.n_gpu_layers > 0) { - // this allocates all Metal resources and memory buffers - - void * data_ptr = NULL; - size_t data_size = 0; - - if (params.use_mmap) { - data_ptr = ctx->model.mapping->addr; - data_size = ctx->model.mapping->size; - } else { - data_ptr = ggml_get_mem_buffer(ctx->model.ctx); - data_size = ggml_get_mem_size (ctx->model.ctx); - } - - const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); - - LLAMA_LOG_INFO("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); - -#define LLAMA_METAL_CHECK_BUF(result) \ - if (!(result)) { \ - LLAMA_LOG_ERROR("%s: failed to add buffer\n", __func__); \ - llama_free(ctx); \ - return NULL; \ - } - - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); - - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.addr, ctx->buf_compute.size, 0)); - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.addr, ctx->kv_self.buf.size, 0)); - - LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.addr, ctx->buf_alloc.size, 0)); -#undef LLAMA_METAL_CHECK_BUF - } -#endif - -#ifdef GGML_USE_MPI - ctx->ctx_mpi = ggml_mpi_init(); - - if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { - // Enter a blocking eval loop with dummy input, letting rank=0 drive the process - const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos()); - while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; - llama_backend_free(); - exit(1); - } -#endif - - return ctx; -} - -struct llama_context * llama_init_from_file( - const char * path_model, - struct llama_context_params params) { - - struct llama_model * model = llama_load_model_from_file(path_model, params); - if (!model) { - return nullptr; - } - struct llama_context * ctx = llama_new_context_with_model(model, params); - ctx->model_owner = true; - return ctx; -} - -void llama_free(struct llama_context * ctx) { - delete ctx; -} - -int llama_model_quantize( - const char * fname_inp, - const char * fname_out, - const llama_model_quantize_params *params) { - try { - llama_model_quantize_internal(fname_inp, fname_out, params); - return 0; - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what()); - return 1; - } -} - +// TODO: after the GGUF PR, this likely won't work and needs to be updated int llama_apply_lora_from_file_internal(const struct llama_model & model, const char * path_lora, const char * path_base_model, int n_threads) { LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora); @@ -3538,10 +3896,6 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const { uint32_t magic; fin.read((char *) &magic, sizeof(magic)); - if (magic != LLAMA_FILE_MAGIC_GGLA) { - LLAMA_LOG_ERROR("%s: bad file magic\n", __func__); - return 1; - } uint32_t format_version; fin.read((char *) &format_version, sizeof(format_version)); @@ -3559,7 +3913,6 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling); - // create a temporary ggml context to store the lora tensors // todo: calculate size from biggest possible tensor std::vector lora_buf(1024ull * 1024ull * 1024ull); @@ -3573,36 +3926,33 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const // create a name -> tensor map of the model to accelerate lookups std::unordered_map model_tensors; - for (const auto & kv: model.tensors_by_name) { + for (const auto & kv : model.tensors_by_name) { model_tensors.insert(kv); } - // load base model std::unique_ptr model_loader; ggml_context * base_ctx = NULL; - llama_buffer base_buf; + std::vector base_buf; if (path_base_model) { LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); model_loader.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true)); size_t ctx_size; size_t mmapped_size; - model_loader->calc_sizes(&ctx_size, &mmapped_size); + model_loader->calc_sizes(ctx_size, mmapped_size); base_buf.resize(ctx_size); ggml_init_params base_params; - base_params.mem_size = base_buf.size; - base_params.mem_buffer = base_buf.addr; + base_params.mem_size = base_buf.size(); + base_params.mem_buffer = base_buf.data(); base_params.no_alloc = model_loader->use_mmap; base_ctx = ggml_init(base_params); - model_loader->ggml_ctx = base_ctx; - // maybe this should in llama_model_loader if (model_loader->use_mmap) { - model_loader->mapping.reset(new llama_mmap(&model_loader->file_loader->file, /* prefetch */ 0, ggml_is_numa())); + model_loader->mapping.reset(new llama_mmap(&model_loader->file, /* prefetch */ 0, ggml_is_numa())); } } @@ -3707,19 +4057,18 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const ggml_tensor * base_t; if (model_loader) { + struct gguf_context * ctx_gguf = model_loader->ctx_gguf; + // load from base model - if (model_loader->tensors_map.name_to_idx.find(base_name) == model_loader->tensors_map.name_to_idx.end()) { + if (gguf_find_tensor(ctx_gguf, base_name.c_str()) < 0) { LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str()); return 1; } - size_t idx = model_loader->tensors_map.name_to_idx[base_name]; - llama_load_tensor & lt = model_loader->tensors_map.tensors[idx]; - base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); - lt.data = (uint8_t *) lt.ggml_tensor->data; - model_loader->load_data_for(lt); - lt.ggml_tensor->data = lt.data; - } - else { + + // TODO: not tested!! maybe not working! + base_t = model_loader->create_tensor(base_ctx, base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU); + model_loader->load_data_for(base_t); + } else { base_t = dest_t; } @@ -3803,6 +4152,341 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const return 0; } +// +// interface implementation +// + +struct llama_context_params llama_context_default_params() { + struct llama_context_params result = { + /*.seed =*/ LLAMA_DEFAULT_SEED, + /*.n_ctx =*/ 512, + /*.n_batch =*/ 512, + /*.gpu_layers =*/ 0, + /*.main_gpu =*/ 0, + /*.tensor_split =*/ nullptr, + /*.rope_freq_base =*/ 10000.0f, + /*.rope_freq_scale =*/ 1.0f, + /*.progress_callback =*/ nullptr, + /*.progress_callback_user_data =*/ nullptr, + /*.low_vram =*/ false, + /*.mul_mat_q =*/ false, + /*.f16_kv =*/ true, + /*.logits_all =*/ false, + /*.vocab_only =*/ false, + /*.use_mmap =*/ true, + /*.use_mlock =*/ false, + /*.embedding =*/ false, + }; + + return result; +} + +struct llama_model_quantize_params llama_model_quantize_default_params() { + struct llama_model_quantize_params result = { + /*.nthread =*/ 0, + /*.ftype =*/ LLAMA_FTYPE_MOSTLY_Q5_1, + /*.allow_requantize =*/ false, + /*.quantize_output_tensor =*/ true, + }; + + return result; +} + +int llama_max_devices(void) { + return LLAMA_MAX_DEVICES; +} + +bool llama_mmap_supported(void) { + return llama_mmap::SUPPORTED; +} + +bool llama_mlock_supported(void) { + return llama_mlock::SUPPORTED; +} + +void llama_backend_init(bool numa) { + ggml_time_init(); + + // needed to initialize f16 tables + { + struct ggml_init_params params = { 0, NULL, false }; + struct ggml_context * ctx = ggml_init(params); + ggml_free(ctx); + } + + if (numa) { + ggml_numa_init(); + } + +#ifdef GGML_USE_MPI + ggml_mpi_backend_init(); +#endif +} + +void llama_backend_free(void) { +#ifdef GGML_USE_MPI + ggml_mpi_backend_free(); +#endif +} + +int64_t llama_time_us(void) { + return ggml_time_us(); +} + +struct llama_model * llama_load_model_from_file( + const char * path_model, + struct llama_context_params params) { + ggml_time_init(); + + llama_model * model = new llama_model; + + ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; + + if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gpu_layers, + params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale, + params.low_vram, memory_type, params.use_mmap, params.use_mlock, params.vocab_only, + params.progress_callback, params.progress_callback_user_data)) { + LLAMA_LOG_ERROR("%s: failed to load model\n", __func__); + delete model; + return nullptr; + } + + return model; +} + +void llama_free_model(struct llama_model * model) { + delete model; +} + +struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params) { + + if (!model) { + return nullptr; + } + + llama_context * ctx = new llama_context(*model); + + if (params.seed == LLAMA_DEFAULT_SEED) { + params.seed = time(NULL); + } + + unsigned cur_percentage = 0; + if (params.progress_callback == NULL) { + params.progress_callback_user_data = &cur_percentage; + params.progress_callback = [](float progress, void * ctx) { + unsigned * cur_percentage_p = (unsigned *) ctx; + unsigned percentage = (unsigned) (100 * progress); + while (percentage > *cur_percentage_p) { + *cur_percentage_p = percentage; + LLAMA_LOG_INFO("."); + if (percentage >= 100) { + LLAMA_LOG_INFO("\n"); + } + } + }; + } + + ctx->rng = std::mt19937(params.seed); + ctx->logits_all = params.logits_all; + + ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; + + // reserve memory for context buffers + if (!params.vocab_only) { + if (!llama_kv_cache_init(ctx->model.hparams, ctx->kv_self, memory_type, ctx->model.hparams.n_ctx, params.n_gpu_layers)) { + LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); + llama_free(ctx); + return nullptr; + } + + { + const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); + LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); + } + + const auto & hparams = ctx->model.hparams; + + // resized during inference + if (params.logits_all) { + ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); + } else { + ctx->logits.reserve(hparams.n_vocab); + } + + if (params.embedding){ + ctx->embedding.resize(hparams.n_embd); + } + +#ifdef LLAMA_USE_ALLOCATOR + { + static const size_t tensor_alignment = 32; + // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data + ctx->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + + // create measure allocator + ctx->alloc = ggml_allocr_new_measure(tensor_alignment); + + // build worst-case graph + int n_tokens = std::min((int)hparams.n_ctx, params.n_batch); + int n_past = hparams.n_ctx - n_tokens; + llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph + ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); +#ifdef GGML_USE_METAL + if (params.n_gpu_layers > 0) { + ctx->ctx_metal = ggml_metal_init(1); + if (!ctx->ctx_metal) { + LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); + llama_free(ctx); + return NULL; + } + ggml_metal_graph_find_concurrency(ctx->ctx_metal, gf, false); + ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + } +#endif + // measure memory requirements for the graph + size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; + + LLAMA_LOG_INFO("%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); + + // debug - for comparison with scratch buffer + //size_t prev_req = + // MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) + + // MEM_REQ_SCRATCH1().at(ctx->model.type) + + // MEM_REQ_EVAL().at(ctx->model.type); + //LLAMA_LOG_INFO("%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0); + + // recreate allocator with exact memory requirements + ggml_allocr_free(ctx->alloc); + + ctx->buf_alloc.resize(alloc_size); + ctx->alloc = ggml_allocr_new(ctx->buf_alloc.data, ctx->buf_alloc.size, tensor_alignment); +#ifdef GGML_USE_METAL + if (ctx->ctx_metal) { + ggml_allocr_set_parse_seq(ctx->alloc, ggml_metal_get_concur_list(ctx->ctx_metal), ggml_metal_if_optimized(ctx->ctx_metal)); + } +#endif + } +#else + ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); +#endif + +#ifdef LLAMA_USE_SCRATCH + ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); + ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); +#endif + } + +#ifdef GGML_USE_METAL + if (params.n_gpu_layers > 0) { + // this allocates all Metal resources and memory buffers + + void * data_ptr = NULL; + size_t data_size = 0; + + if (params.use_mmap) { + data_ptr = ctx->model.mapping->addr; + data_size = ctx->model.mapping->size; + } else { + data_ptr = ggml_get_mem_buffer(ctx->model.ctx); + data_size = ggml_get_mem_size (ctx->model.ctx); + } + + const size_t max_size = ggml_get_max_tensor_size(ctx->model.ctx); + + LLAMA_LOG_INFO("%s: max tensor size = %8.2f MB\n", __func__, max_size/1024.0/1024.0); + +#define LLAMA_METAL_CHECK_BUF(result) \ + if (!(result)) { \ + LLAMA_LOG_ERROR("%s: failed to add buffer\n", __func__); \ + llama_free(ctx); \ + return NULL; \ + } + + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "data", data_ptr, data_size, max_size)); + + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "eval", ctx->buf_compute.data, ctx->buf_compute.size, 0)); + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "kv", ctx->kv_self.buf.data, ctx->kv_self.buf.size, 0)); + + LLAMA_METAL_CHECK_BUF(ggml_metal_add_buffer(ctx->ctx_metal, "alloc", ctx->buf_alloc.data, ctx->buf_alloc.size, 0)); +#undef LLAMA_METAL_CHECK_BUF + } +#endif + +#ifdef GGML_USE_MPI + ctx->ctx_mpi = ggml_mpi_init(); + + if (ggml_mpi_rank(ctx->ctx_mpi) > 0) { + // Enter a blocking eval loop with dummy input, letting rank=0 drive the process + const std::vector tmp(ctx->model.hparams.n_ctx, llama_token_bos(ctx)); + while (!llama_eval(ctx, tmp.data(), tmp.size(), 0, 0)) {}; + llama_backend_free(); + exit(1); + } +#endif + + return ctx; +} + +struct llama_context * llama_init_from_file( + const char * path_model, + struct llama_context_params params) { + + struct llama_model * model = llama_load_model_from_file(path_model, params); + if (!model) { + return nullptr; + } + struct llama_context * ctx = llama_new_context_with_model(model, params); + ctx->model_owner = true; + return ctx; +} + +void llama_free(struct llama_context * ctx) { + delete ctx; +} + +int llama_n_vocab(const struct llama_context * ctx) { + return ctx->model.vocab.id_to_token.size(); +} + +int llama_n_ctx(const struct llama_context * ctx) { + return ctx->model.hparams.n_ctx; +} + +int llama_n_embd(const struct llama_context * ctx) { + return ctx->model.hparams.n_embd; +} + +int llama_model_n_vocab(const struct llama_model * model) { + return model->vocab.id_to_token.size(); +} + +int llama_model_n_ctx(const struct llama_model * model) { + return model->hparams.n_ctx; +} + +int llama_model_n_embd(const struct llama_model * model) { + return model->hparams.n_embd; +} + +int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size) { + return snprintf(buf, buf_size, "LLaMA %s %s", llama_model_type_name(model->type), llama_model_ftype_name(model->ftype)); +} + +int llama_model_quantize( + const char * fname_inp, + const char * fname_out, + const llama_model_quantize_params * params) { + try { + llama_model_quantize_internal(fname_inp, fname_out, params); + return 0; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: failed to quantize: %s\n", __func__, err.what()); + return 1; + } +} + int llama_apply_lora_from_file(struct llama_context * ctx, const char * path_lora, const char * path_base_model, int n_threads) { try { return llama_apply_lora_from_file_internal(ctx->model, path_lora, path_base_model, n_threads); @@ -3865,6 +4549,46 @@ size_t llama_get_state_size(const struct llama_context * ctx) { return s_total; } +// llama_context_data +struct llama_data_context { + virtual void write(const void * src, size_t size) = 0; + virtual size_t get_size_written() = 0; + virtual ~llama_data_context() = default; +}; + +struct llama_data_buffer_context : llama_data_context { + uint8_t * ptr; + size_t size_written = 0; + + llama_data_buffer_context(uint8_t * p) : ptr(p) {} + + void write(const void * src, size_t size) override { + memcpy(ptr, src, size); + ptr += size; + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + +struct llama_data_file_context : llama_data_context { + llama_file * file; + size_t size_written = 0; + + llama_data_file_context(llama_file * f) : file(f) {} + + void write(const void * src, size_t size) override { + file->write_raw(src, size); + size_written += size; + } + + size_t get_size_written() override { + return size_written; + } +}; + /** copy state data into either a buffer or file depending on the passed in context * * file context: @@ -3998,7 +4722,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { rng_ss.str(std::string(&rng_buf[0], rng_size)); rng_ss >> ctx->rng; - LLAMA_ASSERT(rng_ss.fail() == false); + GGML_ASSERT(rng_ss.fail() == false); } // set logits @@ -4009,7 +4733,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { memcpy(&logits_cap, inp, sizeof(logits_cap)); inp += sizeof(logits_cap); memcpy(&logits_size, inp, sizeof(logits_size)); inp += sizeof(logits_size); - LLAMA_ASSERT(ctx->logits.capacity() == logits_cap); + GGML_ASSERT(ctx->logits.capacity() == logits_cap); if (logits_size) { ctx->logits.resize(logits_size); @@ -4025,7 +4749,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { memcpy(&embedding_size, inp, sizeof(embedding_size)); inp += sizeof(embedding_size); - LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); + GGML_ASSERT(ctx->embedding.capacity() == embedding_size); if (embedding_size) { memcpy(ctx->embedding.data(), inp, embedding_size * sizeof(float)); @@ -4048,7 +4772,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { memcpy(&kv_ntok, inp, sizeof(kv_ntok)); inp += sizeof(kv_ntok); if (kv_size) { - LLAMA_ASSERT(kv_self.buf.size == kv_size); + GGML_ASSERT(kv_self.buf.size == kv_size); const size_t elt_size = ggml_element_size(kv_self.k); @@ -4084,7 +4808,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { const size_t nread = inp - src; const size_t max_size = llama_get_state_size(ctx); - LLAMA_ASSERT(nread <= max_size); + GGML_ASSERT(nread <= max_size); return nread; } @@ -4192,7 +4916,6 @@ int llama_eval( return 0; } - int llama_eval_embd( struct llama_context * ctx, const float * embd, @@ -4218,7 +4941,7 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) { const int n_batch = 1; const int n_ctx = 512 - n_batch; - const std::vector tmp(n_batch, llama_token_bos()); + const std::vector tmp(n_batch, llama_token_bos(ctx)); if (!llama_eval_internal(*ctx, tmp.data(), nullptr, tmp.size(), n_ctx, 1, fname)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); @@ -4228,13 +4951,54 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) { return 0; } -int llama_tokenize_with_model( - const struct llama_model * model, +float * llama_get_logits(struct llama_context * ctx) { + return ctx->logits.data(); +} + +float * llama_get_embeddings(struct llama_context * ctx) { + return ctx->embedding.data(); +} + +const char * llama_token_get_text(const struct llama_context * ctx, llama_token token) { + return ctx->model.vocab.id_to_token[token].text.c_str(); +} + +float llama_token_get_score(const struct llama_context * ctx, llama_token token) { + return ctx->model.vocab.id_to_token[token].score; +} + +llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token) { + return ctx->model.vocab.id_to_token[token].type; +} + +llama_token llama_token_bos(const struct llama_context * ctx) { + return ctx->model.vocab.special_bos_id; +} + +llama_token llama_token_eos(const struct llama_context * ctx) { + return ctx->model.vocab.special_eos_id; +} + +llama_token llama_token_nl(const struct llama_context * ctx) { + return ctx->model.vocab.linefeed_id; +} + +int llama_tokenize( + struct llama_context * ctx, const char * text, llama_token * tokens, int n_max_tokens, bool add_bos) { - auto res = llama_tokenize(model->vocab, text, add_bos); + return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); +} + +int llama_tokenize_bpe( + struct llama_context * ctx, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos) { + auto res = llama_tokenize_internal(ctx->model.vocab, text, add_bos, false); if (n_max_tokens < (int) res.size()) { LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); @@ -4248,94 +5012,75 @@ int llama_tokenize_with_model( return res.size(); } -int llama_tokenize( - struct llama_context * ctx, +int llama_tokenize_with_model( + const struct llama_model * model, const char * text, llama_token * tokens, int n_max_tokens, bool add_bos) { - return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos); -} + auto escape = llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM; + auto res = llama_tokenize_internal(model->vocab, text, add_bos, escape); -int llama_n_vocab_from_model(const struct llama_model * model) { - return model->vocab.id_to_token.size(); -} - -int llama_n_ctx_from_model(const struct llama_model * model) { - return model->hparams.n_ctx; -} - -int llama_n_embd_from_model(const struct llama_model * model) { - return model->hparams.n_embd; -} - -int llama_n_vocab(const struct llama_context * ctx) { - return ctx->model.vocab.id_to_token.size(); -} - -int llama_n_ctx(const struct llama_context * ctx) { - return ctx->model.hparams.n_ctx; -} - -int llama_n_embd(const struct llama_context * ctx) { - return ctx->model.hparams.n_embd; -} - -int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size) { - return snprintf(buf, buf_size, "LLaMA %s %s", llama_model_type_name(model->type), llama_ftype_name(model->hparams.ftype)); -} - -int llama_get_vocab_from_model( - const struct llama_model * model, - const char * * strings, - float * scores, - int capacity) { - int n = std::min(capacity, (int) model->vocab.id_to_token.size()); - for (int i = 0; ivocab.id_to_token[i].tok.c_str(); - scores[i] = model->vocab.id_to_token[i].score; - } - return n; -} - -int llama_get_vocab( - const struct llama_context * ctx, - const char * * strings, - float * scores, - int capacity) { - return llama_get_vocab_from_model(&ctx->model, strings, scores, capacity); -} - -float * llama_get_logits(struct llama_context * ctx) { - return ctx->logits.data(); -} - -float * llama_get_embeddings(struct llama_context * ctx) { - return ctx->embedding.data(); -} - -const char * llama_token_to_str_with_model(const struct llama_model * model, llama_token token) { - if (token >= llama_n_vocab_from_model(model)) { - return nullptr; + if (n_max_tokens < (int) res.size()) { + LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); + return -((int) res.size()); } - return model->vocab.id_to_token[token].tok.c_str(); + for (size_t i = 0; i < res.size(); i++) { + tokens[i] = res[i]; + } + + return res.size(); } -const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) { - return llama_token_to_str_with_model(&ctx->model, token); +int llama_token_to_str(const struct llama_context * ctx, llama_token token, char * buf, int length) { + return llama_token_to_str_with_model(&ctx->model, token, buf, length); } -llama_token llama_token_bos() { - return 1; +int llama_token_to_str_bpe(const struct llama_context * ctx, llama_token token, char * buf, int length) { + if (0 <= token && token < llama_model_n_vocab(&ctx->model)) { + std::string result = ctx->model.vocab.id_to_token[token].text; + if (length < (int) result.length()) { + return -result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } + return 0; } -llama_token llama_token_eos() { - return 2; -} - -llama_token llama_token_nl() { - return 13; +// does not write null-terminator to str +int llama_token_to_str_with_model(const struct llama_model * model, llama_token token, char * buf, int length) { + if (0 <= token && token < llama_model_n_vocab(model)) { + if (llama_is_normal_token(model->vocab, token)) { + std::string result = model->vocab.id_to_token[token].text; + if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) { + result = llama_unescape_whitespace(result); + } + if (length < (int) result.length()) { + return -result.length(); + } + memcpy(buf, result.c_str(), result.length()); + return result.length(); + } else if (llama_is_unknown_token(model->vocab, token)) { // NOLINT + if (length < 3) { + return -3; + } + buf[0] = '\xe2'; + buf[1] = '\x96'; + buf[2] = '\x85'; + return 3; + } else if (llama_is_control_token(model->vocab, token)) { + ; + } else if (llama_is_byte_token(model->vocab, token)) { + if (length < 1) { + return -1; + } + buf[0] = llama_token_to_byte(model->vocab, token); + return 1; + } + } + return 0; } struct llama_timings llama_get_timings(struct llama_context * ctx) { @@ -4403,7 +5148,6 @@ const std::vector>& llama_internal_ return ctx->model.tensors_by_name; } - void llama_log_set(llama_log_callback log_callback, void * user_data) { g_state.log_callback = log_callback ? log_callback : llama_log_callback_default; g_state.log_callback_user_data = user_data; diff --git a/llama.h b/llama.h index 9d732f914..aa5b7d69c 100644 --- a/llama.h +++ b/llama.h @@ -34,29 +34,18 @@ # define DEPRECATED(func, hint) func #endif -#define LLAMA_FILE_MAGIC_GGJT 0x67676a74u // 'ggjt' -#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' -#define LLAMA_FILE_MAGIC_GGMF 0x67676d66u // 'ggmf' -#define LLAMA_FILE_MAGIC_GGML 0x67676d6cu // 'ggml' -#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' +#define LLAMA_DEFAULT_SEED 0xFFFFFFFF -#define LLAMA_FILE_VERSION 3 -#define LLAMA_FILE_MAGIC LLAMA_FILE_MAGIC_GGJT -#define LLAMA_FILE_MAGIC_UNVERSIONED LLAMA_FILE_MAGIC_GGML -#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 1 +#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' -#define LLAMA_DEFAULT_SEED 0xFFFFFFFF +#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN +#define LLAMA_SESSION_VERSION 1 #if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL) // Defined when llama.cpp is compiled with support for offloading model layers to GPU. #define LLAMA_SUPPORTS_GPU_OFFLOAD #endif -#ifndef LLAMA_DEFAULT_RMS_EPS -#define LLAMA_DEFAULT_RMS_EPS 5e-6f -#endif - #ifdef __cplusplus extern "C" { #endif @@ -72,6 +61,50 @@ extern "C" { typedef int llama_token; + enum llama_log_level { + LLAMA_LOG_LEVEL_ERROR = 2, + LLAMA_LOG_LEVEL_WARN = 3, + LLAMA_LOG_LEVEL_INFO = 4 + }; + + enum llama_vocab_type { + LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece + LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding + }; + + enum llama_token_type { + LLAMA_TOKEN_TYPE_UNDEFINED = 0, + LLAMA_TOKEN_TYPE_NORMAL = 1, + LLAMA_TOKEN_TYPE_UNKNOWN = 2, + LLAMA_TOKEN_TYPE_CONTROL = 3, + LLAMA_TOKEN_TYPE_USER_DEFINED = 4, + LLAMA_TOKEN_TYPE_UNUSED = 5, + LLAMA_TOKEN_TYPE_BYTE = 6, + }; + + // model file types + enum llama_ftype { + LLAMA_FTYPE_ALL_F32 = 0, + LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 + // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed + // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed + LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors + LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors + }; + typedef struct llama_token_data { llama_token id; // token id float logit; // log-odds of the token @@ -86,25 +119,10 @@ extern "C" { typedef void (*llama_progress_callback)(float progress, void *ctx); - enum llama_log_level { - LLAMA_LOG_LEVEL_ERROR = 2, - LLAMA_LOG_LEVEL_WARN = 3, - LLAMA_LOG_LEVEL_INFO = 4 - }; - - // Signature for logging events - // Note that text includes the new line character at the end for most events. - // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it - // if it exists. - // It might not exist for progress report where '.' is output repeatedly. - typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); - struct llama_context_params { uint32_t seed; // RNG seed, -1 for random int32_t n_ctx; // text context int32_t n_batch; // prompt processing batch size - int32_t n_gqa; // grouped-query attention (TEMP - will be moved to model hparams) - float rms_norm_eps; // rms norm epsilon (TEMP - will be moved to model hparams) int32_t n_gpu_layers; // number of layers to store in VRAM int32_t main_gpu; // the GPU that is used for scratch and small tensors @@ -129,33 +147,18 @@ extern "C" { bool use_mlock; // force system to keep model in RAM bool embedding; // embedding mode only }; - // model file types - enum llama_ftype { - LLAMA_FTYPE_ALL_F32 = 0, - LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16 - // LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed - // LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed - LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors - LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors - LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors - }; + + // Signature for logging events + // Note that text includes the new line character at the end for most events. + // If your logging mechanism cannot handle that, check if the last character is '\n' and strip it + // if it exists. + // It might not exist for progress report where '.' is output repeatedly. + typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data); // model quantization parameters typedef struct llama_model_quantize_params { int nthread; // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() - enum llama_ftype ftype; // quantize to this llama_ftype + enum llama_ftype ftype; // quantize to this llama_ftype bool allow_requantize; // allow quantizing non-f32/f16 tensors bool quantize_output_tensor; // quantize output.weight } llama_model_quantize_params; @@ -208,27 +211,16 @@ extern "C" { int32_t n_eval; }; - // Set callback for all future logging events. - // If this is not called, or NULL is supplied, everything is output on stderr. - LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data); + LLAMA_API struct llama_context_params llama_context_default_params(void); + LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void); - LLAMA_API int llama_max_devices(); - - LLAMA_API struct llama_context_params llama_context_default_params(); - LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(); - - LLAMA_API bool llama_mmap_supported(); - LLAMA_API bool llama_mlock_supported(); - - // TODO: not great API - very likely to change // Initialize the llama + ggml backend // If numa is true, use NUMA optimizations // Call once at the start of the program LLAMA_API void llama_backend_init(bool numa); - // Call once at the end of the program - currently only used for MPI - LLAMA_API void llama_backend_free(); - LLAMA_API int64_t llama_time_us(); + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_backend_free(void); LLAMA_API struct llama_model * llama_load_model_from_file( const char * path_model, @@ -240,17 +232,26 @@ extern "C" { struct llama_model * model, struct llama_context_params params); - // Various functions for loading a ggml llama model. - // Allocate (almost) all memory needed for the model. - // Return NULL on failure - LLAMA_API DEPRECATED(struct llama_context * llama_init_from_file( - const char * path_model, - struct llama_context_params params), - "please use llama_load_model_from_file combined with llama_new_context_with_model instead"); - // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); + LLAMA_API int64_t llama_time_us(void); + + LLAMA_API int llama_max_devices (void); + LLAMA_API bool llama_mmap_supported (void); + LLAMA_API bool llama_mlock_supported(void); + + LLAMA_API int llama_n_vocab(const struct llama_context * ctx); + LLAMA_API int llama_n_ctx (const struct llama_context * ctx); + LLAMA_API int llama_n_embd (const struct llama_context * ctx); + + LLAMA_API int llama_model_n_vocab(const struct llama_model * model); + LLAMA_API int llama_model_n_ctx (const struct llama_model * model); + LLAMA_API int llama_model_n_embd (const struct llama_model * model); + + // Get a string describing the model type + LLAMA_API int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size); + // Returns 0 on success LLAMA_API int llama_model_quantize( const char * fname_inp, @@ -272,9 +273,9 @@ extern "C" { LLAMA_API int llama_model_apply_lora_from_file( const struct llama_model * model, - const char * path_lora, - const char * path_base_model, - int n_threads); + const char * path_lora, + const char * path_base_model, + int n_threads); // Returns the number of tokens in the KV cache LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); @@ -324,11 +325,40 @@ extern "C" { // IMPORTANT: do not use for anything else other than debugging and testing! LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname); + // Token logits obtained from the last call to llama_eval() + // The logits for the last token are stored in the last row + // Can be mutated in order to change the probabilities of the next token + // Rows: n_tokens + // Cols: n_vocab + LLAMA_API float * llama_get_logits(struct llama_context * ctx); + + // Get the embeddings for the input + // shape: [n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); + + // + // Vocab + // + + LLAMA_API const char * llama_token_get_text(const struct llama_context * ctx, llama_token token); + + LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token); + + LLAMA_API llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token); + + // Special tokens + LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence + LLAMA_API llama_token llama_token_eos(const struct llama_context * ctx); // end-of-sentence + LLAMA_API llama_token llama_token_nl (const struct llama_context * ctx); // next-line + + // + // Tokenization + // + // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. // Returns the number of tokens on success, no more than n_max_tokens // Returns a negative number on failure - the number of tokens that would have been returned - // TODO: not sure if correct LLAMA_API int llama_tokenize( struct llama_context * ctx, const char * text, @@ -336,6 +366,13 @@ extern "C" { int n_max_tokens, bool add_bos); + LLAMA_API int llama_tokenize_bpe( + struct llama_context * ctx, + const char * text, + llama_token * tokens, + int n_max_tokens, + bool add_bos); + LLAMA_API int llama_tokenize_with_model( const struct llama_model * model, const char * text, @@ -343,57 +380,30 @@ extern "C" { int n_max_tokens, bool add_bos); - LLAMA_API int llama_n_vocab(const struct llama_context * ctx); - LLAMA_API int llama_n_ctx (const struct llama_context * ctx); - LLAMA_API int llama_n_embd (const struct llama_context * ctx); - - LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model); - LLAMA_API int llama_n_ctx_from_model (const struct llama_model * model); - LLAMA_API int llama_n_embd_from_model (const struct llama_model * model); - - LLAMA_API int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size); - - // Get the vocabulary as output parameters. - // Returns number of results. - LLAMA_API int llama_get_vocab( - const struct llama_context * ctx, - const char * * strings, - float * scores, - int capacity); - - LLAMA_API int llama_get_vocab_from_model( - const struct llama_model * model, - const char * * strings, - float * scores, - int capacity); - - // Token logits obtained from the last call to llama_eval() - // The logits for the last token are stored in the last row - // Can be mutated in order to change the probabilities of the next token - // Rows: n_tokens - // Cols: n_vocab - LLAMA_API float * llama_get_logits(struct llama_context * ctx); - - // Get the embeddings for the input - // shape: [n_embd] (1-dimensional) - LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); - // Token Id -> String. Uses the vocabulary in the provided context - LLAMA_API const char * llama_token_to_str( + // Does not write null terminator to the buffer + LLAMA_API int llama_token_to_str( const struct llama_context * ctx, - llama_token token); + llama_token token, + char * buf, + int length); - LLAMA_API const char * llama_token_to_str_with_model( + LLAMA_API int llama_token_to_str_bpe( + const struct llama_context * ctx, + llama_token token, + char * buf, + int length); + + LLAMA_API int llama_token_to_str_with_model( const struct llama_model * model, - llama_token token); - - // Special tokens - LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence - LLAMA_API llama_token llama_token_eos(); // end-of-sentence - LLAMA_API llama_token llama_token_nl(); // next-line + llama_token token, + char * buf, + int length); + // // Grammar // + LLAMA_API struct llama_grammar * llama_grammar_init( const llama_grammar_element ** rules, size_t n_rules, @@ -401,7 +411,9 @@ extern "C" { LLAMA_API void llama_grammar_free(struct llama_grammar * grammar); + // // Sampling functions + // /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty); @@ -470,6 +482,10 @@ extern "C" { // Print system information LLAMA_API const char * llama_print_system_info(void); + // Set callback for all future logging events. + // If this is not called, or NULL is supplied, everything is output on stderr. + LLAMA_API void llama_log_set(llama_log_callback log_callback, void * user_data); + #ifdef __cplusplus } #endif @@ -479,10 +495,11 @@ extern "C" { #include #include + struct ggml_tensor; const std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx); -#endif +#endif // LLAMA_API_INTERNAL #endif // LLAMA_H diff --git a/models/.editorconfig b/models/.editorconfig new file mode 100644 index 000000000..78b36ca08 --- /dev/null +++ b/models/.editorconfig @@ -0,0 +1 @@ +root = true diff --git a/models/ggml-vocab-llama.gguf b/models/ggml-vocab-llama.gguf new file mode 100644 index 000000000..63bfaf672 Binary files /dev/null and b/models/ggml-vocab-llama.gguf differ diff --git a/models/ggml-vocab.bin b/models/ggml-vocab.bin deleted file mode 100644 index 38f63493a..000000000 Binary files a/models/ggml-vocab.bin and /dev/null differ diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 276f39b3b..4ccefe932 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -1,17 +1,36 @@ -function(llama_add_test source) +function(llama_build_executable source) get_filename_component(TEST_TARGET ${source} NAME_WE) add_executable(${TEST_TARGET} ${source}) install(TARGETS ${TEST_TARGET} RUNTIME) - target_link_libraries(${TEST_TARGET} PRIVATE llama) + target_link_libraries(${TEST_TARGET} PRIVATE llama common) +endfunction() + +function(llama_test_executable name source) + get_filename_component(TEST_TARGET ${source} NAME_WE) + # add_executable(${TEST_TARGET} ${source}) + # install(TARGETS ${TEST_TARGET} RUNTIME) + # target_link_libraries(${TEST_TARGET} PRIVATE llama) + add_test(NAME ${name} COMMAND $ ${ARGN}) +endfunction() + +function(llama_build_and_test_executable source) + get_filename_component(TEST_TARGET ${source} NAME_WE) + add_executable(${TEST_TARGET} ${source}) + install(TARGETS ${TEST_TARGET} RUNTIME) + target_link_libraries(${TEST_TARGET} PRIVATE llama common) add_test(NAME ${TEST_TARGET} COMMAND $ ${ARGN}) endfunction() -# llama_add_test(test-double-float.cpp) # SLOW -llama_add_test(test-quantize-fns.cpp) -llama_add_test(test-quantize-perf.cpp) -llama_add_test(test-sampling.cpp) -llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin) -llama_add_test(test-grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp) -llama_add_test(test-llama-grammar.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/grammar-parser.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../llama.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../examples/common.cpp) -llama_add_test(test-grad0.cpp) # SLOW -# llama_add_test(test-opt.cpp) # SLOW +# llama_build_and_test_executable(test-double-float.cpp) # SLOW +llama_build_and_test_executable(test-quantize-fns.cpp) +llama_build_and_test_executable(test-quantize-perf.cpp) +llama_build_and_test_executable(test-sampling.cpp) +llama_build_executable(test-tokenizer-0.cpp) +llama_test_executable (test-tokenizer-0.llama test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) +llama_build_executable(test-tokenizer-1.cpp) +llama_test_executable (test-tokenizer-1.llama test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama.gguf) +#llama_test_executable(test-tokenizer-1.aquila test-tokenizer-1.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-aquila.gguf) +llama_build_and_test_executable(test-grammar-parser.cpp) +llama_build_and_test_executable(test-llama-grammar.cpp) +llama_build_and_test_executable(test-grad0.cpp) # SLOW +# llama_build_and_test_executable(test-opt.cpp) # SLOW diff --git a/tests/test-grammar-parser.cpp b/tests/test-grammar-parser.cpp index 7022988b4..a0b5b043d 100644 --- a/tests/test-grammar-parser.cpp +++ b/tests/test-grammar-parser.cpp @@ -3,7 +3,8 @@ #endif #include "llama.h" -#include "examples/grammar-parser.cpp" +#include "grammar-parser.h" + #include int main() diff --git a/tests/test-llama-grammar.cpp b/tests/test-llama-grammar.cpp index 81c31e9e2..73dd33dd2 100644 --- a/tests/test-llama-grammar.cpp +++ b/tests/test-llama-grammar.cpp @@ -2,9 +2,9 @@ #undef NDEBUG #endif -#include "llama.cpp" -#include "examples/common.cpp" -#include "examples/grammar-parser.cpp" +#include "llama.cpp" // TODO: not great +#include "grammar-parser.h" + #include int main() diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 87fde1645..81764565b 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -1,22 +1,47 @@ #include "llama.h" +#include "common.h" #include #include #include #include -static const std::map> & k_tests() -{ +static std::string unescape_whitespace(llama_context* ctx, const std::vector& tokens) { + std::string result; + for (size_t i = 0; i < tokens.size(); ++i) { + result += llama_token_to_str(ctx, tokens[i]); + } + return result; +} + +static const std::map> & k_tests() { static std::map> _k_tests = { - { "Hello World", { 1, 10994, 2787, }, }, - { " Hello World", { 1, 15043, 2787, }, }, - { " Hello World!", { 1, 15043, 2787, 29991, }, }, - { " this is 🦙.cpp", { 1, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, }, - { "w048 7tuijk dsdfhu", { 1, 29893, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, - { "нещо на Български", { 1, 821, 4851, 665, 1386, 29713, 1305, }, }, + { " ", {1, 259, }, }, + { "\t", { 1, 29871, 12, }, }, + { "\n", { 1, 29871, 13, }, }, + { "\t\n", { 1, 29871, 12, 13, }, }, + { "Hello world", { 1, 15043, 3186, }, }, + { " Hello world", { 1, 29871, 15043, 3186, }, }, + { "Hello World", { 1, 15043, 2787, }, }, + { " Hello World", { 1, 29871, 15043, 2787, }, }, + { " Hello World!", { 1, 29871, 15043, 2787, 29991, }, }, + { " this is 🦙.cpp", { 1, 29871, 445, 338, 29871, 243, 162, 169, 156, 29889, 8223, }, }, + { "w048 7tuijk dsdfhu", { 1, 281, 29900, 29946, 29947, 29871, 29955, 9161, 13535, 18031, 2176, 6905, }, }, + { "нещо на Български", { 1, 1538, 4851, 665, 1386, 29713, 1305, }, }, + { "កាន់តែពិសេសអាចខលចេញ", { 1, 29871, 31849, 31324, 31934, 228, 162, 142, 228, 161, + 146, 228, 162, 133, 228, 161, 153, 228, 161, 186, + 31708, 228, 162, 132, 31708, 228, 161, 165, 31324, 228, + 161, 136, 228, 161, 132, 228, 161, 158, 228, 161, + 136, 228, 162, 132, 228, 161, 140, }, }, + { "🚀 (normal) 😶‍🌫️ (multiple emojis concatenated) ✅ (only emoji that has its own token)", + { 1, 29871, 243, 162, 157, 131, 313, 8945, 29897, 29871, + 243, 162, 155, 185, 30722, 243, 162, 143, 174, 30598, + 313, 20787, 953, 3848, 275, 16125, 630, 29897, 29871, 31681, + 313, 6194, 953, 29877, 2397, 393, 756, 967, 1914, 5993, 29897, }, }, }; + return _k_tests; -}; +} int main(int argc, char **argv) { if (argc < 2) { @@ -64,10 +89,12 @@ int main(int argc, char **argv) { return 2; } + bool success = true; + for (const auto & test_kv : k_tests()) { - std::vector res(test_kv.first.size()); - const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true); - res.resize(n); + std::vector res = llama_tokenize(ctx, test_kv.first, true); + fprintf(stderr, "%s : '%s' tokenized to '%s'\n", + __func__, test_kv.first.c_str(), unescape_whitespace(ctx, res).c_str()); bool correct = res.size() == test_kv.second.size(); @@ -78,7 +105,8 @@ int main(int argc, char **argv) { } if (!correct) { - fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); + fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str()); + fprintf(stderr, "%s : detokenized to: '%s'\n", __func__, unescape_whitespace(ctx, test_kv.second).c_str()); fprintf(stderr, "%s : expected tokens: ", __func__); for (const auto & t : test_kv.second) { fprintf(stderr, "%6d, ", t); @@ -90,9 +118,7 @@ int main(int argc, char **argv) { } fprintf(stderr, "\n"); - llama_free_model(model); - llama_free(ctx); - return 3; + success = false; } } @@ -101,5 +127,5 @@ int main(int argc, char **argv) { llama_backend_free(); - return 0; + return success ? 0 : 3; } diff --git a/tests/test-tokenizer-1.cpp b/tests/test-tokenizer-1.cpp new file mode 100644 index 000000000..d8db7cd96 --- /dev/null +++ b/tests/test-tokenizer-1.cpp @@ -0,0 +1,131 @@ +#include "llama.h" +#include "common.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +static std::string escape_whitespace(const std::string& text) { + std::string result; + bool escaping = false; + result += "\xe2\x96\x81"; + for (size_t offs = 0; offs < text.length(); ++offs) { + if (text[offs] == ' ') { + if (!escaping) { + result += "\xe2\x96\x81"; + escaping = true; + } + } + else { + escaping = false; + result += text[offs]; + } + } + return result; +} + +static std::string unescape_whitespace(llama_context * ctx, const std::vector & tokens) { + std::string result; + for (size_t i = 0; i < tokens.size(); ++i) { + result += llama_token_to_str(ctx, tokens[i]); + } + return result; +} + +int main(int argc, char **argv) { + if (argc < 2) { + fprintf(stderr, "Usage: %s \n", argv[0]); + return 1; + } + + const std::string fname = argv[1]; + + fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str()); + + llama_model * model; + llama_context * ctx; + + llama_backend_init(false); + + // load the vocab + { + auto lparams = llama_context_default_params(); + + lparams.vocab_only = true; + + model = llama_load_model_from_file(fname.c_str(), lparams); + + if (model == NULL) { + fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); + return 1; + } + + ctx = llama_new_context_with_model(model, lparams); + + if (ctx == NULL) { + fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); + llama_free_model(model); + return 1; + } + } + + const int n_vocab = llama_n_vocab(ctx); + + for (int i = 0; i < n_vocab; ++i) { + std::string forward = llama_token_to_str_bpe(ctx, i); + std::vector tokens = llama_tokenize_bpe(ctx, forward, false); + if (tokens.size() == 1) { + if (i != tokens[0]) { + std::string backward = llama_token_to_str(ctx, tokens[0]); + fprintf(stderr, "%s : error: token %d is string %s but bpe returns token %d %s\n", + __func__, i, llama_token_to_str(ctx, i).c_str(), tokens[0], backward.c_str()); + return 2; + } + } else { + llama_token_type type = llama_token_get_type(ctx, i); + if (type == LLAMA_TOKEN_TYPE_UNKNOWN || type == LLAMA_TOKEN_TYPE_CONTROL || type == LLAMA_TOKEN_TYPE_BYTE) { + fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n", + __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str()); + } else { + fprintf(stderr, "%s : error: token %d is string %s but bpe returns tokens %s\n", + __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str()); + return 2; + } + } + } + +#ifdef _WIN32 + std::wstring_convert, char16_t> u16converter; + for (char16_t ch = 0x0000; ch < 0xffff; ++ch) { + std::u16string u16str(1, ch); + std::string str = u16converter.to_bytes(u16str); + std::vector tokens = llama_tokenize(ctx, escape_whitespace(str).c_str(), false); + if (tokens.size() == 1) { + fprintf(stderr, "%s : info: %s tokenized to %d \n", + __func__, str.c_str(), tokens[0]); + } + } + + std::wstring_convert, char32_t> u32converter; + for (char32_t ch = 0x0000; ch < 0x0010ffff; ++ch) { + std::u32string u32str(1, ch); + std::string str = u32converter.to_bytes(u32str); + std::vector tokens = llama_tokenize(ctx, escape_whitespace(str).c_str(), false); + if (tokens.size() == 1) { + fprintf(stderr, "%s : info: %s tokenized to %d \n", __func__, str.c_str(), tokens[0]); + } + } +#endif + + llama_free_model(model); + llama_free(ctx); + + llama_backend_free(); + + return 0; +}