From e8185370276ef43d0a98265c97704fb86c02fa40 Mon Sep 17 00:00:00 2001 From: Howard Su Date: Sun, 2 Jul 2023 22:06:03 +0800 Subject: [PATCH] [llama] Add resegment post processing of tokenizer --- llama.cpp | 39 ++++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/llama.cpp b/llama.cpp index 7419b03b6..c9488dc6f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1771,21 +1771,35 @@ struct llama_tokenizer { for (int i = 0; i != -1; i = symbols_[i].next) { auto & symbol = symbols_[i]; - auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n)); - - if (token == vocab_.token_to_id.end()) { - // output any symbols that did not form tokens as bytes. - for (int j = 0; j < (int) symbol.n; ++j) { - llama_vocab::id token_id = static_cast(symbol.text[j]) + 3; - output.push_back(token_id); - } - } else { - output.push_back((*token).second); - } + resegment(symbol, output); } } private: + void resegment(llama_sp_symbol &symbol, std::vector &output) { + auto text = std::string(symbol.text, symbol.n); + auto token = vocab_.token_to_id.find(text); + + if (token != vocab_.token_to_id.end()) { + output.push_back((*token).second); + return; + } + + const auto p = rev_merge.find(text); + + if (p == rev_merge.end()) { + // output any symbols that did not form tokens as bytes. + for (int j = 0; j < (int) symbol.n; ++j) { + llama_vocab::id token_id = static_cast(symbol.text[j]) + 3; + output.push_back(token_id); + } + return; + } + + resegment(symbols_[p->second.first], output); + resegment(symbols_[p->second.second], output); + } + void try_add_bigram(int left, int right) { if (left == -1 || right == -1) { return; @@ -1810,11 +1824,14 @@ private: bigram.score = tok_score.score; bigram.size = text.size(); work_queue_.push(bigram); + + rev_merge[text] = std::make_pair(left, right); } const llama_vocab & vocab_; std::vector symbols_; llama_sp_bigram::queue work_queue_; + std::map > rev_merge; }; static std::vector llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {