cmpnct_gpt2bpe.hpp : remove non-general stuff

This commit is contained in:
klosax 2023-08-19 13:19:02 +02:00 committed by GitHub
parent 8945d47f52
commit 6a2e520095
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -14,6 +14,7 @@
//----- //-----
// Unicode GPT2 Byte Pair Encoding Tokenizer // Unicode GPT2 Byte Pair Encoding Tokenizer
// Adapted from https://github.com/cmp-nct/ggllm.cpp // Adapted from https://github.com/cmp-nct/ggllm.cpp
// Removed loading of merges from HF json and parts made for a specific vocab
//----- //-----
// Unicode library (from cmpnct_unicode.cpp) // Unicode library (from cmpnct_unicode.cpp)
@ -439,11 +440,10 @@ private:
struct gpt2bpe_vocab { struct gpt2bpe_vocab {
using id = int32_t; using id = int32_t;
using token = std::string; using token = std::string;
std::map<std::string, uint32_t> max_token_length; // max length, for each 2byte prefix
std::map<std::string, uint32_t> max_token_length; // max length, for each 2byte prefix
std::map<std::pair<std::string,std::string>, int> bpe_ranks; std::map<std::pair<std::string,std::string>, int> bpe_ranks;
std::vector<std::pair<std::string, std::string>> bpe_merges; std::vector<std::pair<std::string, std::string>> bpe_merges;
std::map<std::string, int> special_tokens;
id special_bos_id = -1; id special_bos_id = -1;
id special_eos_id = -1; id special_eos_id = -1;
@ -476,22 +476,6 @@ struct gpt2bpe_vocab {
bpe_ranks.emplace(bpe_merges_[i], i); bpe_ranks.emplace(bpe_merges_[i], i);
} }
bpe_merges = bpe_merges_; bpe_merges = bpe_merges_;
// populate special tokens too (0-11 and if available 65024++)
#if 0
for (int i = 0; i < 12; i++) {
special_tokens[id_to_token[i].tok] = i;
}
for (int i = 65024; i < (int)id_to_token.size(); i++) {
special_tokens[id_to_token[i].tok] = i;
}
#endif
// token_to_id["</s>"] = 11; // bugfix for TII instruct training (blocks stopwords)
// special_tokens["</s>"] = 11; // bugfix for TII instruct training (blocks stopwords)
return bpe_merges_.size(); return bpe_merges_.size();
} }
@ -508,10 +492,6 @@ struct gpt2bpe_vocab {
}).base(), str.end()); }).base(), str.end());
} }
// removed, merges loaded from gguf model file:
// requires the standard HF type tokenizer.json (pretty printed)
// std::vector<std::pair<std::string, std::string>> parse_json_to_bpe_merges(const std::string& filename) {
// get max token length available for a prefix of 2 bytes (string at least 2 bytes long) // get max token length available for a prefix of 2 bytes (string at least 2 bytes long)
int get_max_token_length(const std::string& string) const { int get_max_token_length(const std::string& string) const {
if (string.size() < 2) if (string.size() < 2)
@ -609,45 +589,27 @@ struct gpt2bpe_tokenizer {
{ {
work_queue_ = ggllm_bpe_bigram::queue(); work_queue_ = ggllm_bpe_bigram::queue();
symbols_.clear(); symbols_.clear();
bool is_special = false;
for (auto it = vocab_.special_tokens.begin(); it != vocab_.special_tokens.end(); ++it)
{
std::string special_token = it->first;
if (word.compare(special_token) == 0)
{
ggllm_bpe_symbol sym;
sym.text = word.c_str();
sym.n = word.size();
sym.prev = -1;
sym.next = -1;
symbols_.emplace_back(sym);
is_special = true;
break;
}
}
int index = 0; int index = 0;
size_t offset = 0; size_t offset = 0;
if (!is_special)
{
while (offset < word.size()) while (offset < word.size())
{ {
ggllm_bpe_symbol sym; ggllm_bpe_symbol sym;
size_t char_len = std::min(word.size() - offset, (size_t) CNCTUnicode::utf8_len(word[offset])); size_t char_len = std::min(word.size() - offset, (size_t) CNCTUnicode::utf8_len(word[offset]));
sym.text = word.c_str() + offset; sym.text = word.c_str() + offset;
sym.n = 1; sym.n = 1;
sym.n = char_len; sym.n = char_len;
offset += sym.n; offset += sym.n;
sym.prev = index - 1; sym.prev = index - 1;
sym.next = offset == word.size() ? -1 : index + 1; sym.next = offset == word.size() ? -1 : index + 1;
index++; index++;
symbols_.emplace_back(sym); symbols_.emplace_back(sym);
}
for (size_t i = 1; i < symbols_.size(); ++i) {
add_new_bigram(i - 1, i);
}
} }
for (size_t i = 1; i < symbols_.size(); ++i) {
add_new_bigram(i - 1, i);
}
// build token(s) // build token(s)
while (!work_queue_.empty()) while (!work_queue_.empty())
{ {
@ -790,17 +752,6 @@ private:
bpe_encoded_words.reserve(text.size()); bpe_encoded_words.reserve(text.size());
text_utf = CNCTUnicode::split_utf8_enhanced(text); text_utf = CNCTUnicode::split_utf8_enhanced(text);
std::map<std::string, int> special_tokens = vocab_.special_tokens;
int smallest_len_special_tokens = 0;
if (special_tokens.size())
{
smallest_len_special_tokens = special_tokens.begin()->first.size();
for (auto it = special_tokens.begin(); it != special_tokens.end(); ++it)
{
if (it->first.size() < (size_t)smallest_len_special_tokens)
smallest_len_special_tokens = it->first.size();
}
}
for (int i = 0; i < (int)text_utf.size(); i++) for (int i = 0; i < (int)text_utf.size(); i++)
{ {
@ -813,41 +764,6 @@ private:
const CNCTString &utf_char_next_next = (i+2 < (int)text_utf.size()) ? text_utf[i+2] : CNCTString(); const CNCTString &utf_char_next_next = (i+2 < (int)text_utf.size()) ? text_utf[i+2] : CNCTString();
// const CNCTString &utf_char_prev = (i > 0) ? text_utf[i-1] : CNCTString(); // const CNCTString &utf_char_prev = (i > 0) ? text_utf[i-1] : CNCTString();
// handling special tokens
bool special_token_found = false;
if (bytes_remain >= (int)smallest_len_special_tokens)
for (auto it = special_tokens.begin(); it != special_tokens.end(); ++it)
{
if ((bytes_remain) < (int)it->first.size())
continue;
if (str_is_equal(text_pos, it->first.c_str(), it->first.size()))
{
if (token.size())
{
bpe_words.emplace_back(token); // push previous content as token
token.clear();
collecting = false;
collecting_letter = false;
collecting_numeric = false;
collecting_special = false;
collecting_whitespace_lookahead = false;
}
bpe_words.emplace_back(it->first); // push special token as token
// we now advance i until the token is fulfilled by the utf_chars
int st_bytes = (int)it->first.size();
for (;st_bytes;st_bytes -= text_utf[i++].str.size());
i--;
special_token_found = true;
break;
}
}
if (special_token_found) continue;
// handling contractions // handling contractions
if (!split_condition && bytes_remain >= 2) if (!split_condition && bytes_remain >= 2)
{ {