refactor tokenizer

This commit is contained in:
zhenweijin 2024-09-11 09:42:55 +08:00
parent 722ec1eb51
commit d949c5844d
5 changed files with 382 additions and 103 deletions

View File

@ -50,7 +50,7 @@ struct naive_trie {
res.first->second.insert(key + 1, len - 1, value); res.first->second.insert(key + 1, len - 1, value);
} }
} }
std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) { std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) const {
if (len == 0 || offset == len) { if (len == 0 || offset == len) {
return std::make_pair(key, offset); return std::make_pair(key, offset);
} }
@ -187,10 +187,17 @@ struct llm_bigram_spm {
size_t size; size_t size;
}; };
struct llm_tokenizer_spm { struct llm_tokenizer_spm : llm_tokenizer {
llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {} llm_tokenizer_spm(const llama_vocab & vocab) : llm_tokenizer(vocab) {}
};
struct llm_tokenizer_spm_session {
llm_tokenizer_spm_session(const llm_tokenizer & tokenizer) :
spm_tokenizer(static_cast<const llm_tokenizer_spm &>(tokenizer)) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
// split string into utf8 chars // split string into utf8 chars
int index = 0; int index = 0;
size_t offs = 0; size_t offs = 0;
@ -250,6 +257,7 @@ struct llm_tokenizer_spm {
private: private:
void resegment(llm_symbol & symbol, std::vector<llama_vocab::id> & output) { void resegment(llm_symbol & symbol, std::vector<llama_vocab::id> & output) {
const auto & vocab = spm_tokenizer.vocab;
auto text = std::string(symbol.text, symbol.n); auto text = std::string(symbol.text, symbol.n);
auto token = vocab.token_to_id.find(text); auto token = vocab.token_to_id.find(text);
@ -279,7 +287,7 @@ private:
if (left == -1 || right == -1) { if (left == -1 || right == -1) {
return; return;
} }
const auto & vocab = spm_tokenizer.vocab;
const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n); const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
auto token = vocab.token_to_id.find(text); auto token = vocab.token_to_id.find(text);
@ -305,11 +313,10 @@ private:
rev_merge[text] = std::make_pair(left, right); rev_merge[text] = std::make_pair(left, right);
} }
const llama_vocab & vocab; const llm_tokenizer_spm & spm_tokenizer;
std::vector<llm_symbol> symbols; std::vector<llm_symbol> symbols;
llm_bigram_spm::queue work_queue; llm_bigram_spm::queue work_queue;
std::map<std::string, std::pair<int, int>> rev_merge; std::map<std::string, std::pair<int, int>> rev_merge;
}; };
@ -352,8 +359,8 @@ struct llm_bigram_bpe {
size_t size; size_t size;
}; };
struct llm_tokenizer_bpe { struct llm_tokenizer_bpe : llm_tokenizer {
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) { llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer(vocab) {
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE); GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
switch (vocab.type_pre) { switch (vocab.type_pre) {
case LLAMA_VOCAB_PRE_TYPE_LLAMA3: case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
@ -462,11 +469,20 @@ struct llm_tokenizer_bpe {
} }
} }
std::vector<std::string> regex_exprs;
};
struct llm_tokenizer_bpe_session {
llm_tokenizer_bpe_session(const llm_tokenizer & tokenizer) :
bpe_tokenizer(static_cast<const llm_tokenizer_bpe &>(tokenizer)) {}
void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const { void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
output.push_back(token_id); output.push_back(token_id);
} }
bool append_bos(std::vector<llama_vocab::id> & output) const { bool append_bos(std::vector<llama_vocab::id> & output) const {
const auto & vocab = bpe_tokenizer.vocab;
if (vocab.tokenizer_add_bos) { if (vocab.tokenizer_add_bos) {
GGML_ASSERT(vocab.special_bos_id != -1); GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id); output.push_back(vocab.special_bos_id);
@ -476,6 +492,7 @@ struct llm_tokenizer_bpe {
} }
bool append_eos(std::vector<llama_vocab::id> & output) const { bool append_eos(std::vector<llama_vocab::id> & output) const {
const auto & vocab = bpe_tokenizer.vocab;
if (vocab.tokenizer_add_eos) { if (vocab.tokenizer_add_eos) {
GGML_ASSERT(vocab.special_eos_id != -1); GGML_ASSERT(vocab.special_eos_id != -1);
output.push_back(vocab.special_eos_id); output.push_back(vocab.special_eos_id);
@ -485,6 +502,7 @@ struct llm_tokenizer_bpe {
} }
void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const { void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
const auto & vocab = bpe_tokenizer.vocab;
if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) { if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
LLAMA_LOG_WARN( LLAMA_LOG_WARN(
"%s: Added a BOS token to the prompt as specified by the model but the prompt " "%s: Added a BOS token to the prompt as specified by the model but the prompt "
@ -501,8 +519,8 @@ struct llm_tokenizer_bpe {
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
int final_prev_index = -1; int final_prev_index = -1;
const auto word_collection = unicode_regex_split(text, bpe_tokenizer.regex_exprs);
const auto word_collection = unicode_regex_split(text, regex_exprs); const auto & vocab = bpe_tokenizer.vocab;
symbols_final.clear(); symbols_final.clear();
@ -609,7 +627,7 @@ private:
if (left == -1 || right == -1) { if (left == -1 || right == -1) {
return; return;
} }
const auto & vocab = bpe_tokenizer.vocab;
std::string left_token = std::string(symbols[left].text, symbols[left].n); std::string left_token = std::string(symbols[left].text, symbols[left].n);
std::string right_token = std::string(symbols[right].text, symbols[right].n); std::string right_token = std::string(symbols[right].text, symbols[right].n);
@ -632,13 +650,10 @@ private:
work_queue.push(bigram); work_queue.push(bigram);
} }
const llama_vocab & vocab; const llm_tokenizer_bpe & bpe_tokenizer;
std::vector<std::string> regex_exprs;
std::vector<llm_symbol> symbols; std::vector<llm_symbol> symbols;
std::vector<llm_symbol> symbols_final; std::vector<llm_symbol> symbols_final;
llm_bigram_bpe::queue work_queue; llm_bigram_bpe::queue work_queue;
}; };
@ -646,15 +661,20 @@ private:
// WPM tokenizer // WPM tokenizer
// //
struct llm_tokenizer_wpm { struct llm_tokenizer_wpm : llm_tokenizer {
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {} llm_tokenizer_wpm(const llama_vocab & vocab) : llm_tokenizer(vocab) {}
};
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const { struct llm_tokenizer_wpm_session {
llm_tokenizer_wpm_session(const llm_tokenizer & tokenizer)
: wpm_tokenizer(static_cast<const llm_tokenizer_wpm &>(tokenizer)) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
const auto & vocab = wpm_tokenizer.vocab;
const auto & token_map = vocab.token_to_id; const auto & token_map = vocab.token_to_id;
// normalize and split by whitespace // normalize and split by whitespace
std::vector<std::string> words = preprocess(text); std::vector<std::string> words = preprocess(text);
// bos token prepended already // bos token prepended already
// find the longest tokens that form the words // find the longest tokens that form the words
@ -751,15 +771,16 @@ struct llm_tokenizer_wpm {
//(cpt >= 0xFF00 && cpt <= 0xFFEF); //(cpt >= 0xFF00 && cpt <= 0xFFEF);
} }
const llama_vocab & vocab; private:
const llm_tokenizer_wpm & wpm_tokenizer;
}; };
// //
// UGM tokenizer // UGM tokenizer
// //
struct llm_tokenizer_ugm { struct llm_tokenizer_ugm : llm_tokenizer {
llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) { llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer(vocab) {
if (vocab.precompiled_charsmap.size() > 0) { if (vocab.precompiled_charsmap.size() > 0) {
size_t charsmap_offset = 0; size_t charsmap_offset = 0;
@ -805,6 +826,31 @@ struct llm_tokenizer_ugm {
unknown_token_score = min_score - unknown_token_score_penalty; unknown_token_score = min_score - unknown_token_score_penalty;
} }
// escaped space symbol - U+2581 (Lower One Eighth Block)
const std::string escaped_space = "\xE2\x96\x81";
const char * prefix_replacements = NULL;
size_t prefix_replacements_size = 0;
const uint32_t * xcda_array = NULL;
size_t xcda_array_size = 0;
struct naive_trie user_defined_token_matcher;
float min_score = FLT_MAX;
float max_score = -FLT_MAX;
float unknown_token_score_penalty = 10.0;
float unknown_token_score;
struct naive_trie token_matcher;
};
struct llm_tokenizer_ugm_session {
llm_tokenizer_ugm_session(const llm_tokenizer & tokenizer)
: ugm_tokenizer(static_cast<const llm_tokenizer_ugm &>(tokenizer)) {}
/* This implementation is based on SentencePiece optimized Viterbi algorithm for /* This implementation is based on SentencePiece optimized Viterbi algorithm for
* unigram language models. The general idea is to: * unigram language models. The general idea is to:
* - move along the input sequence in steps of one UTF code point, * - move along the input sequence in steps of one UTF code point,
@ -821,6 +867,7 @@ struct llm_tokenizer_ugm {
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
// get current size of output (for reversal later) // get current size of output (for reversal later)
size_t output_size = output.size(); size_t output_size = output.size();
const auto & vocab = ugm_tokenizer.vocab;
// normalize the input first // normalize the input first
std::string normalized; std::string normalized;
@ -843,7 +890,7 @@ struct llm_tokenizer_ugm {
// traverse the token matcher trie to find a matching token // traverse the token matcher trie to find a matching token
bool single_codepoint_token_found = false; bool single_codepoint_token_found = false;
const struct best_tokenization & current_best = tokenization_results[input_offset]; const struct best_tokenization & current_best = tokenization_results[input_offset];
const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]); const struct naive_trie * node = ugm_tokenizer.token_matcher.traverse(normalized[prefix_offset++]);
while (prefix_offset <= input_len && node != NULL) { while (prefix_offset <= input_len && node != NULL) {
// check if we found valid token in prefix // check if we found valid token in prefix
@ -873,7 +920,7 @@ struct llm_tokenizer_ugm {
// if we didn't find a valid token corresponding to the whole UTF code point // if we didn't find a valid token corresponding to the whole UTF code point
// then use unknown token as the tokenization of this UTF code point // then use unknown token as the tokenization of this UTF code point
if (!single_codepoint_token_found) { if (!single_codepoint_token_found) {
const double challenger_score = current_best.score_sum + unknown_token_score; const double challenger_score = current_best.score_sum + ugm_tokenizer.unknown_token_score;
prefix_offset = input_offset + n_utf8_code_units; prefix_offset = input_offset + n_utf8_code_units;
struct best_tokenization & current_champ = tokenization_results[prefix_offset]; struct best_tokenization & current_champ = tokenization_results[prefix_offset];
if (challenger_score > current_champ.score_sum) { if (challenger_score > current_champ.score_sum) {
@ -905,7 +952,6 @@ struct llm_tokenizer_ugm {
} }
private: private:
const llama_vocab & vocab;
// helper structure for returning normalization results // helper structure for returning normalization results
struct normalization_result { struct normalization_result {
@ -917,8 +963,9 @@ private:
void normalize(const std::string& input, std::string * normalized) { void normalize(const std::string& input, std::string * normalized) {
normalized->clear(); normalized->clear();
normalized->reserve(input.size() * 3); normalized->reserve(input.size() * 3);
const auto & vocab = ugm_tokenizer.vocab;
const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " "; const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer.escaped_space : " ";
bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix; bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
@ -1000,13 +1047,21 @@ private:
size_t xcda_array_size; size_t xcda_array_size;
}; };
// this structure stores the best tokenization so far at input_offset
struct best_tokenization {
llama_token token_id;
size_t input_offset;
float score_sum;
};
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) { struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
if (input_offset == input.size()) { if (input_offset == input.size()) {
return { &input[input_offset], 0, 0 }; return { &input[input_offset], 0, 0 };
} }
// if input prefix matches some user-defined token return this token as normalization result // if input prefix matches some user-defined token return this token as normalization result
auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset); auto user_defined_token_match =
ugm_tokenizer.user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
if (user_defined_token_match.second > 0) { if (user_defined_token_match.second > 0) {
return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second }; return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
} }
@ -1014,8 +1069,8 @@ private:
size_t longest_prefix_length = 0; size_t longest_prefix_length = 0;
size_t longest_prefix_offset = 0; size_t longest_prefix_offset = 0;
if (xcda_array_size > 0) { if (ugm_tokenizer.xcda_array_size > 0) {
struct xcda_array_view xcda_view(xcda_array, xcda_array_size); struct xcda_array_view xcda_view(ugm_tokenizer.xcda_array, ugm_tokenizer.xcda_array_size);
// Find the longest normalized sequence matching the input prefix by walking // Find the longest normalized sequence matching the input prefix by walking
// the XOR-compressed compact double array (XCDA) starting from the root node // the XOR-compressed compact double array (XCDA) starting from the root node
@ -1051,10 +1106,10 @@ private:
if (longest_prefix_length > 0) { if (longest_prefix_length > 0) {
// we have a match, so return the replacement sequence // we have a match, so return the replacement sequence
if (longest_prefix_offset >= prefix_replacements_size) { if (longest_prefix_offset >= ugm_tokenizer.prefix_replacements_size) {
throw std::runtime_error("Index out of array bounds in precompiled charsmap!"); throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
} }
const char * prefix_replacement = &prefix_replacements[longest_prefix_offset]; const char * prefix_replacement = &(ugm_tokenizer.prefix_replacements)[longest_prefix_offset];
return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length }; return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
} else { } else {
// check if the input prefix contains a valid sequence of UTF-8 code units // check if the input prefix contains a valid sequence of UTF-8 code units
@ -1070,31 +1125,7 @@ private:
} }
} }
// escaped space symbol - U+2581 (Lower One Eighth Block) const llm_tokenizer_ugm & ugm_tokenizer;
const std::string escaped_space = "\xE2\x96\x81";
const char * prefix_replacements = NULL;
size_t prefix_replacements_size = 0;
const uint32_t * xcda_array = NULL;
size_t xcda_array_size = 0;
struct naive_trie user_defined_token_matcher;
// this structure stores the best tokenization so far at input_offset
struct best_tokenization {
llama_token token_id;
size_t input_offset;
float score_sum;
};
float min_score = FLT_MAX;
float max_score = -FLT_MAX;
float unknown_token_score_penalty = 10.0;
float unknown_token_score;
struct naive_trie token_matcher;
}; };
// //
@ -1155,8 +1186,8 @@ static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escape
return output; return output;
} }
struct llm_tokenizer_rwkv { struct llm_tokenizer_rwkv : llm_tokenizer {
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) { llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer(vocab) {
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens. // RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
// For now, we decode the vocab here into the lookup we'll use for tokenization. // For now, we decode the vocab here into the lookup we'll use for tokenization.
@ -1168,11 +1199,19 @@ struct llm_tokenizer_rwkv {
} }
} }
struct naive_trie token_matcher;
};
struct llm_tokenizer_rwkv_session {
llm_tokenizer_rwkv_session(const llm_tokenizer & tokenizer)
: rwkv_tokenizer(static_cast<const llm_tokenizer_rwkv &>(tokenizer)) {}
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) { void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
uint32_t position = 0; uint32_t position = 0;
const auto & vocab = rwkv_tokenizer.vocab;
while (position < text.size()) { while (position < text.size()) {
const struct naive_trie * node = token_matcher.traverse(text[position]); const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
if (node == NULL) { if (node == NULL) {
// no matching token found, add unknown token // no matching token found, add unknown token
output.push_back(vocab.special_unk_id); output.push_back(vocab.special_unk_id);
@ -1197,9 +1236,8 @@ struct llm_tokenizer_rwkv {
} }
} }
const llama_vocab & vocab; private:
const llm_tokenizer_rwkv & rwkv_tokenizer;
struct naive_trie token_matcher;
}; };
// //
@ -1362,9 +1400,11 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
} }
} }
std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) { std::vector<llama_vocab::id> llama_tokenize_internal(const llm_tokenizer * tokenizer,
std::string raw_text, bool add_special, bool parse_special) {
std::vector<llama_vocab::id> output; std::vector<llama_vocab::id> output;
std::forward_list<fragment_buffer_variant> fragment_buffer; std::forward_list<fragment_buffer_variant> fragment_buffer;
const llama_vocab & vocab = tokenizer->vocab;
if (!raw_text.empty()) { if (!raw_text.empty()) {
fragment_buffer.emplace_front(raw_text, 0, raw_text.length()); fragment_buffer.emplace_front(raw_text, 0, raw_text.length());
@ -1399,9 +1439,9 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif #endif
llm_tokenizer_spm tokenizer(vocab);
llama_escape_whitespace(raw_text); llama_escape_whitespace(raw_text);
tokenizer.tokenize(raw_text, output); llm_tokenizer_spm_session session(*tokenizer);
session.tokenize(raw_text, output);
is_prev_special = false; is_prev_special = false;
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token); output.push_back(fragment.token);
@ -1423,10 +1463,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
} break; } break;
case LLAMA_VOCAB_TYPE_BPE: case LLAMA_VOCAB_TYPE_BPE:
{ {
llm_tokenizer_bpe tokenizer(vocab); llm_tokenizer_bpe_session session(*tokenizer);
// it calls some other methods that are not exist in llm_tokenizer,
// here just cast it to bpe tokenizer object
if (add_special) { if (add_special) {
tokenizer.append_bos(output); session.append_bos(output);
} }
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@ -1435,15 +1476,15 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif #endif
tokenizer.tokenize(raw_text, output); session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
tokenizer.append(fragment.token, output); session.append(fragment.token, output);
} }
} }
if (add_special) { if (add_special) {
tokenizer.append_eos(output); session.append_eos(output);
tokenizer.check_double_bos_eos(output); session.check_double_bos_eos(output);
} }
} break; } break;
case LLAMA_VOCAB_TYPE_WPM: case LLAMA_VOCAB_TYPE_WPM:
@ -1453,7 +1494,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
output.push_back(vocab.special_cls_id); output.push_back(vocab.special_cls_id);
} }
llm_tokenizer_wpm tokenizer(vocab); llm_tokenizer_wpm_session session(*tokenizer);
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@ -1462,7 +1503,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif #endif
tokenizer.tokenize(raw_text, output); session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token); output.push_back(fragment.token);
} }
@ -1475,12 +1516,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
} break; } break;
case LLAMA_VOCAB_TYPE_UGM: case LLAMA_VOCAB_TYPE_UGM:
{ {
llm_tokenizer_ugm tokenizer(vocab);
if (add_special && vocab.tokenizer_add_bos != 0) { if (add_special && vocab.tokenizer_add_bos != 0) {
GGML_ASSERT(vocab.special_bos_id != -1); GGML_ASSERT(vocab.special_bos_id != -1);
output.push_back(vocab.special_bos_id); output.push_back(vocab.special_bos_id);
} }
llm_tokenizer_ugm_session session(*tokenizer);
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
@ -1488,7 +1528,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
#ifdef PRETOKENIZERDEBUG #ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif #endif
tokenizer.tokenize(raw_text, output); session.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token); output.push_back(fragment.token);
} }
@ -1508,6 +1548,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
} break; } break;
case LLAMA_VOCAB_TYPE_RWKV: case LLAMA_VOCAB_TYPE_RWKV:
{ {
llm_tokenizer_rwkv_session session(*tokenizer);
for (const auto & fragment : fragment_buffer) { for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
@ -1516,8 +1557,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str()); LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
#endif #endif
llm_tokenizer_rwkv tokenizer(vocab); session.tokenize(raw_text, output);
tokenizer.tokenize(raw_text, output);
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token); output.push_back(fragment.token);
} }
@ -1530,6 +1570,32 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
return output; return output;
} }
llm_tokenizer * llama_create_tokenizer(const llama_vocab & vocab) {
llm_tokenizer * tokenizer;
switch (vocab.type) {
case LLAMA_VOCAB_TYPE_SPM:
tokenizer = new llm_tokenizer_spm(vocab);
break;
case LLAMA_VOCAB_TYPE_BPE:
tokenizer = new llm_tokenizer_bpe(vocab);
break;
case LLAMA_VOCAB_TYPE_WPM:
tokenizer = new llm_tokenizer_wpm(vocab);
break;
case LLAMA_VOCAB_TYPE_UGM:
tokenizer = new llm_tokenizer_ugm(vocab);
break;
case LLAMA_VOCAB_TYPE_RWKV:
tokenizer = new llm_tokenizer_rwkv(vocab);
break;
default:
GGML_ABORT("fatal error");
}
return tokenizer;
}
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) { llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch) {
GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE); GGML_ASSERT(llama_vocab_get_type(vocab) != LLAMA_VOCAB_TYPE_NONE);
static const char * hex = "0123456789ABCDEF"; static const char * hex = "0123456789ABCDEF";
@ -1634,14 +1700,14 @@ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
} }
int32_t llama_tokenize_impl( int32_t llama_tokenize_impl(
const struct llama_vocab & vocab, const llm_tokenizer * tokenizer,
const char * text, const char * text,
int32_t text_len, int32_t text_len,
llama_token * tokens, llama_token * tokens,
int32_t n_tokens_max, int32_t n_tokens_max,
bool add_special, bool add_special,
bool parse_special) { bool parse_special) {
auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special); auto res = llama_tokenize_internal(tokenizer, std::string(text, text_len), add_special, parse_special);
if (n_tokens_max < (int) res.size()) { if (n_tokens_max < (int) res.size()) {
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__); // LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
return -((int) res.size()); return -((int) res.size());

View File

@ -64,6 +64,13 @@ struct llama_vocab {
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
}; };
struct llm_tokenizer {
llm_tokenizer(const llama_vocab & vocab) : vocab(vocab) {}
virtual ~llm_tokenizer() = default;
const llama_vocab & vocab;
};
// //
// internal API // internal API
// //
@ -71,11 +78,13 @@ struct llama_vocab {
// TODO: rename to llama_tokenize_impl // TODO: rename to llama_tokenize_impl
// TODO: This should probably be in llama.h // TODO: This should probably be in llama.h
std::vector<llama_vocab::id> llama_tokenize_internal( std::vector<llama_vocab::id> llama_tokenize_internal(
const llama_vocab & vocab, const llm_tokenizer * tokenizer,
std::string raw_text, std::string raw_text,
bool add_special, bool add_special,
bool parse_special = false); bool parse_special = false);
llm_tokenizer * llama_create_tokenizer(const llama_vocab & vocab);
// TODO: move the API below as member functions of llama_vocab // TODO: move the API below as member functions of llama_vocab
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch); llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
@ -106,7 +115,7 @@ llama_token llama_token_eot_impl (const struct llama_vocab & vocab);
llama_token llama_token_eom_impl (const struct llama_vocab & vocab); llama_token llama_token_eom_impl (const struct llama_vocab & vocab);
int32_t llama_tokenize_impl( int32_t llama_tokenize_impl(
const struct llama_vocab & vocab, const llm_tokenizer * tokenizer,
const char * text, const char * text,
int32_t text_len, int32_t text_len,
llama_token * tokens, llama_token * tokens,

View File

@ -2848,6 +2848,7 @@ struct llama_model {
llama_hparams hparams = {}; llama_hparams hparams = {};
llama_vocab vocab; llama_vocab vocab;
llm_tokenizer * tokenizer;
struct ggml_tensor * tok_embd; struct ggml_tensor * tok_embd;
struct ggml_tensor * type_embd; struct ggml_tensor * type_embd;
@ -2923,6 +2924,8 @@ struct llama_model {
while (!lora_adapters.empty()) { while (!lora_adapters.empty()) {
llama_lora_adapter_free(*lora_adapters.begin()); llama_lora_adapter_free(*lora_adapters.begin());
} }
delete tokenizer;
} }
}; };
@ -6404,6 +6407,8 @@ static void llm_load_vocab(
} }
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
model.tokenizer = llama_create_tokenizer(vocab);
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
// For Fill-In-the-Middle (FIM)/infill models which where converted // For Fill-In-the-Middle (FIM)/infill models which where converted
@ -6453,11 +6458,11 @@ static void llm_load_vocab(
} else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) { } else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
vocab.linefeed_id = vocab.special_pad_id; vocab.linefeed_id = vocab.special_pad_id;
} else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) { } else if (vocab.type == LLAMA_VOCAB_TYPE_RWKV) {
const std::vector<int> ids = llama_tokenize_internal(vocab, "\n", false); const std::vector<int> ids = llama_tokenize_internal(model.tokenizer, "\n", false);
GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
vocab.linefeed_id = ids[0]; vocab.linefeed_id = ids[0];
} else { } else {
const std::vector<int> ids = llama_tokenize_internal(vocab, "\xC4\x8A", false); // U+010A const std::vector<int> ids = llama_tokenize_internal(model.tokenizer, "\xC4\x8A", false); // U+010A
GGML_ASSERT(!ids.empty() && "model vocab missing newline token"); GGML_ASSERT(!ids.empty() && "model vocab missing newline token");
vocab.linefeed_id = ids[0]; vocab.linefeed_id = ids[0];
} }
@ -20885,7 +20890,7 @@ int32_t llama_tokenize(
int32_t n_tokens_max, int32_t n_tokens_max,
bool add_special, bool add_special,
bool parse_special) { bool parse_special) {
return llama_tokenize_impl(model->vocab, text, text_len, tokens, n_tokens_max, add_special, parse_special); return llama_tokenize_impl(model->tokenizer, text, text_len, tokens, n_tokens_max, add_special, parse_special);
} }
int32_t llama_token_to_piece( int32_t llama_token_to_piece(

View File

@ -84,6 +84,25 @@ llama_test(test-tokenizer-0 NAME test-tokenizer-0-qwen2 ARGS ${CMAKE
llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf) llama_test(test-tokenizer-0 NAME test-tokenizer-0-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
# build test-tokenizer-parallel target once and add many tests
add_executable(test-tokenizer-parallel test-tokenizer-parallel.cpp)
target_link_libraries(test-tokenizer-parallel PRIVATE common)
install(TARGETS test-tokenizer-parallel RUNTIME)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-bert-bge ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-bert-bge.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-command-r ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-command-r.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-deepseek-coder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-coder.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-deepseek-llm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-deepseek-llm.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-falcon ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-falcon.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-gpt-2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-gpt-2.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-llama-bpe ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-bpe.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-llama-spm ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-llama-spm.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-mpt ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-mpt.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-phi-3 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-phi-3.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-qwen2 ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-qwen2.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-refact ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-refact.gguf)
llama_test(test-tokenizer-parallel NAME test-tokenizer-parallel-starcoder ARGS ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab-starcoder.gguf)
# build test-tokenizer-1-bpe target once and add many tests # build test-tokenizer-1-bpe target once and add many tests
add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp) add_executable(test-tokenizer-1-bpe test-tokenizer-1-bpe.cpp)
target_link_libraries(test-tokenizer-1-bpe PRIVATE common) target_link_libraries(test-tokenizer-1-bpe PRIVATE common)

View File

@ -0,0 +1,180 @@
#include "llama.h"
#include "common.h"
#include "console.h"
#include <cstdio>
#include <string>
#include <map>
#include <vector>
#include <fstream>
#include <thread>
using llama_tests = std::map<std::string, std::vector<llama_token>>;
static llama_tests read_tests(const std::string & fname_inp, const std::string & fname_out) {
llama_tests tests;
std::ifstream ifs_inp(fname_inp);
if (!ifs_inp) {
fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_inp.c_str());
return tests;
}
std::string sraw((std::istreambuf_iterator<char>(ifs_inp)), std::istreambuf_iterator<char>());
std::ifstream ifs_out(fname_out);
if (!ifs_out) {
fprintf(stderr, "%s : error: could not open file '%s'\n", __func__, fname_out.c_str());
return tests;
}
std::vector<std::string> sout;
for (std::string line; std::getline(ifs_out, line);) {
sout.push_back(line);
}
const std::string sep = "\n__ggml_vocab_test__\n";
std::vector<std::string> sinp;
size_t pos = 0;
while (pos < sraw.size()) {
const size_t next = sraw.find(sep, pos);
if (next == std::string::npos) {
sinp.push_back(sraw.substr(pos));
break;
}
sinp.push_back(sraw.substr(pos, next - pos));
pos = next + sep.size();
}
if (sinp.size() != sout.size()) {
fprintf(stderr, "%s : error: input and output files have different number of tests\n", __func__);
return tests;
}
for (size_t i = 0; i < sinp.size(); ++i) {
const std::string & s = sinp[i];
const std::string & o = string_strip(sout[i]);
std::vector<llama_token> toks;
size_t pos = 0;
while (pos < o.size()) {
size_t next = o.find(' ', pos);
if (next == std::string::npos) {
next = o.size();
}
const std::string stok = o.substr(pos, next - pos);
toks.push_back(std::stoi(stok));
pos = next + 1;
}
tests[s] = toks;
}
return tests;
}
int main(int argc, char const *argv[]) {
if (argc < 2) {
fprintf(stderr, "Usage: %s vocab-file \n", argv[0]);
return 1;
}
const std::string fname = argv[1];
const std::string fname_inp = fname + ".inp";
const std::string fname_out = fname + ".out";
fprintf(stderr, "%s : reading vocab from: '%s'\n", __func__, fname.c_str());
llama_model * model;
llama_context * ctx;
llama_backend_init();
// load the vocab
{
auto mparams = llama_model_default_params();
mparams.vocab_only = true;
model = llama_load_model_from_file(fname.c_str(), mparams);
if (model == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
return 1;
}
auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams);
if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
llama_free_model(model);
return 1;
}
}
#ifdef _WIN32
// We need this for unicode console support
console::init(false, false);
atexit([]() { console::cleanup(); });
#endif
const int nthread = std::thread::hardware_concurrency();
std::vector<std::thread> threads(nthread);
bool success = true;
const auto k_tests = [&]() -> llama_tests {
const auto res = read_tests(fname_inp, fname_out);
if (res.empty()) {
fprintf(stderr, "%s : error: no tests found\n", __func__);
exit(1);
}
return res;
}();
const bool add_special = false;
for (int i = 0; i < nthread; i++) {
threads[i] = std::thread([&]() {
for (const auto & test_kv : k_tests) {
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
bool correct = res.size() == test_kv.second.size();
for (int i = 0; i < (int) res.size() && correct; ++i) {
if (test_kv.second[i] != res[i]) {
correct = false;
}
}
if (!correct) {
success = false;
}
}
});
}
for (int i = 0; i < nthread; i++) {
threads[i].join();
}
llama_free_model(model);
llama_free(ctx);
llama_backend_free();
printf("\n");
printf("Tests %s\n", success ? "passed" : "failed");
return success ? 0 : 3;
}