diff --git a/README.md b/README.md index 60f14a1fb..4fd4bd427 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++ ### Hot topics - +- ‼️ BPE tokenizer update: existing Falcon and Starcoder `.gguf` models will need to be reconverted: [#3252](https://github.com/ggerganov/llama.cpp/pull/3252) - ‼️ Breaking change: `rope_freq_base` and `rope_freq_scale` must be set to zero to use the model default values: [#3401](https://github.com/ggerganov/llama.cpp/pull/3401) - Parallel decoding + continuous batching support added: [#3228](https://github.com/ggerganov/llama.cpp/pull/3228) \ **Devs should become familiar with the new API** @@ -89,15 +89,17 @@ as the main playground for developing new features for the [ggml](https://github - [X] [Vicuna](https://github.com/ggerganov/llama.cpp/discussions/643#discussioncomment-5533894) - [X] [Koala](https://bair.berkeley.edu/blog/2023/04/03/koala/) - [X] [OpenBuddy 🐶 (Multilingual)](https://github.com/OpenBuddy/OpenBuddy) -- [X] [Pygmalion 7B / Metharme 7B](#using-pygmalion-7b--metharme-7b) +- [X] [Pygmalion/Metharme](#using-pygmalion-7b--metharme-7b) - [X] [WizardLM](https://github.com/nlpxucan/WizardLM) -- [X] [Baichuan-7B](https://huggingface.co/baichuan-inc/baichuan-7B) and its derivations (such as [baichuan-7b-sft](https://huggingface.co/hiyouga/baichuan-7b-sft)) -- [X] [Aquila-7B](https://huggingface.co/BAAI/Aquila-7B) / [AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) +- [X] [Baichuan 1 & 2](https://huggingface.co/models?search=baichuan-inc/Baichuan) + [derivations](https://huggingface.co/hiyouga/baichuan-7b-sft) +- [X] [Aquila 1 & 2](https://huggingface.co/models?search=BAAI/Aquila) - [X] [Starcoder models](https://github.com/ggerganov/llama.cpp/pull/3187) - [X] [Mistral AI v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) - [X] [Refact](https://huggingface.co/smallcloudai/Refact-1_6B-fim) -- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553) +- [X] [Persimmon 8B](https://github.com/ggerganov/llama.cpp/pull/3410) - [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417) +- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553) + **Bindings:** @@ -206,7 +208,7 @@ https://user-images.githubusercontent.com/1991296/224442907-7693d4be-acaa-4e01-8 ## Usage -Here are the steps for the LLaMA-7B model. +Here are the end-to-end binary build and model conversion steps for the LLaMA-7B model. ### Get the Code @@ -573,6 +575,18 @@ python3 convert.py models/7B/ When running the larger models, make sure you have enough disk space to store all the intermediate files. +### Running on Windows with prebuilt binaries + +You will find prebuilt Windows binaries on the release page. + +Simply download and extract the latest zip package of choice: (e.g. `llama-b1380-bin-win-avx2-x64.zip`) + +From the unzipped folder, open a terminal/cmd window here and place a pre-converted `.gguf` model file. Test out the main example like so: + +``` +.\main -m llama-2-7b.Q4_0.gguf -n 128 +``` + ### Memory/Disk Requirements As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same. diff --git a/ci/run.sh b/ci/run.sh index 34c9129c1..2e3343831 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -208,6 +208,8 @@ function gg_run_open_llama_3b_v2 { (time ./bin/perplexity --model ${model_q5_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log (time ./bin/perplexity --model ${model_q6_k} -f ${wiki_test_60} -c 128 -b 128 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + function check_ppl { qnt="$1" ppl=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) @@ -296,6 +298,7 @@ function gg_sum_open_llama_3b_v2 { gg_printf '- q4_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_k.log)" gg_printf '- q5_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_k.log)" gg_printf '- q6_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q6_k.log)" + gg_printf '- save-load-state: \n```\n%s\n```\n' "$(cat $OUT/${ci}-save-load-state.log)" gg_printf '- shakespeare (f16):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-f16.log)" gg_printf '- shakespeare (f16 lora):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-lora-f16.log)" gg_printf '- shakespeare (q8_0):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-q8_0.log)" @@ -382,6 +385,8 @@ function gg_run_open_llama_7b_v2 { (time ./bin/perplexity --model ${model_q5_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q5_k.log (time ./bin/perplexity --model ${model_q6_k} -f ${wiki_test} -t 1 -ngl 999 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-tg-q6_k.log + (time ./bin/save-load-state --model ${model_q4_0} ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log + function check_ppl { qnt="$1" ppl=$(echo "$2" | grep -oE "[0-9]+\.[0-9]+" | tail -n 1) @@ -470,6 +475,7 @@ function gg_sum_open_llama_7b_v2 { gg_printf '- q4_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q4_k.log)" gg_printf '- q5_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q5_k.log)" gg_printf '- q6_k:\n```\n%s\n```\n' "$(cat $OUT/${ci}-tg-q6_k.log)" + gg_printf '- save-load-state: \n```\n%s\n```\n' "$(cat $OUT/${ci}-save-load-state.log)" gg_printf '- shakespeare (f16):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-f16.log)" gg_printf '- shakespeare (f16 lora):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-lora-f16.log)" #gg_printf '- shakespeare (q8_0):\n```\n%s\n```\n' "$(cat $OUT/${ci}-ppl-shakespeare-q8_0.log)" diff --git a/common/common.cpp b/common/common.cpp index 9c4f7df20..3e4b8a8cb 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -879,21 +879,23 @@ std::tuple llama_init_from_gpt_par std::vector llama_tokenize( const struct llama_context * ctx, const std::string & text, - bool add_bos) { - return llama_tokenize(llama_get_model(ctx), text, add_bos); + bool add_bos, + bool special) { + return llama_tokenize(llama_get_model(ctx), text, add_bos, special); } std::vector llama_tokenize( const struct llama_model * model, const std::string & text, - bool add_bos) { + bool add_bos, + bool special) { // upper limit for the number of tokens int n_tokens = text.length() + add_bos; std::vector result(n_tokens); - n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); + n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); if (n_tokens < 0) { result.resize(-n_tokens); - int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos); + int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, special); GGML_ASSERT(check == -n_tokens); } else { result.resize(n_tokens); diff --git a/common/common.h b/common/common.h index 36fd44166..08c603231 100644 --- a/common/common.h +++ b/common/common.h @@ -137,12 +137,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param std::vector llama_tokenize( const struct llama_context * ctx, const std::string & text, - bool add_bos); + bool add_bos, + bool special = false); std::vector llama_tokenize( const struct llama_model * model, const std::string & text, - bool add_bos); + bool add_bos, + bool special = false); // tokenizes a token into a piece // should work similar to Python's `tokenizer.id_to_piece` diff --git a/common/train.cpp b/common/train.cpp index 35a4cf9e6..972eaefe0 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -863,7 +863,7 @@ size_t tokenize_file( (int) buf.size(), out_tokens.data(), (int) out_tokens.size(), - false); + false, false); if (n_tokens < 0) { out_tokens.resize(-n_tokens); n_tokens = llama_tokenize( @@ -872,7 +872,7 @@ size_t tokenize_file( (int) buf.size(), out_tokens.data(), (int) out_tokens.size(), - false); + false, false); } if (n_tokens >= 0) { out_tokens.resize(n_tokens); @@ -966,7 +966,7 @@ size_t tokenize_file( (int) buf_sample.size(), tok_sample.data(), (int) tok_sample.size(), - false); + false, false); if (n_tokens < 0) { tok_sample.resize(-n_tokens); n_tokens = llama_tokenize(llama_get_model(lctx), @@ -974,7 +974,7 @@ size_t tokenize_file( (int) buf_sample.size(), tok_sample.data(), (int) tok_sample.size(), - false); + false, false); GGML_ASSERT(n_tokens >= 0); } GGML_ASSERT(n_tokens <= (int) tok_sample.size()); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 938f30512..05d1bb9d0 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -209,7 +209,7 @@ llama_print_timings(context) private func tokenize(text: String, add_bos: Bool) -> [llama_token] { let n_tokens = text.count + (add_bos ? 1 : 0) let tokens = UnsafeMutablePointer.allocate(capacity: n_tokens) - let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos) + let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false) var swiftTokens: [llama_token] = [] for i in 0 ..< tokenCount { swiftTokens.append(tokens[Int(i)]) diff --git a/examples/llava/llava-utils.h b/examples/llava/llava-utils.h index 79e237c86..4e71351dd 100644 --- a/examples/llava/llava-utils.h +++ b/examples/llava/llava-utils.h @@ -49,9 +49,9 @@ inline bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) { return eval_tokens(ctx_llama, tokens, 1, n_past); } -inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past){ +inline bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){ std::string str2 = str; - std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, true); + std::vector embd_inp = ::llama_tokenize(ctx_llama, str2, add_bos); eval_tokens(ctx_llama, embd_inp, n_batch, n_past); return true; } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 8384d9d78..b24cb2e6f 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -97,6 +97,7 @@ int main(int argc, char ** argv) { ctx_params.n_ctx = params.n_ctx < 2048 ? 2048 : params.n_ctx; // we need a longer context size to process image embeddings ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + ctx_params.seed = params.seed; llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); @@ -106,7 +107,8 @@ int main(int argc, char ** argv) { } // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); + const int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); + if (n_img_embd != n_llama_embd) { printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd); @@ -125,14 +127,14 @@ int main(int argc, char ** argv) { const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; - // GG: are we sure that the should be a trailing whitespace at the end of this string? - eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past); + eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past, true); eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past); - eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past); - eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past); + eval_string(ctx_llama, (params.prompt + "\nASSISTANT:").c_str(), params.n_batch, &n_past, false); // generate the response + printf("\n"); + printf("prompt: '%s'\n", params.prompt.c_str()); printf("\n"); for (int i = 0; i < max_tgt_len; i++) { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 55f73356f..7313d06a0 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -238,7 +238,7 @@ int main(int argc, char ** argv) { if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) { LOG("tokenize the prompt\n"); - embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos); + embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); } else { LOG("use session tokens\n"); embd_inp = session_tokens; @@ -260,10 +260,10 @@ int main(int argc, char ** argv) { if (ctx_guidance) { LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt)); - guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos); + guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, add_bos, true); LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp)); - std::vector original_inp = ::llama_tokenize(ctx, params.prompt, add_bos); + std::vector original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true); LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp)); original_prompt_len = original_inp.size(); @@ -320,8 +320,8 @@ int main(int argc, char ** argv) { } // prefix & suffix for instruct mode - const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos); - const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false); + const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true); + const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true); LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx)); LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx)); @@ -383,6 +383,12 @@ int main(int argc, char ** argv) { if (!params.antiprompt.empty()) { for (const auto & antiprompt : params.antiprompt) { LOG_TEE("Reverse prompt: '%s'\n", antiprompt.c_str()); + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, antiprompt, false, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } } } @@ -392,10 +398,22 @@ int main(int argc, char ** argv) { if (!params.input_prefix.empty()) { LOG_TEE("Input prefix: '%s'\n", params.input_prefix.c_str()); + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, params.input_prefix, true, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } } if (!params.input_suffix.empty()) { LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str()); + if (params.verbose_prompt) { + auto tmp = ::llama_tokenize(ctx, params.input_suffix, false, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_TEE("%6d -> '%s'\n", tmp[i], llama_token_to_piece(ctx, tmp[i]).c_str()); + } + } } } LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", @@ -717,7 +735,7 @@ int main(int argc, char ** argv) { if (params.interactive) { if (!params.antiprompt.empty()) { // tokenize and inject first reverse prompt - const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false); + const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true); embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end()); is_antiprompt = true; } @@ -744,8 +762,7 @@ int main(int argc, char ** argv) { std::string buffer; if (!params.input_prefix.empty()) { LOG("appending input prefix: '%s'\n", params.input_prefix.c_str()); - buffer += params.input_prefix; - printf("%s", buffer.c_str()); + printf("%s", params.input_prefix.c_str()); } // color user input only @@ -767,7 +784,6 @@ int main(int argc, char ** argv) { // append input suffix if any if (!params.input_suffix.empty()) { LOG("appending input suffix: '%s'\n", params.input_suffix.c_str()); - buffer += params.input_suffix; printf("%s", params.input_suffix.c_str()); } @@ -782,10 +798,14 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); } - const auto line_inp = ::llama_tokenize(ctx, buffer, false); + const auto line_pfx = ::llama_tokenize(ctx, params.input_prefix, false, true); + const auto line_inp = ::llama_tokenize(ctx, buffer, false, false); + const auto line_sfx = ::llama_tokenize(ctx, params.input_suffix, false, true); LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp)); + embd_inp.insert(embd_inp.end(), line_pfx.begin(), line_pfx.end()); embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); + embd_inp.insert(embd_inp.end(), line_sfx.begin(), line_sfx.end()); // instruct mode: insert response suffix if (params.instruct) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index f9e3c98a3..38d05f4d3 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -8,10 +8,7 @@ int main(int argc, char ** argv) { gpt_params params; - llama_sampling_params & sparams = params.sampling_params; - params.seed = 42; - params.n_threads = 4; - sparams.repeat_last_n = 64; + params.prompt = "The quick brown fox"; if (!gpt_params_parse(argc, argv, params)) { @@ -25,56 +22,49 @@ int main(int argc, char ** argv) { } auto n_past = 0; - auto last_n_tokens_data = std::vector(sparams.repeat_last_n, 0); + + std::string result0; + std::string result1; // init llama_model * model; llama_context * ctx; - std::tie(model, ctx) = llama_init_from_gpt_params( params ); - if (model == nullptr) { - return 1; - } - if (ctx == nullptr) { - llama_free_model(model); + std::tie(model, ctx) = llama_init_from_gpt_params(params); + if (model == nullptr || ctx == nullptr) { + fprintf(stderr, "%s : failed to init\n", __func__); return 1; } + + // tokenize prompt auto tokens = llama_tokenize(ctx, params.prompt, true); - auto n_prompt_tokens = tokens.size(); - if (n_prompt_tokens < 1) { - fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); - llama_free(ctx); - llama_free_model(model); - return 1; - } // evaluate prompt - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0)); + llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), n_past, 0)); + n_past += tokens.size(); - last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); - n_past += n_prompt_tokens; - - const size_t state_size = llama_get_state_size(ctx); - uint8_t * state_mem = new uint8_t[state_size]; - - // Save state (rng, logits, embedding and kv_cache) to file + // save state (rng, logits, embedding and kv_cache) to file { - FILE *fp_write = fopen("dump_state.bin", "wb"); - llama_copy_state_data(ctx, state_mem); // could also copy directly to memory mapped file - fwrite(state_mem, 1, state_size, fp_write); - fclose(fp_write); + std::vector state_mem(llama_get_state_size(ctx)); + + { + FILE *fp_write = fopen("dump_state.bin", "wb"); + llama_copy_state_data(ctx, state_mem.data()); // could also copy directly to memory mapped file + fwrite(state_mem.data(), 1, state_mem.size(), fp_write); + fclose(fp_write); + } } // save state (last tokens) - const auto last_n_tokens_data_saved = std::vector(last_n_tokens_data); const auto n_past_saved = n_past; // first run - printf("\n%s", params.prompt.c_str()); + printf("\nfirst run: %s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { auto * logits = llama_get_logits(ctx); auto n_vocab = llama_n_vocab(model); + std::vector candidates; candidates.reserve(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { @@ -83,9 +73,10 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; auto next_token = llama_sample_token(ctx, &candidates_p); auto next_token_str = llama_token_to_piece(ctx, next_token); - last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); + result0 += next_token_str; + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); @@ -103,32 +94,28 @@ int main(int argc, char ** argv) { // make new context auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params)); - // Load state (rng, logits, embedding and kv_cache) from file - { - FILE *fp_read = fopen("dump_state.bin", "rb"); - if (state_size != llama_get_state_size(ctx2)) { - fprintf(stderr, "\n%s : failed to validate state size\n", __func__); - llama_free(ctx2); - llama_free_model(model); - return 1; - } + printf("\nsecond run: %s", params.prompt.c_str()); - const size_t ret = fread(state_mem, 1, state_size, fp_read); - if (ret != state_size) { + // load state (rng, logits, embedding and kv_cache) from file + { + std::vector state_mem(llama_get_state_size(ctx2)); + + FILE * fp_read = fopen("dump_state.bin", "rb"); + + const size_t ret = fread(state_mem.data(), 1, state_mem.size(), fp_read); + if (ret != state_mem.size()) { fprintf(stderr, "\n%s : failed to read state\n", __func__); llama_free(ctx2); llama_free_model(model); return 1; } - llama_set_state_data(ctx2, state_mem); // could also read directly from memory mapped file + llama_set_state_data(ctx2, state_mem.data()); + fclose(fp_read); } - delete[] state_mem; - // restore state (last tokens) - last_n_tokens_data = last_n_tokens_data_saved; n_past = n_past_saved; // second run @@ -143,10 +130,11 @@ int main(int argc, char ** argv) { llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; auto next_token = llama_sample_token(ctx2, &candidates_p); auto next_token_str = llama_token_to_piece(ctx2, next_token); - last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0))) { + result1 += next_token_str; + + if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1, n_past, 0))) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); @@ -155,10 +143,17 @@ int main(int argc, char ** argv) { n_past += 1; } - printf("\n\n"); + printf("\n"); llama_free(ctx2); llama_free_model(model); + if (result0 != result1) { + fprintf(stderr, "\n%s : error : the 2 generations are different\n", __func__); + return 1; + } + + fprintf(stderr, "\n%s : success\n", __func__); + return 0; } diff --git a/examples/server/README.md b/examples/server/README.md index 6d3ce8031..31d65a57f 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -114,8 +114,6 @@ node index.js *Options:* - `prompt`: Provide the prompt for this completion as a string or as an array of strings or numbers representing tokens. Internally, the prompt is compared to the previous completion and only the "unseen" suffix is evaluated. If the prompt is a string or an array with the first element given as a string, a `bos` token is inserted in the front like `main` does. - `temperature`: Adjust the randomness of the generated text (default: 0.8). `top_k`: Limit the next token selection to the K most probable tokens (default: 40). @@ -124,7 +122,7 @@ node index.js `n_predict`: Set the maximum number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: -1, -1 = infinity). - `n_keep`: Specify the number of tokens from the prompt to retain when context size is exceeded and tokens need to be discarded. + `n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context. By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt. `stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`. @@ -162,42 +160,6 @@ node index.js `n_probs`: If greater than 0, the response also contains the probabilities of top N tokens for each generated token (default: 0) - *Result JSON:* - - Note: When using streaming mode (`stream`) only `content` and `stop` will be returned until end of completion. - - `content`: Completion result as a string (excluding `stopping_word` if any). In case of streaming mode, will contain the next token as a string. - - `stop`: Boolean for use with `stream` to check whether the generation has stopped (Note: This is not related to stopping words array `stop` from input options) - - `generation_settings`: The provided options above excluding `prompt` but including `n_ctx`, `model` - - `model`: The path to the model loaded with `-m` - - `prompt`: The provided `prompt` - - `stopped_eos`: Indicating whether the completion has stopped because it encountered the EOS token - - `stopped_limit`: Indicating whether the completion stopped because `n_predict` tokens were generated before stop words or EOS was encountered - - `stopped_word`: Indicating whether the completion stopped due to encountering a stopping word from `stop` JSON array provided - - `stopping_word`: The stopping word encountered which stopped the generation (or "" if not stopped due to a stopping word) - - `timings`: Hash of timing information about the completion such as the number of tokens `predicted_per_second` - - `tokens_cached`: Number of tokens from the prompt which could be re-used from previous completion (`n_past`) - - `tokens_evaluated`: Number of tokens evaluated in total from the prompt - - `truncated`: Boolean indicating if the context size was exceeded during generation, i.e. the number of tokens provided in the prompt (`tokens_evaluated`) plus tokens generated (`tokens predicted`) exceeded the context size (`n_ctx`) - - `slot_id`: Assign the completion task to an specific slot. If is -1 the task will be assigned to a Idle slot (default: -1) - - `cache_prompt`: Save the prompt and generation for avoid reprocess entire prompt if a part of this isn't change (default: false) - - `system_prompt`: Change the system prompt (initial prompt of all slots), this is useful for chat applications. [See more](#change-system-prompt-on-runtime) - - **POST** `/tokenize`: Tokenize a given text. *Options:* diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index be693b3ac..1ce6cef29 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -253,13 +253,14 @@ static void init_model(struct my_llama_model * model) { set_param_model(model); // measure data size - struct ggml_allocr * alloc = NULL; - alloc = ggml_allocr_new_measure(tensor_alignment); - alloc_model(alloc, model); + size_t size = 0; + for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + size += GGML_PAD(ggml_nbytes(t), tensor_alignment); + } // allocate data - model->data.resize(ggml_allocr_max_size(alloc) + tensor_alignment); - ggml_allocr_free(alloc); + struct ggml_allocr * alloc = NULL; + model->data.resize(size + tensor_alignment); alloc = ggml_allocr_new(model->data.data(), model->data.size(), tensor_alignment); alloc_model(alloc, model); ggml_allocr_free(alloc); @@ -1094,11 +1095,9 @@ int main(int argc, char ** argv) { struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); // measure required memory for input tensors - alloc = ggml_allocr_new_measure(tensor_alignment); - ggml_allocr_alloc(alloc, tokens_input); - ggml_allocr_alloc(alloc, target_probs); - size_t max_input_size = ggml_allocr_max_size(alloc) + tensor_alignment; - ggml_allocr_free(alloc); + size_t max_input_size = GGML_PAD(ggml_nbytes(tokens_input), tensor_alignment) + + GGML_PAD(ggml_nbytes(target_probs), tensor_alignment) + + tensor_alignment; printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f)); // allocate input tensors diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp index 33d0691eb..22fd0e3a7 100644 --- a/ggml-opencl.cpp +++ b/ggml-opencl.cpp @@ -1568,7 +1568,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr ggml_cl_pool_free(d_D, d_size); } -static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) { +static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) { GGML_ASSERT(fp16_support); const int64_t ne00 = src0->ne[0]; @@ -1598,6 +1598,10 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; + GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * y_ne); + GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * d_ne); + ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata; + size_t x_size; size_t y_size; size_t d_size; @@ -1634,7 +1638,6 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr // convert src1 to fp16 // TODO: use multiple threads - ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12); char * src1i = (char *) src1->data + i13*nb13 + i12*nb12; if (src1_cont_rows) { if (src1_cont_cols) { @@ -1897,8 +1900,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * } size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) { - return ggml_nelements(src1) * sizeof(ggml_fp16_t); + if (src0->type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16(src0, src1, dst)) { + return sizeof(ggml_fp16_t) * std::max(src1->ne[0] * src1->ne[1], dst->ne[0] * dst->ne[1]); } return 0; } diff --git a/k_quants.c b/k_quants.c index 558f5fda8..e168a87bb 100644 --- a/k_quants.c +++ b/k_quants.c @@ -462,12 +462,9 @@ void quantize_row_q2_K(const float * restrict x, void * restrict vy, int k) { } size_t ggml_quantize_q2_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - const int nb = k / QK_K; + (void)hist; // TODO: collect histograms - // TODO - collect histograms - although, at a second thought, I don't really care about them - (void)hist; - - for (int j = 0; j < nb; j += k) { + for (int j = 0; j < n; j += k) { block_q2_K * restrict y = (block_q2_K *)dst + j/QK_K; quantize_row_q2_K_reference(src + j, y, k); } @@ -678,12 +675,9 @@ void quantize_row_q3_K(const float * restrict x, void * restrict vy, int k) { } size_t ggml_quantize_q3_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { - const int nb = k / QK_K; + (void)hist; // TODO: collect histograms - // TODO - collect histograms - although, at a second thought, I don't really care about them - (void)hist; - - for (int j = 0; j < nb; j += k) { + for (int j = 0; j < n; j += k) { block_q3_K * restrict y = (block_q3_K *)dst + j/QK_K; quantize_row_q3_K_reference(src + j, y, k); } @@ -846,9 +840,9 @@ void quantize_row_q4_K(const float * restrict x, void * restrict vy, int k) { size_t ggml_quantize_q4_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { assert(k % QK_K == 0); - const int nb = k / QK_K; (void)hist; // TODO: collect histograms - for (int j = 0; j < nb; j += k) { + + for (int j = 0; j < n; j += k) { block_q4_K * restrict y = (block_q4_K *)dst + j/QK_K; quantize_row_q4_K_reference(src + j, y, k); } @@ -1052,9 +1046,9 @@ void quantize_row_q5_K(const float * restrict x, void * restrict vy, int k) { size_t ggml_quantize_q5_K(const float * restrict src, void * restrict dst, int n, int k, int64_t * restrict hist) { assert(k % QK_K == 0); - const int nb = k / QK_K; - (void)hist; - for (int j = 0; j < nb; j += k) { + (void)hist; // TODO: collect histograms + + for (int j = 0; j < n; j += k) { block_q5_K * restrict y = (block_q5_K *)dst + j/QK_K; quantize_row_q5_K_reference(src + j, y, k); } @@ -1200,11 +1194,9 @@ void quantize_row_q6_K(const float * restrict x, void * restrict vy, int k) { size_t ggml_quantize_q6_K(const float * src, void * dst, int n, int k, int64_t * hist) { assert(k % QK_K == 0); - const int nb = k / QK_K; + (void)hist; // TODO: collect histograms - (void)hist; // TODO - - for (int j = 0; j < nb; j += k) { + for (int j = 0; j < n; j += k) { block_q6_K * restrict y = (block_q6_K *)dst + j/QK_K; quantize_row_q6_K_reference(src + j, y, k); } diff --git a/llama.cpp b/llama.cpp index 5329bd828..04a779e04 100644 --- a/llama.cpp +++ b/llama.cpp @@ -75,6 +75,7 @@ #include #include #include +#include #if defined(_MSC_VER) #pragma warning(disable: 4244 4267) // possible loss of data @@ -1183,6 +1184,8 @@ struct llama_vocab { std::unordered_map token_to_id; std::vector id_to_token; + std::unordered_map special_tokens_cache; + std::map, int> bpe_ranks; // default LLaMA special tokens @@ -2125,7 +2128,7 @@ static void llm_load_hparams( } // TODO: This should probably be in llama.h -static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos); +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special = false); static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch); static void llm_load_vocab( @@ -2241,6 +2244,101 @@ static void llm_load_vocab( GGUF_GET_KEY(ctx, vocab.special_unk_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_UNK_ID)); GGUF_GET_KEY(ctx, vocab.special_sep_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_SEP_ID)); GGUF_GET_KEY(ctx, vocab.special_pad_id, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_TOKENIZER_PAD_ID)); + + // build special tokens cache + { + // TODO: It is unclear (to me) at this point, whether special tokes are guaranteed to be of a deterministic type, + // and will always be correctly labeled in 'added_tokens.json' etc. + // The assumption is, since special tokens aren't meant to be exposed to end user, they are designed + // to be unmatchable by the tokenizer, therefore tokens from the vocab, which are unmatchable by the tokenizer + // are special tokens. + // From testing, this appears to corelate 1:1 with special tokens. + // + + // Counting special tokens and verifying in only one direction + // is sufficient to detect difference in those two sets. + // + uint32_t special_tokens_count_by_type = 0; + uint32_t special_tokens_count_from_verification = 0; + + bool special_tokens_definition_mismatch = false; + + for (const auto & t : vocab.token_to_id) { + const auto & token = t.first; + const auto & id = t.second; + + // Count all non-normal tokens in the vocab while iterating + if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { + special_tokens_count_by_type++; + } + + // Skip single character tokens + if (token.length() > 1) { + bool is_tokenizable = false; + + // Split token string representation in two, in all possible ways + // and check if both halves can be matched to a valid token + for (unsigned i = 1; i < token.length();) { + const auto left = token.substr(0, i); + const auto right = token.substr(i); + + // check if we didnt partition in the middle of a utf sequence + auto utf = utf8_len(left.at(left.length() - 1)); + + if (utf == 1) { + if (vocab.token_to_id.find(left) != vocab.token_to_id.end() && + vocab.token_to_id.find(right) != vocab.token_to_id.end() ) { + is_tokenizable = true; + break; + } + i++; + } else { + // skip over the rest of multibyte utf sequence + i += utf - 1; + } + } + + if (!is_tokenizable) { + // Some tokens are multibyte, but they are utf sequences with equivalent text length of 1 + // it's faster to re-filter them here, since there are way less candidates now + + // Calculate a total "utf" length of a token string representation + size_t utf8_str_len = 0; + for (unsigned i = 0; i < token.length();) { + utf8_str_len++; + i += utf8_len(token.at(i)); + } + + // And skip the ones which are one character + if (utf8_str_len > 1) { + // At this point what we have left are special tokens only + vocab.special_tokens_cache[token] = id; + + // Count manually found special tokens + special_tokens_count_from_verification++; + + // If this manually found special token is not marked as such, flag a mismatch + if (vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL) { + special_tokens_definition_mismatch = true; + } + } + } + } + } + + if (special_tokens_definition_mismatch || special_tokens_count_from_verification != special_tokens_count_by_type) { + LLAMA_LOG_WARN("%s: mismatch in special tokens definition ( %u/%zu vs %u/%zu ).\n", + __func__, + special_tokens_count_from_verification, vocab.id_to_token.size(), + special_tokens_count_by_type, vocab.id_to_token.size() + ); + } else { + LLAMA_LOG_INFO("%s: special tokens definition check successful ( %u/%zu ).\n", + __func__, + special_tokens_count_from_verification, vocab.id_to_token.size() + ); + } + } } static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { @@ -5805,6 +5903,13 @@ static int llama_decode_internal( ggml_allocr_alloc_graph(lctx.alloc, gf); + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; + struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; + + GGML_ASSERT(strcmp(res->name, "result_output") == 0); + GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); + + #ifdef GGML_USE_CUBLAS for (int i = 0; i < gf->n_leafs; i++) { ggml_tensor * node = gf->leafs[i]; @@ -5822,6 +5927,12 @@ static int llama_decode_internal( } ggml_cuda_set_mul_mat_q(cparams.mul_mat_q); + + // HACK: ggml-alloc may change the tensor backend when reusing a parent, so force output to be on the CPU here if needed + if (!lctx.embedding.empty()) { + embeddings->backend = GGML_BACKEND_CPU; + } + res->backend = GGML_BACKEND_CPU; #endif // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); @@ -5846,12 +5957,6 @@ static int llama_decode_internal( n_threads = 1; } - struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; - struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; - - GGML_ASSERT(strcmp(res->name, "result_output") == 0); - GGML_ASSERT(strcmp(embeddings->name, "result_norm") == 0); - #if GGML_USE_MPI const int64_t n_layer = hparams.n_layer; ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); @@ -6464,7 +6569,137 @@ private: llm_bigram_bpe::queue work_queue; }; -static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos) { +typedef enum FRAGMENT_BUFFER_VARIANT_TYPE{ + FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN, + FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT +} FRAGMENT_BUFFER_VARIANT_TYPE; + +struct fragment_buffer_variant{ + fragment_buffer_variant(llama_vocab::id _token) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN), + token(_token), + raw_text(_dummy), + offset(0), + length(0){} + fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) + : + type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), + token((llama_vocab::id)-1), + raw_text(_raw_text), + offset(_offset), + length(_length){ + GGML_ASSERT( _offset >= 0 ); + GGML_ASSERT( _length >= 1 ); + GGML_ASSERT( offset + length <= raw_text.length() ); + } + + const FRAGMENT_BUFFER_VARIANT_TYPE type; + const llama_vocab::id token; + const std::string _dummy; + const std::string & raw_text; + const uint64_t offset; + const uint64_t length; +}; + +// #define PRETOKENIZERDEBUG + +static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) +{ + // for each special token + for (const auto & st: vocab.special_tokens_cache) { + const auto & special_token = st.first; + const auto & special_id = st.second; + + // for each text fragment + std::forward_list::iterator it = buffer.begin(); + while (it != buffer.end()) { + auto & fragment = (*it); + + // if a fragment is text ( not yet processed ) + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + auto * raw_text = &(fragment.raw_text); + + auto raw_text_base_offset = fragment.offset; + auto raw_text_base_length = fragment.length; + + // loop over the text + while (true) { + // find the first occurence of a given special token in this fragment + // passing offset argument only limit the "search area" but match coordinates + // are still relative to the source full raw_text + auto match = raw_text->find(special_token, raw_text_base_offset); + + // no occurences found, stop processing this fragment for a given special token + if (match == std::string::npos) break; + + // check if match is within bounds of offset <-> length + if (match + special_token.length() > raw_text_base_offset + raw_text_base_length) break; + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "FF: (%ld %ld %ld) '%s'\n", raw_text->length(), raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); +#endif + auto source = std::distance(buffer.begin(), it); + + // if match is further than base offset + // then we have some text to the left of it + if (match > raw_text_base_offset) { + // left + const int64_t left_reminder_offset = raw_text_base_offset + 0; + const int64_t left_reminder_length = match - raw_text_base_offset; + buffer.emplace_after(it, (*raw_text), left_reminder_offset, left_reminder_length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); +#endif + it++; + } + + // special token + buffer.emplace_after(it, special_id); + it++; + + // right + if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { + const int64_t right_reminder_offset = match + special_token.length(); + const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); + buffer.emplace_after(it, (*raw_text), right_reminder_offset, right_reminder_length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); +#endif + + it++; + + if (source == 0) { + buffer.erase_after(buffer.before_begin()); + } else { + buffer.erase_after(std::next(buffer.begin(), (source-1))); + } + + // repeat for the right side + raw_text_base_offset = right_reminder_offset; + raw_text_base_length = right_reminder_length; + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr, "RR: (%ld %ld) '%s'\n", raw_text_base_offset, raw_text_base_length, raw_text->substr(raw_text_base_offset, raw_text_base_length).c_str()); +#endif + } else { + if (source == 0) { + buffer.erase_after(buffer.before_begin()); + } else { + buffer.erase_after(std::next(buffer.begin(), (source-1))); + } + break; + } + } + } + it++; + } + } +} + +static std::vector llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool bos, bool special) { std::vector output; // OG tokenizer behavior: @@ -6480,20 +6715,58 @@ static std::vector llama_tokenize_internal(const llama_vocab & return output; } + std::forward_list fragment_buffer; + fragment_buffer.emplace_front( raw_text, 0, raw_text.length() ); + + if (special) tokenizer_st_partition( vocab, fragment_buffer ); + switch (vocab.type) { case LLAMA_VOCAB_TYPE_SPM: { - // without adding this leading whitespace, we do not get the same results as the original tokenizer - raw_text = " " + raw_text; + for (const auto & fragment: fragment_buffer) + { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) + { + // without adding this leading whitespace, we do not get the same results as the original tokenizer - llm_tokenizer_spm tokenizer(vocab); - llama_escape_whitespace(raw_text); - tokenizer.tokenize(raw_text, output); + // TODO: It's likely possible to get rid of this string copy entirely + // by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer + // and passing 'add space prefix' as bool argument + // + auto raw_text = (special ? "" : " ") + fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + llm_tokenizer_spm tokenizer(vocab); + llama_escape_whitespace(raw_text); + tokenizer.tokenize(raw_text, output); + } + else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + { + output.push_back(fragment.token); + } + } } break; case LLAMA_VOCAB_TYPE_BPE: { - llm_tokenizer_bpe tokenizer(vocab); - tokenizer.tokenize(raw_text, output); + for (const auto & fragment: fragment_buffer) + { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) + { + auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); + +#ifdef PRETOKENIZERDEBUG + fprintf(stderr,"TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); +#endif + llm_tokenizer_bpe tokenizer(vocab); + tokenizer.tokenize(raw_text, output); + } + else // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + { + output.push_back(fragment.token); + } + } } break; } @@ -9407,15 +9680,15 @@ llama_token llama_token_eot(const struct llama_context * ctx) { return ctx->model.vocab.special_eot_id; } - int llama_tokenize( const struct llama_model * model, const char * text, int text_len, llama_token * tokens, int n_max_tokens, - bool add_bos) { - auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos); + bool add_bos, + bool special) { + auto res = llama_tokenize_internal(model->vocab, std::string(text, text_len), add_bos, special); if (n_max_tokens < (int) res.size()) { // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); diff --git a/llama.h b/llama.h index a78015ada..b13f23123 100644 --- a/llama.h +++ b/llama.h @@ -511,17 +511,20 @@ extern "C" { // Tokenization // - // Convert the provided text into tokens. - // The tokens pointer must be large enough to hold the resulting tokens. - // Returns the number of tokens on success, no more than n_max_tokens - // Returns a negative number on failure - the number of tokens that would have been returned + /// @details Convert the provided text into tokens. + /// @param tokens The tokens pointer must be large enough to hold the resulting tokens. + /// @return Returns the number of tokens on success, no more than n_max_tokens + /// @return Returns a negative number on failure - the number of tokens that would have been returned + /// @param special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. + /// Does not insert a leading space. LLAMA_API int llama_tokenize( const struct llama_model * model, const char * text, int text_len, llama_token * tokens, int n_max_tokens, - bool add_bos); + bool add_bos, + bool special); // Token Id -> Piece. // Uses the vocabulary in the provided context.