mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
vocab : refactor tokenizer to reduce init overhead (#9449)
* refactor tokenizer * llama : make llm_tokenizer more private ggml-ci * refactor tokenizer * refactor tokenizer * llama : make llm_tokenizer more private ggml-ci * remove unused files * remove unused fileds to avoid unused filed build error * avoid symbol link error * Update src/llama.cpp * Update src/llama.cpp --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
9a913110cf
commit
6102037bbb
@ -201,7 +201,7 @@ static void print_sample_weights(TransformerWeights *w){
|
|||||||
|
|
||||||
//////////////////////////////////////// ggml structs and functions required to load models, configs and save the model.
|
//////////////////////////////////////// ggml structs and functions required to load models, configs and save the model.
|
||||||
|
|
||||||
struct llama_vocab {
|
struct my_llama_vocab {
|
||||||
using id = int32_t;
|
using id = int32_t;
|
||||||
using token = std::string;
|
using token = std::string;
|
||||||
using ttype = llama_token_type;
|
using ttype = llama_token_type;
|
||||||
@ -525,7 +525,7 @@ static std::string llama_escape_whitespaces(const std::string & text) {
|
|||||||
return out.str();
|
return out.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void load_vocab(const char * filename, const Config * config, struct llama_vocab * vocab) {
|
static void load_vocab(const char * filename, const Config * config, struct my_llama_vocab * vocab) {
|
||||||
if (is_ggml_file(filename)) {
|
if (is_ggml_file(filename)) {
|
||||||
LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename);
|
LOG_INF("%s: Loading vocabulary from gguf file %s\n", __func__, filename);
|
||||||
struct ggml_context * ctx_data = NULL;
|
struct ggml_context * ctx_data = NULL;
|
||||||
@ -583,13 +583,13 @@ static void load_vocab(const char * filename, const Config * config, struct llam
|
|||||||
const int n_vocab = config->vocab_size;
|
const int n_vocab = config->vocab_size;
|
||||||
/* uint32_t max_token_length = */ file.read_u32(); // unused
|
/* uint32_t max_token_length = */ file.read_u32(); // unused
|
||||||
vocab->id_to_token.resize(n_vocab);
|
vocab->id_to_token.resize(n_vocab);
|
||||||
for (llama_vocab::id id=0; id<n_vocab; ++id) {
|
for (my_llama_vocab::id id=0; id<n_vocab; ++id) {
|
||||||
float_t score = file.read_f32();
|
float_t score = file.read_f32();
|
||||||
uint32_t len = file.read_u32();
|
uint32_t len = file.read_u32();
|
||||||
std::string text = file.read_string(len);
|
std::string text = file.read_string(len);
|
||||||
|
|
||||||
unsigned char byte_val;
|
unsigned char byte_val;
|
||||||
llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL;
|
my_llama_vocab::ttype type = LLAMA_TOKEN_TYPE_NORMAL;
|
||||||
if (id == UNKNOWN_TOKEN_ID) {
|
if (id == UNKNOWN_TOKEN_ID) {
|
||||||
text = "<unk>";
|
text = "<unk>";
|
||||||
type = LLAMA_TOKEN_TYPE_UNKNOWN;
|
type = LLAMA_TOKEN_TYPE_UNKNOWN;
|
||||||
@ -631,7 +631,7 @@ static void convert_weights_ak_to_gg(struct ggml_tensor * gg_weights, const floa
|
|||||||
}
|
}
|
||||||
|
|
||||||
static void save_as_llama_model(
|
static void save_as_llama_model(
|
||||||
struct llama_vocab * vocab, struct my_llama_model * model, TransformerWeights* w, const char * filename
|
struct my_llama_vocab * vocab, struct my_llama_model * model, TransformerWeights* w, const char * filename
|
||||||
) {
|
) {
|
||||||
// convert AK weights into GG weights one by one.
|
// convert AK weights into GG weights one by one.
|
||||||
// w->token_embedding_table -> model->tok_embeddings
|
// w->token_embedding_table -> model->tok_embeddings
|
||||||
@ -671,7 +671,7 @@ static void save_as_llama_model(
|
|||||||
std::vector<const char*> tokens;
|
std::vector<const char*> tokens;
|
||||||
std::vector<float> scores;
|
std::vector<float> scores;
|
||||||
std::vector<llama_token_type> token_types;
|
std::vector<llama_token_type> token_types;
|
||||||
for (const llama_vocab::token_data & token_data : vocab->id_to_token) {
|
for (const my_llama_vocab::token_data & token_data : vocab->id_to_token) {
|
||||||
tokens.push_back(token_data.text.c_str());
|
tokens.push_back(token_data.text.c_str());
|
||||||
scores.push_back(token_data.score);
|
scores.push_back(token_data.score);
|
||||||
token_types.push_back(token_data.type);
|
token_types.push_back(token_data.type);
|
||||||
@ -905,7 +905,7 @@ int main(int argc, char ** argv) {
|
|||||||
fclose(file);
|
fclose(file);
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_vocab vocab;
|
struct my_llama_vocab vocab;
|
||||||
load_vocab(params.fn_vocab_model, &config, &vocab);
|
load_vocab(params.fn_vocab_model, &config, &vocab);
|
||||||
|
|
||||||
struct my_llama_model model;
|
struct my_llama_model model;
|
||||||
|
@ -50,7 +50,7 @@ struct naive_trie {
|
|||||||
res.first->second.insert(key + 1, len - 1, value);
|
res.first->second.insert(key + 1, len - 1, value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) {
|
std::pair<const char *, size_t> get_longest_prefix(const char * key, size_t len, size_t offset = 0) const {
|
||||||
if (len == 0 || offset == len) {
|
if (len == 0 || offset == len) {
|
||||||
return std::make_pair(key, offset);
|
return std::make_pair(key, offset);
|
||||||
}
|
}
|
||||||
@ -79,6 +79,15 @@ struct naive_trie {
|
|||||||
// impl
|
// impl
|
||||||
//
|
//
|
||||||
|
|
||||||
|
struct llm_tokenizer {
|
||||||
|
llm_tokenizer() {}
|
||||||
|
virtual ~llm_tokenizer() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
llama_vocab::~llama_vocab() {
|
||||||
|
delete tokenizer;
|
||||||
|
}
|
||||||
|
|
||||||
int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
|
int llama_vocab::find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
|
||||||
GGML_ASSERT(token_left.find(' ') == std::string::npos);
|
GGML_ASSERT(token_left.find(' ') == std::string::npos);
|
||||||
GGML_ASSERT(token_left.find('\n') == std::string::npos);
|
GGML_ASSERT(token_left.find('\n') == std::string::npos);
|
||||||
@ -187,10 +196,15 @@ struct llm_bigram_spm {
|
|||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llm_tokenizer_spm {
|
struct llm_tokenizer_spm : llm_tokenizer {
|
||||||
llm_tokenizer_spm(const llama_vocab & vocab) : vocab(vocab) {}
|
llm_tokenizer_spm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llm_tokenizer_spm_session {
|
||||||
|
llm_tokenizer_spm_session(const llama_vocab & vocab) : vocab(vocab) {}
|
||||||
|
|
||||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
|
|
||||||
// split string into utf8 chars
|
// split string into utf8 chars
|
||||||
int index = 0;
|
int index = 0;
|
||||||
size_t offs = 0;
|
size_t offs = 0;
|
||||||
@ -271,7 +285,7 @@ private:
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
resegment(symbols[p->second.first], output);
|
resegment(symbols[p->second.first], output);
|
||||||
resegment(symbols[p->second.second], output);
|
resegment(symbols[p->second.second], output);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -279,7 +293,6 @@ private:
|
|||||||
if (left == -1 || right == -1) {
|
if (left == -1 || right == -1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
|
const std::string text = std::string(symbols[left].text, symbols[left].n + symbols[right].n);
|
||||||
auto token = vocab.token_to_id.find(text);
|
auto token = vocab.token_to_id.find(text);
|
||||||
|
|
||||||
@ -306,10 +319,11 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
const llama_vocab & vocab;
|
const llama_vocab & vocab;
|
||||||
|
// currently unused
|
||||||
|
// const llm_tokenizer_spm * spm_tokenizer;
|
||||||
|
|
||||||
std::vector<llm_symbol> symbols;
|
std::vector<llm_symbol> symbols;
|
||||||
llm_bigram_spm::queue work_queue;
|
llm_bigram_spm::queue work_queue;
|
||||||
|
|
||||||
std::map<std::string, std::pair<int, int>> rev_merge;
|
std::map<std::string, std::pair<int, int>> rev_merge;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -352,8 +366,8 @@ struct llm_bigram_bpe {
|
|||||||
size_t size;
|
size_t size;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llm_tokenizer_bpe {
|
struct llm_tokenizer_bpe : llm_tokenizer {
|
||||||
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
|
llm_tokenizer_bpe(const llama_vocab & vocab) : llm_tokenizer() {
|
||||||
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
|
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
|
||||||
switch (vocab.type_pre) {
|
switch (vocab.type_pre) {
|
||||||
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
|
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
|
||||||
@ -476,7 +490,14 @@ struct llm_tokenizer_bpe {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
|
std::vector<std::string> regex_exprs;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llm_tokenizer_bpe_session {
|
||||||
|
llm_tokenizer_bpe_session(const llama_vocab & vocab) : vocab(vocab),
|
||||||
|
bpe_tokenizer(static_cast<const llm_tokenizer_bpe *>(vocab.tokenizer)) {}
|
||||||
|
|
||||||
|
static void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) {
|
||||||
output.push_back(token_id);
|
output.push_back(token_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -515,12 +536,11 @@ struct llm_tokenizer_bpe {
|
|||||||
|
|
||||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
int final_prev_index = -1;
|
int final_prev_index = -1;
|
||||||
|
const auto word_collection = unicode_regex_split(text, bpe_tokenizer->regex_exprs);
|
||||||
const auto word_collection = unicode_regex_split(text, regex_exprs);
|
|
||||||
|
|
||||||
symbols_final.clear();
|
symbols_final.clear();
|
||||||
|
|
||||||
for (auto & word : word_collection) {
|
for (const auto & word : word_collection) {
|
||||||
work_queue = llm_bigram_bpe::queue();
|
work_queue = llm_bigram_bpe::queue();
|
||||||
symbols.clear();
|
symbols.clear();
|
||||||
|
|
||||||
@ -623,7 +643,6 @@ private:
|
|||||||
if (left == -1 || right == -1) {
|
if (left == -1 || right == -1) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string left_token = std::string(symbols[left].text, symbols[left].n);
|
std::string left_token = std::string(symbols[left].text, symbols[left].n);
|
||||||
std::string right_token = std::string(symbols[right].text, symbols[right].n);
|
std::string right_token = std::string(symbols[right].text, symbols[right].n);
|
||||||
|
|
||||||
@ -647,12 +666,10 @@ private:
|
|||||||
}
|
}
|
||||||
|
|
||||||
const llama_vocab & vocab;
|
const llama_vocab & vocab;
|
||||||
|
const llm_tokenizer_bpe * bpe_tokenizer;
|
||||||
std::vector<std::string> regex_exprs;
|
|
||||||
|
|
||||||
std::vector<llm_symbol> symbols;
|
std::vector<llm_symbol> symbols;
|
||||||
std::vector<llm_symbol> symbols_final;
|
std::vector<llm_symbol> symbols_final;
|
||||||
|
|
||||||
llm_bigram_bpe::queue work_queue;
|
llm_bigram_bpe::queue work_queue;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -660,15 +677,17 @@ private:
|
|||||||
// WPM tokenizer
|
// WPM tokenizer
|
||||||
//
|
//
|
||||||
|
|
||||||
struct llm_tokenizer_wpm {
|
struct llm_tokenizer_wpm : llm_tokenizer {
|
||||||
llm_tokenizer_wpm(const llama_vocab & vocab): vocab(vocab) {}
|
llm_tokenizer_wpm(const llama_vocab & /*vocab*/) : llm_tokenizer() {}
|
||||||
|
};
|
||||||
|
|
||||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) const {
|
struct llm_tokenizer_wpm_session {
|
||||||
|
llm_tokenizer_wpm_session(const llama_vocab & vocab) : vocab(vocab) {}
|
||||||
|
|
||||||
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
const auto & token_map = vocab.token_to_id;
|
const auto & token_map = vocab.token_to_id;
|
||||||
|
|
||||||
// normalize and split by whitespace
|
// normalize and split by whitespace
|
||||||
std::vector<std::string> words = preprocess(text);
|
std::vector<std::string> words = preprocess(text);
|
||||||
|
|
||||||
// bos token prepended already
|
// bos token prepended already
|
||||||
|
|
||||||
// find the longest tokens that form the words
|
// find the longest tokens that form the words
|
||||||
@ -713,7 +732,7 @@ struct llm_tokenizer_wpm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: reduce string copies by using cpts_offs array
|
// TODO: reduce string copies by using cpts_offs array
|
||||||
std::vector<std::string> preprocess(const std::string & text) const {
|
static std::vector<std::string> preprocess(const std::string & text) {
|
||||||
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
|
const std::vector<uint32_t> cpts_nfd = unicode_cpts_normalize_nfd(unicode_cpts_from_utf8(text));
|
||||||
std::vector<std::string> words(1, "");
|
std::vector<std::string> words(1, "");
|
||||||
|
|
||||||
@ -765,15 +784,18 @@ struct llm_tokenizer_wpm {
|
|||||||
//(cpt >= 0xFF00 && cpt <= 0xFFEF);
|
//(cpt >= 0xFF00 && cpt <= 0xFFEF);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
const llama_vocab & vocab;
|
const llama_vocab & vocab;
|
||||||
|
// currently unused
|
||||||
|
// const llm_tokenizer_wpm * wpm_tokenizer;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// UGM tokenizer
|
// UGM tokenizer
|
||||||
//
|
//
|
||||||
|
|
||||||
struct llm_tokenizer_ugm {
|
struct llm_tokenizer_ugm : llm_tokenizer {
|
||||||
llm_tokenizer_ugm(const llama_vocab & vocab) : vocab(vocab) {
|
llm_tokenizer_ugm(const llama_vocab & vocab) : llm_tokenizer() {
|
||||||
if (vocab.precompiled_charsmap.size() > 0) {
|
if (vocab.precompiled_charsmap.size() > 0) {
|
||||||
size_t charsmap_offset = 0;
|
size_t charsmap_offset = 0;
|
||||||
|
|
||||||
@ -819,6 +841,30 @@ struct llm_tokenizer_ugm {
|
|||||||
unknown_token_score = min_score - unknown_token_score_penalty;
|
unknown_token_score = min_score - unknown_token_score_penalty;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// escaped space symbol - U+2581 (Lower One Eighth Block)
|
||||||
|
const std::string escaped_space = "\xE2\x96\x81";
|
||||||
|
|
||||||
|
const char * prefix_replacements = NULL;
|
||||||
|
size_t prefix_replacements_size = 0;
|
||||||
|
|
||||||
|
const uint32_t * xcda_array = NULL;
|
||||||
|
size_t xcda_array_size = 0;
|
||||||
|
|
||||||
|
struct naive_trie user_defined_token_matcher;
|
||||||
|
|
||||||
|
float min_score = FLT_MAX;
|
||||||
|
float max_score = -FLT_MAX;
|
||||||
|
|
||||||
|
float unknown_token_score_penalty = 10.0;
|
||||||
|
float unknown_token_score;
|
||||||
|
|
||||||
|
struct naive_trie token_matcher;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llm_tokenizer_ugm_session {
|
||||||
|
llm_tokenizer_ugm_session(const llama_vocab & vocab) : vocab(vocab),
|
||||||
|
ugm_tokenizer(static_cast<const llm_tokenizer_ugm *>(vocab.tokenizer)) {}
|
||||||
|
|
||||||
/* This implementation is based on SentencePiece optimized Viterbi algorithm for
|
/* This implementation is based on SentencePiece optimized Viterbi algorithm for
|
||||||
* unigram language models. The general idea is to:
|
* unigram language models. The general idea is to:
|
||||||
* - move along the input sequence in steps of one UTF code point,
|
* - move along the input sequence in steps of one UTF code point,
|
||||||
@ -857,7 +903,7 @@ struct llm_tokenizer_ugm {
|
|||||||
// traverse the token matcher trie to find a matching token
|
// traverse the token matcher trie to find a matching token
|
||||||
bool single_codepoint_token_found = false;
|
bool single_codepoint_token_found = false;
|
||||||
const struct best_tokenization & current_best = tokenization_results[input_offset];
|
const struct best_tokenization & current_best = tokenization_results[input_offset];
|
||||||
const struct naive_trie * node = token_matcher.traverse(normalized[prefix_offset++]);
|
const struct naive_trie * node = ugm_tokenizer->token_matcher.traverse(normalized[prefix_offset++]);
|
||||||
|
|
||||||
while (prefix_offset <= input_len && node != NULL) {
|
while (prefix_offset <= input_len && node != NULL) {
|
||||||
// check if we found valid token in prefix
|
// check if we found valid token in prefix
|
||||||
@ -887,7 +933,7 @@ struct llm_tokenizer_ugm {
|
|||||||
// if we didn't find a valid token corresponding to the whole UTF code point
|
// if we didn't find a valid token corresponding to the whole UTF code point
|
||||||
// then use unknown token as the tokenization of this UTF code point
|
// then use unknown token as the tokenization of this UTF code point
|
||||||
if (!single_codepoint_token_found) {
|
if (!single_codepoint_token_found) {
|
||||||
const double challenger_score = current_best.score_sum + unknown_token_score;
|
const double challenger_score = current_best.score_sum + ugm_tokenizer->unknown_token_score;
|
||||||
prefix_offset = input_offset + n_utf8_code_units;
|
prefix_offset = input_offset + n_utf8_code_units;
|
||||||
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
|
||||||
if (challenger_score > current_champ.score_sum) {
|
if (challenger_score > current_champ.score_sum) {
|
||||||
@ -919,7 +965,6 @@ struct llm_tokenizer_ugm {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const llama_vocab & vocab;
|
|
||||||
|
|
||||||
// helper structure for returning normalization results
|
// helper structure for returning normalization results
|
||||||
struct normalization_result {
|
struct normalization_result {
|
||||||
@ -932,7 +977,7 @@ private:
|
|||||||
normalized->clear();
|
normalized->clear();
|
||||||
normalized->reserve(input.size() * 3);
|
normalized->reserve(input.size() * 3);
|
||||||
|
|
||||||
const std::string space = vocab.tokenizer_escape_whitespaces ? escaped_space : " ";
|
const std::string space = vocab.tokenizer_escape_whitespaces ? ugm_tokenizer->escaped_space : " ";
|
||||||
|
|
||||||
bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
|
bool shall_prepend_space = !vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
|
||||||
bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
|
bool shall_append_space = vocab.tokenizer_treat_whitespace_as_suffix && vocab.tokenizer_add_space_prefix;
|
||||||
@ -1014,13 +1059,21 @@ private:
|
|||||||
size_t xcda_array_size;
|
size_t xcda_array_size;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// this structure stores the best tokenization so far at input_offset
|
||||||
|
struct best_tokenization {
|
||||||
|
llama_token token_id;
|
||||||
|
size_t input_offset;
|
||||||
|
float score_sum;
|
||||||
|
};
|
||||||
|
|
||||||
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
|
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {
|
||||||
if (input_offset == input.size()) {
|
if (input_offset == input.size()) {
|
||||||
return { &input[input_offset], 0, 0 };
|
return { &input[input_offset], 0, 0 };
|
||||||
}
|
}
|
||||||
|
|
||||||
// if input prefix matches some user-defined token return this token as normalization result
|
// if input prefix matches some user-defined token return this token as normalization result
|
||||||
auto user_defined_token_match = user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
|
auto user_defined_token_match =
|
||||||
|
ugm_tokenizer->user_defined_token_matcher.get_longest_prefix(&input[input_offset], input.size() - input_offset);
|
||||||
if (user_defined_token_match.second > 0) {
|
if (user_defined_token_match.second > 0) {
|
||||||
return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
|
return { &input[input_offset], user_defined_token_match.second, user_defined_token_match.second };
|
||||||
}
|
}
|
||||||
@ -1028,8 +1081,8 @@ private:
|
|||||||
size_t longest_prefix_length = 0;
|
size_t longest_prefix_length = 0;
|
||||||
size_t longest_prefix_offset = 0;
|
size_t longest_prefix_offset = 0;
|
||||||
|
|
||||||
if (xcda_array_size > 0) {
|
if (ugm_tokenizer->xcda_array_size > 0) {
|
||||||
struct xcda_array_view xcda_view(xcda_array, xcda_array_size);
|
struct xcda_array_view xcda_view(ugm_tokenizer->xcda_array, ugm_tokenizer->xcda_array_size);
|
||||||
|
|
||||||
// Find the longest normalized sequence matching the input prefix by walking
|
// Find the longest normalized sequence matching the input prefix by walking
|
||||||
// the XOR-compressed compact double array (XCDA) starting from the root node
|
// the XOR-compressed compact double array (XCDA) starting from the root node
|
||||||
@ -1065,50 +1118,27 @@ private:
|
|||||||
|
|
||||||
if (longest_prefix_length > 0) {
|
if (longest_prefix_length > 0) {
|
||||||
// we have a match, so return the replacement sequence
|
// we have a match, so return the replacement sequence
|
||||||
if (longest_prefix_offset >= prefix_replacements_size) {
|
if (longest_prefix_offset >= ugm_tokenizer->prefix_replacements_size) {
|
||||||
throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
|
throw std::runtime_error("Index out of array bounds in precompiled charsmap!");
|
||||||
}
|
}
|
||||||
const char * prefix_replacement = &prefix_replacements[longest_prefix_offset];
|
const char * prefix_replacement = &(ugm_tokenizer->prefix_replacements)[longest_prefix_offset];
|
||||||
return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
|
return { prefix_replacement, strlen(prefix_replacement), longest_prefix_length };
|
||||||
} else {
|
}
|
||||||
// check if the input prefix contains a valid sequence of UTF-8 code units
|
|
||||||
try {
|
// check if the input prefix contains a valid sequence of UTF-8 code units
|
||||||
// if yes, return this sequence unmodified
|
try {
|
||||||
size_t prefix_offset = input_offset;
|
// if yes, return this sequence unmodified
|
||||||
unicode_cpt_from_utf8(input, prefix_offset);
|
size_t prefix_offset = input_offset;
|
||||||
return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
|
unicode_cpt_from_utf8(input, prefix_offset);
|
||||||
} catch (std::invalid_argument & /*ex*/) {
|
return { &input[input_offset], prefix_offset - input_offset, prefix_offset - input_offset };
|
||||||
// if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
|
} catch (std::invalid_argument & /*ex*/) {
|
||||||
return { "\xEF\xBF\xBD", 3, 1 };
|
// if no, consume 1 byte and return U+FFFD - REPLACEMENT CHARACTER
|
||||||
}
|
return { "\xEF\xBF\xBD", 3, 1 };
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// escaped space symbol - U+2581 (Lower One Eighth Block)
|
const llama_vocab & vocab;
|
||||||
const std::string escaped_space = "\xE2\x96\x81";
|
const llm_tokenizer_ugm * ugm_tokenizer;
|
||||||
|
|
||||||
const char * prefix_replacements = NULL;
|
|
||||||
size_t prefix_replacements_size = 0;
|
|
||||||
|
|
||||||
const uint32_t * xcda_array = NULL;
|
|
||||||
size_t xcda_array_size = 0;
|
|
||||||
|
|
||||||
struct naive_trie user_defined_token_matcher;
|
|
||||||
|
|
||||||
// this structure stores the best tokenization so far at input_offset
|
|
||||||
struct best_tokenization {
|
|
||||||
llama_token token_id;
|
|
||||||
size_t input_offset;
|
|
||||||
float score_sum;
|
|
||||||
};
|
|
||||||
|
|
||||||
float min_score = FLT_MAX;
|
|
||||||
float max_score = -FLT_MAX;
|
|
||||||
|
|
||||||
float unknown_token_score_penalty = 10.0;
|
|
||||||
float unknown_token_score;
|
|
||||||
|
|
||||||
struct naive_trie token_matcher;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
@ -1169,8 +1199,8 @@ static std::vector<uint8_t> llama_unescape_rwkv_token(const std::string & escape
|
|||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llm_tokenizer_rwkv {
|
struct llm_tokenizer_rwkv : llm_tokenizer {
|
||||||
llm_tokenizer_rwkv(const llama_vocab & vocab): vocab(vocab) {
|
llm_tokenizer_rwkv(const llama_vocab & vocab) : llm_tokenizer() {
|
||||||
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
|
// RWKV supports arbitrary byte tokens, but the vocab struct only supports string tokens.
|
||||||
// For now, we decode the vocab here into the lookup we'll use for tokenization.
|
// For now, we decode the vocab here into the lookup we'll use for tokenization.
|
||||||
|
|
||||||
@ -1182,11 +1212,17 @@ struct llm_tokenizer_rwkv {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct naive_trie token_matcher;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct llm_tokenizer_rwkv_session {
|
||||||
|
llm_tokenizer_rwkv_session(const llama_vocab & vocab) : vocab(vocab),
|
||||||
|
rwkv_tokenizer(static_cast<const llm_tokenizer_rwkv &>(*vocab.tokenizer)) {}
|
||||||
|
|
||||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
uint32_t position = 0;
|
uint32_t position = 0;
|
||||||
|
|
||||||
while (position < text.size()) {
|
while (position < text.size()) {
|
||||||
const struct naive_trie * node = token_matcher.traverse(text[position]);
|
const struct naive_trie * node = rwkv_tokenizer.token_matcher.traverse(text[position]);
|
||||||
if (node == NULL) {
|
if (node == NULL) {
|
||||||
// no matching token found, add unknown token
|
// no matching token found, add unknown token
|
||||||
output.push_back(vocab.special_unk_id);
|
output.push_back(vocab.special_unk_id);
|
||||||
@ -1211,11 +1247,33 @@ struct llm_tokenizer_rwkv {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
const llama_vocab & vocab;
|
const llama_vocab & vocab;
|
||||||
|
const llm_tokenizer_rwkv & rwkv_tokenizer;
|
||||||
struct naive_trie token_matcher;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void llama_vocab::init_tokenizer() {
|
||||||
|
switch (type) {
|
||||||
|
case LLAMA_VOCAB_TYPE_SPM:
|
||||||
|
tokenizer = new llm_tokenizer_spm(*this);
|
||||||
|
break;
|
||||||
|
case LLAMA_VOCAB_TYPE_BPE:
|
||||||
|
tokenizer = new llm_tokenizer_bpe(*this);
|
||||||
|
break;
|
||||||
|
case LLAMA_VOCAB_TYPE_WPM:
|
||||||
|
tokenizer = new llm_tokenizer_wpm(*this);
|
||||||
|
break;
|
||||||
|
case LLAMA_VOCAB_TYPE_UGM:
|
||||||
|
tokenizer = new llm_tokenizer_ugm(*this);
|
||||||
|
break;
|
||||||
|
case LLAMA_VOCAB_TYPE_RWKV:
|
||||||
|
tokenizer = new llm_tokenizer_rwkv(*this);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("unsupported vocab type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// (de-) tokenize
|
// (de-) tokenize
|
||||||
//
|
//
|
||||||
@ -1277,7 +1335,7 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
|
|||||||
|
|
||||||
// if a fragment is text ( not yet processed )
|
// if a fragment is text ( not yet processed )
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
auto & raw_text = fragment.raw_text;
|
const auto & raw_text = fragment.raw_text;
|
||||||
|
|
||||||
auto raw_text_base_offset = fragment.offset;
|
auto raw_text_base_offset = fragment.offset;
|
||||||
auto raw_text_base_length = fragment.length;
|
auto raw_text_base_length = fragment.length;
|
||||||
@ -1376,7 +1434,13 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab, std::string raw_text, bool add_special, bool parse_special) {
|
std::vector<llama_vocab::id> llama_tokenize_internal(
|
||||||
|
const llama_vocab & vocab,
|
||||||
|
std::string raw_text,
|
||||||
|
bool add_special,
|
||||||
|
bool parse_special) {
|
||||||
|
GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
|
||||||
|
|
||||||
std::vector<llama_vocab::id> output;
|
std::vector<llama_vocab::id> output;
|
||||||
std::forward_list<fragment_buffer_variant> fragment_buffer;
|
std::forward_list<fragment_buffer_variant> fragment_buffer;
|
||||||
|
|
||||||
@ -1413,9 +1477,9 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||||||
#ifdef PRETOKENIZERDEBUG
|
#ifdef PRETOKENIZERDEBUG
|
||||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||||
#endif
|
#endif
|
||||||
llm_tokenizer_spm tokenizer(vocab);
|
|
||||||
llama_escape_whitespace(raw_text);
|
llama_escape_whitespace(raw_text);
|
||||||
tokenizer.tokenize(raw_text, output);
|
llm_tokenizer_spm_session session(vocab);
|
||||||
|
session.tokenize(raw_text, output);
|
||||||
is_prev_special = false;
|
is_prev_special = false;
|
||||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||||
output.push_back(fragment.token);
|
output.push_back(fragment.token);
|
||||||
@ -1437,10 +1501,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_BPE:
|
case LLAMA_VOCAB_TYPE_BPE:
|
||||||
{
|
{
|
||||||
llm_tokenizer_bpe tokenizer(vocab);
|
llm_tokenizer_bpe_session session(vocab);
|
||||||
|
// it calls some other methods that are not exist in llm_tokenizer,
|
||||||
|
// here just cast it to bpe tokenizer object
|
||||||
if (add_special) {
|
if (add_special) {
|
||||||
tokenizer.append_bos(output);
|
session.append_bos(output);
|
||||||
}
|
}
|
||||||
for (const auto & fragment : fragment_buffer) {
|
for (const auto & fragment : fragment_buffer) {
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
@ -1449,15 +1514,15 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||||||
#ifdef PRETOKENIZERDEBUG
|
#ifdef PRETOKENIZERDEBUG
|
||||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||||
#endif
|
#endif
|
||||||
tokenizer.tokenize(raw_text, output);
|
session.tokenize(raw_text, output);
|
||||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||||
tokenizer.append(fragment.token, output);
|
session.append(fragment.token, output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_special) {
|
if (add_special) {
|
||||||
tokenizer.append_eos(output);
|
session.append_eos(output);
|
||||||
tokenizer.check_double_bos_eos(output);
|
session.check_double_bos_eos(output);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_WPM:
|
case LLAMA_VOCAB_TYPE_WPM:
|
||||||
@ -1467,7 +1532,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||||||
output.push_back(vocab.special_cls_id);
|
output.push_back(vocab.special_cls_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
llm_tokenizer_wpm tokenizer(vocab);
|
llm_tokenizer_wpm_session session(vocab);
|
||||||
|
|
||||||
for (const auto & fragment : fragment_buffer) {
|
for (const auto & fragment : fragment_buffer) {
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
@ -1476,7 +1541,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||||||
#ifdef PRETOKENIZERDEBUG
|
#ifdef PRETOKENIZERDEBUG
|
||||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||||
#endif
|
#endif
|
||||||
tokenizer.tokenize(raw_text, output);
|
session.tokenize(raw_text, output);
|
||||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||||
output.push_back(fragment.token);
|
output.push_back(fragment.token);
|
||||||
}
|
}
|
||||||
@ -1489,12 +1554,11 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_UGM:
|
case LLAMA_VOCAB_TYPE_UGM:
|
||||||
{
|
{
|
||||||
llm_tokenizer_ugm tokenizer(vocab);
|
|
||||||
|
|
||||||
if (add_special && vocab.tokenizer_add_bos != 0) {
|
if (add_special && vocab.tokenizer_add_bos != 0) {
|
||||||
GGML_ASSERT(vocab.special_bos_id != -1);
|
GGML_ASSERT(vocab.special_bos_id != -1);
|
||||||
output.push_back(vocab.special_bos_id);
|
output.push_back(vocab.special_bos_id);
|
||||||
}
|
}
|
||||||
|
llm_tokenizer_ugm_session session(vocab);
|
||||||
|
|
||||||
for (const auto & fragment : fragment_buffer) {
|
for (const auto & fragment : fragment_buffer) {
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
@ -1502,7 +1566,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||||||
#ifdef PRETOKENIZERDEBUG
|
#ifdef PRETOKENIZERDEBUG
|
||||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||||
#endif
|
#endif
|
||||||
tokenizer.tokenize(raw_text, output);
|
session.tokenize(raw_text, output);
|
||||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||||
output.push_back(fragment.token);
|
output.push_back(fragment.token);
|
||||||
}
|
}
|
||||||
@ -1522,6 +1586,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_RWKV:
|
case LLAMA_VOCAB_TYPE_RWKV:
|
||||||
{
|
{
|
||||||
|
llm_tokenizer_rwkv_session session(vocab);
|
||||||
for (const auto & fragment : fragment_buffer) {
|
for (const auto & fragment : fragment_buffer) {
|
||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||||
@ -1530,8 +1595,7 @@ std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab & vocab,
|
|||||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
llm_tokenizer_rwkv tokenizer(vocab);
|
session.tokenize(raw_text, output);
|
||||||
tokenizer.tokenize(raw_text, output);
|
|
||||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||||
output.push_back(fragment.token);
|
output.push_back(fragment.token);
|
||||||
}
|
}
|
||||||
@ -1644,13 +1708,13 @@ llama_token llama_token_eom_impl(const struct llama_vocab & vocab) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_tokenize_impl(
|
int32_t llama_tokenize_impl(
|
||||||
const struct llama_vocab & vocab,
|
const struct llama_vocab & vocab,
|
||||||
const char * text,
|
const char * text,
|
||||||
int32_t text_len,
|
int32_t text_len,
|
||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens_max,
|
int32_t n_tokens_max,
|
||||||
bool add_special,
|
bool add_special,
|
||||||
bool parse_special) {
|
bool parse_special) {
|
||||||
auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
|
auto res = llama_tokenize_internal(vocab, std::string(text, text_len), add_special, parse_special);
|
||||||
if (n_tokens_max < (int) res.size()) {
|
if (n_tokens_max < (int) res.size()) {
|
||||||
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
// LLAMA_LOG_ERROR("%s: too many tokens\n", __func__);
|
||||||
@ -1775,6 +1839,8 @@ int32_t llama_detokenize_impl(
|
|||||||
int32_t text_len_max,
|
int32_t text_len_max,
|
||||||
bool remove_special,
|
bool remove_special,
|
||||||
bool unparse_special) {
|
bool unparse_special) {
|
||||||
|
GGML_ASSERT(vocab.tokenizer && "Tokenizer not initialized. Call llama_vocab::init_tokenizer() first.");
|
||||||
|
|
||||||
int32_t avail = text_len_max;
|
int32_t avail = text_len_max;
|
||||||
int32_t total = 0;
|
int32_t total = 0;
|
||||||
|
|
||||||
|
@ -8,6 +8,8 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
|
||||||
|
struct llm_tokenizer;
|
||||||
|
|
||||||
struct llama_vocab {
|
struct llama_vocab {
|
||||||
using id = llama_token;
|
using id = llama_token;
|
||||||
using token = std::string;
|
using token = std::string;
|
||||||
@ -65,7 +67,14 @@ struct llama_vocab {
|
|||||||
|
|
||||||
std::vector<char> precompiled_charsmap;
|
std::vector<char> precompiled_charsmap;
|
||||||
|
|
||||||
|
llm_tokenizer * tokenizer = nullptr;
|
||||||
|
|
||||||
|
llama_vocab() = default;
|
||||||
|
~llama_vocab();
|
||||||
|
|
||||||
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
|
||||||
|
|
||||||
|
void init_tokenizer();
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -6464,6 +6464,8 @@ static void llm_load_vocab(
|
|||||||
}
|
}
|
||||||
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
|
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
|
||||||
|
|
||||||
|
vocab.init_tokenizer();
|
||||||
|
|
||||||
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
|
// determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n'
|
||||||
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
// For Fill-In-the-Middle (FIM)/infill models which where converted
|
// For Fill-In-the-Middle (FIM)/infill models which where converted
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include <map>
|
#include <map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
//static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
//static const std::map<std::string, std::vector<llama_token>> & k_tests() {
|
||||||
// static std::map<std::string, std::vector<llama_token>> _k_tests = {
|
// static std::map<std::string, std::vector<llama_token>> _k_tests = {
|
||||||
@ -194,45 +195,64 @@ int main(int argc, char **argv) {
|
|||||||
|
|
||||||
const bool add_special = false;
|
const bool add_special = false;
|
||||||
|
|
||||||
for (const auto & test_kv : k_tests) {
|
// multi-threaded tokenization
|
||||||
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
|
const int nthread = std::thread::hardware_concurrency();
|
||||||
|
std::vector<std::thread> threads(nthread);
|
||||||
|
|
||||||
printf("\n");
|
for (int i = 0; i < nthread; i++) {
|
||||||
printf("src: '%s'\n", test_kv.first.c_str());
|
threads[i] = std::thread([&, i]() {
|
||||||
printf("res: '%s'\n", llama_detokenize(ctx, res).c_str());
|
for (const auto & test_kv : k_tests) {
|
||||||
printf("tok: ");
|
const std::vector<llama_token> res = llama_tokenize(ctx, test_kv.first, add_special, false);
|
||||||
for (const auto & tok : res) {
|
|
||||||
printf("%d ", tok);
|
|
||||||
}
|
|
||||||
printf("\n");
|
|
||||||
|
|
||||||
bool correct = res.size() == test_kv.second.size();
|
// here only print the result of the first thread
|
||||||
for (int i = 0; i < (int) res.size() && correct; ++i) {
|
// because the other threads are running the same tests
|
||||||
if (test_kv.second[i] != res[i]) {
|
if (i != 0) {
|
||||||
correct = false;
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
printf("\n");
|
||||||
|
printf("src: '%s'\n", test_kv.first.c_str());
|
||||||
|
printf("res: '%s'\n", llama_detokenize(ctx, res).c_str());
|
||||||
|
printf("tok: ");
|
||||||
|
for (const auto & tok : res) {
|
||||||
|
printf("%d ", tok);
|
||||||
|
}
|
||||||
|
printf("\n");
|
||||||
|
|
||||||
|
bool correct = res.size() == test_kv.second.size();
|
||||||
|
for (int i = 0; i < (int) res.size() && correct; ++i) {
|
||||||
|
if (test_kv.second[i] != res[i]) {
|
||||||
|
correct = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!correct) {
|
||||||
|
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
||||||
|
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
|
||||||
|
llama_detokenize(ctx, res).c_str(),
|
||||||
|
llama_detokenize(ctx, test_kv.second).c_str());
|
||||||
|
fprintf(stderr, "%s : expected tokens: ", __func__);
|
||||||
|
for (const auto & t : test_kv.second) {
|
||||||
|
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
fprintf(stderr, "%s : got tokens: ", __func__);
|
||||||
|
for (const auto & t : res) {
|
||||||
|
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
|
||||||
|
}
|
||||||
|
fprintf(stderr, "\n");
|
||||||
|
|
||||||
|
success = false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
|
|
||||||
if (!correct) {
|
|
||||||
fprintf(stderr, "%s : failed test: '%s'\n", __func__, test_kv.first.c_str());
|
|
||||||
fprintf(stderr, "%s : detokenized to: '%s' instead of '%s'\n", __func__,
|
|
||||||
llama_detokenize(ctx, res).c_str(),
|
|
||||||
llama_detokenize(ctx, test_kv.second).c_str());
|
|
||||||
fprintf(stderr, "%s : expected tokens: ", __func__);
|
|
||||||
for (const auto & t : test_kv.second) {
|
|
||||||
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
|
|
||||||
}
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
fprintf(stderr, "%s : got tokens: ", __func__);
|
|
||||||
for (const auto & t : res) {
|
|
||||||
fprintf(stderr, "%6d '%s', ", t, llama_token_to_piece(ctx, t).c_str());
|
|
||||||
}
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
|
|
||||||
success = false;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < nthread; i++) {
|
||||||
|
threads[i].join();
|
||||||
|
}
|
||||||
|
|
||||||
|
// single threaded tokenization
|
||||||
if (!fname_text.empty()) {
|
if (!fname_text.empty()) {
|
||||||
fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
|
fprintf(stderr, "%s : tokenizing: '%s'\n", __func__, fname_text.c_str());
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user