10X faster BPE tokenizer (#2876)

* 10X faster BPE tokenizer

* Remove comment that no longer applies

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2023-08-29 23:55:03 +03:00 committed by GitHub
parent 53885d7256
commit e37e69dcc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -3211,7 +3211,7 @@ private:
struct llm_bigram_bpe { struct llm_bigram_bpe {
struct comparator { struct comparator {
bool operator()(llm_bigram_bpe & l, llm_bigram_bpe & r) { bool operator()(const llm_bigram_bpe & l, const llm_bigram_bpe & r) const {
return l.rank > r.rank || (l.rank == r.rank && l.left > r.left); return l.rank > r.rank || (l.rank == r.rank && l.left > r.left);
} }
}; };
@ -3359,23 +3359,22 @@ private:
} }
// probably not 100% correct // probably not 100% correct
// TODO: this is quite slow - how to make it more efficient? static std::vector<std::string> bpe_gpt2_preprocess(const std::string & text) {
static std::vector<std::string> bpe_gpt2_preprocess(std::string text) {
std::vector<std::string> words; std::vector<std::string> words;
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53 // ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53
const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)"; const std::string pattern = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
const std::regex re(pattern); const std::regex re(pattern);
std::smatch m;
while (std::regex_search(text, m, re)) { auto words_begin = std::sregex_iterator(text.begin(), text.end(), re);
for (auto x : m) { auto words_end = std::sregex_iterator();
words.push_back(x); auto n_words = std::distance(words_begin, words_end);
} words.reserve(n_words);
text = m.suffix(); for (auto it = words_begin; it != words_end; ++it) {
words.push_back(it->str());
} }
return words; return words;
} }
const llama_vocab & vocab; const llama_vocab & vocab;