[llama] Add resegment post processing of tokenizer

This commit is contained in:
Howard Su 2023-07-02 22:06:03 +08:00
parent acc111caf9
commit e818537027

View File

@ -1771,21 +1771,35 @@ struct llama_tokenizer {
for (int i = 0; i != -1; i = symbols_[i].next) {
auto & symbol = symbols_[i];
auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n));
resegment(symbol, output);
}
}
if (token == vocab_.token_to_id.end()) {
private:
void resegment(llama_sp_symbol &symbol, std::vector<llama_vocab::id> &output) {
auto text = std::string(symbol.text, symbol.n);
auto token = vocab_.token_to_id.find(text);
if (token != vocab_.token_to_id.end()) {
output.push_back((*token).second);
return;
}
const auto p = rev_merge.find(text);
if (p == rev_merge.end()) {
// output any symbols that did not form tokens as bytes.
for (int j = 0; j < (int) symbol.n; ++j) {
llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
output.push_back(token_id);
}
} else {
output.push_back((*token).second);
}
}
return;
}
resegment(symbols_[p->second.first], output);
resegment(symbols_[p->second.second], output);
}
private:
void try_add_bigram(int left, int right) {
if (left == -1 || right == -1) {
return;
@ -1810,11 +1824,14 @@ private:
bigram.score = tok_score.score;
bigram.size = text.size();
work_queue_.push(bigram);
rev_merge[text] = std::make_pair(left, right);
}
const llama_vocab & vocab_;
std::vector<llama_sp_symbol> symbols_;
llama_sp_bigram::queue work_queue_;
std::map<std::string, std::pair<int, int> > rev_merge;
};
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {