mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 19:21:46 +00:00
[llama] Add resegment post processing of tokenizer
This commit is contained in:
parent
acc111caf9
commit
e818537027
31
llama.cpp
31
llama.cpp
@ -1771,21 +1771,35 @@ struct llama_tokenizer {
|
|||||||
|
|
||||||
for (int i = 0; i != -1; i = symbols_[i].next) {
|
for (int i = 0; i != -1; i = symbols_[i].next) {
|
||||||
auto & symbol = symbols_[i];
|
auto & symbol = symbols_[i];
|
||||||
auto token = vocab_.token_to_id.find(std::string(symbol.text, symbol.n));
|
resegment(symbol, output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (token == vocab_.token_to_id.end()) {
|
private:
|
||||||
|
void resegment(llama_sp_symbol &symbol, std::vector<llama_vocab::id> &output) {
|
||||||
|
auto text = std::string(symbol.text, symbol.n);
|
||||||
|
auto token = vocab_.token_to_id.find(text);
|
||||||
|
|
||||||
|
if (token != vocab_.token_to_id.end()) {
|
||||||
|
output.push_back((*token).second);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto p = rev_merge.find(text);
|
||||||
|
|
||||||
|
if (p == rev_merge.end()) {
|
||||||
// output any symbols that did not form tokens as bytes.
|
// output any symbols that did not form tokens as bytes.
|
||||||
for (int j = 0; j < (int) symbol.n; ++j) {
|
for (int j = 0; j < (int) symbol.n; ++j) {
|
||||||
llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
|
llama_vocab::id token_id = static_cast<uint8_t>(symbol.text[j]) + 3;
|
||||||
output.push_back(token_id);
|
output.push_back(token_id);
|
||||||
}
|
}
|
||||||
} else {
|
return;
|
||||||
output.push_back((*token).second);
|
}
|
||||||
}
|
|
||||||
}
|
resegment(symbols_[p->second.first], output);
|
||||||
|
resegment(symbols_[p->second.second], output);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
|
||||||
void try_add_bigram(int left, int right) {
|
void try_add_bigram(int left, int right) {
|
||||||
if (left == -1 || right == -1) {
|
if (left == -1 || right == -1) {
|
||||||
return;
|
return;
|
||||||
@ -1810,11 +1824,14 @@ private:
|
|||||||
bigram.score = tok_score.score;
|
bigram.score = tok_score.score;
|
||||||
bigram.size = text.size();
|
bigram.size = text.size();
|
||||||
work_queue_.push(bigram);
|
work_queue_.push(bigram);
|
||||||
|
|
||||||
|
rev_merge[text] = std::make_pair(left, right);
|
||||||
}
|
}
|
||||||
|
|
||||||
const llama_vocab & vocab_;
|
const llama_vocab & vocab_;
|
||||||
std::vector<llama_sp_symbol> symbols_;
|
std::vector<llama_sp_symbol> symbols_;
|
||||||
llama_sp_bigram::queue work_queue_;
|
llama_sp_bigram::queue work_queue_;
|
||||||
|
std::map<std::string, std::pair<int, int> > rev_merge;
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, const std::string & text, bool bos) {
|
||||||
|
Loading…
Reference in New Issue
Block a user