[llama] Add resegment post processing of tokenizer

This commit is contained in:
Howard Su 2023-07-02 22:06:03 +08:00
parent acc111caf9
commit e818537027

View File

@ -1771,21 +1771,35 @@ struct llama_tokenizer {
for (int i = 0; i != -1; i = symbols_[i].next) { for (int i = 0; i != -1; i = symbols_[i].next) {
auto & symbol = symbols_[i]; auto & symbol = symbols_[i];
auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n)); resegment(symbol, output);
}
}
if (token == vocab_.token_to_id.end()) { private:
void resegment(llama_sp_symbol &symbol, std::vector<llama_vocab::id> &output) {
auto text = std::string(symbol.text, symbol.n);
auto token = vocab_.token_to_id.find(text);
if (token != vocab_.token_to_id.end()) {
output.push_back((*token).second);
return;
}
const auto p = rev_merge.find(text);
if (p == rev_merge.end()) {
// output any symbols that did not form tokens as bytes. // output any symbols that did not form tokens as bytes.
for (int j = 0; j < (int) symbol.n; ++j) { for (int j = 0; j < (int) symbol.n; ++j) {
llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3; llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
output.push_back(token_id); output.push_back(token_id);
} }
} else { return;
output.push_back((*token).second); }
}
} resegment(symbols_[p->second.first], output);
resegment(symbols_[p->second.second], output);
} }
private:
void try_add_bigram(int left, int right) { void try_add_bigram(int left, int right) {
if (left == -1 || right == -1) { if (left == -1 || right == -1) {
return; return;
@ -1810,11 +1824,14 @@ private:
bigram.score = tok_score.score; bigram.score = tok_score.score;
bigram.size = text.size(); bigram.size = text.size();
work_queue_.push(bigram); work_queue_.push(bigram);
rev_merge[text] = std::make_pair(left, right);
} }
const llama_vocab & vocab_; const llama_vocab & vocab_;
std::vector<llama_sp_symbol> symbols_; std::vector<llama_sp_symbol> symbols_;
llama_sp_bigram::queue work_queue_; llama_sp_bigram::queue work_queue_;
std::map<std::string, std::pair<int, int> > rev_merge;
}; };
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) { static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {