From a927b0f3dd9a86ee042cd2bdcc8c9da4a855926b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 21 Jun 2024 08:51:28 +0300 Subject: [PATCH] llama : optimize long word tokenization with WPM (#8034) ggml-ci --- llama.cpp | 17 ++++++++++++----- unicode.cpp | 1 + 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 9ca0b7479..a05a52b42 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2293,6 +2293,8 @@ struct llama_vocab { enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + int max_token_len = 0; // used for optimizing longest token search + std::unordered_map token_to_id; std::vector id_to_token; @@ -4939,6 +4941,7 @@ static void llm_load_vocab( GGML_ASSERT(unicode_cpts_from_utf8(word).size() > 0); vocab.token_to_id[word] = i; + vocab.max_token_len = std::max(vocab.max_token_len, (int) word.size()); auto & token_data = vocab.id_to_token[i]; token_data.text = std::move(word); @@ -5249,6 +5252,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); } if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } + LLAMA_LOG_INFO("%s: max token length = %d\n", __func__, vocab.max_token_len); + if (model.arch == LLM_ARCH_DEEPSEEK2) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); @@ -13488,7 +13493,7 @@ private: struct llm_tokenizer_wpm { llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} - void tokenize(const std::string & text, std::vector & output) { + void tokenize(const std::string & text, std::vector & output) const { const auto & token_map = vocab.token_to_id; // normalize and split by whitespace @@ -13497,7 +13502,7 @@ struct llm_tokenizer_wpm { // bos token prepended already // find the longest tokens that form the words - for (const std::string &word : words) { + for (const std::string & word : words) { // skip empty words if (word.size() == 0) { continue; @@ -13514,7 +13519,7 @@ struct llm_tokenizer_wpm { for (int i = 0; i < n; ++i) { // loop through possible match length bool match = false; - for (int j = n; j > i; j--) { + for (int j = std::min(n, i + vocab.max_token_len + 1); j > i; j--) { auto it = token_map.find(word1.substr(i, j - i)); if (it != token_map.end()) { output.push_back(it->second); @@ -13537,7 +13542,8 @@ struct llm_tokenizer_wpm { } } - std::vector preprocess(const std::string & text) { + // TODO: reduce string copies by using cpts_offs array + std::vector preprocess(const std::string & text) const { const std::vector cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text)); std::vector words(1, ""); @@ -13832,6 +13838,8 @@ static std::vector llama_tokenize_internal(const llama_vocab & output.push_back(vocab.special_cls_id); } + llm_tokenizer_wpm tokenizer(vocab); + for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); @@ -13839,7 +13847,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); #endif - llm_tokenizer_wpm tokenizer(vocab); tokenizer.tokenize(raw_text, output); } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); diff --git a/unicode.cpp b/unicode.cpp index 913c34b9b..c0b76bf20 100644 --- a/unicode.cpp +++ b/unicode.cpp @@ -596,6 +596,7 @@ std::vector unicode_cpts_normalize_nfd(const std::vector & c std::vector unicode_cpts_from_utf8(const std::string & utf8) { std::vector result; + result.reserve(utf8.size()); size_t offset = 0; while (offset < utf8.size()) { result.push_back(unicode_cpt_from_utf8(utf8, offset));