mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
tokenizer : BPE fixes (#7530)
* Random test: add_bos_token, add_eos_token * Random test: add BPE models for testing * Custom regex split fails with codepoint 0 * Fix falcon punctuation regex * Refactor llm_tokenizer_bpe: move code to constructor * Move 'add_special_bos/eos' logic to llm_tokenizer_bpe * Move tokenizer flags to vocab structure. * Default values for special_add_bos/eos * Build vocab.special_tokens_cache using vocab token types * Generalize 'jina-v2' per token attributes * Fix unicode whitespaces (deepseek-coder, deepseek-llm) * Skip missing byte tokens (falcon) * Better unicode data generation * Replace char32_t with uint32_t
This commit is contained in:
parent
91c188d6c2
commit
37bef89433
179
llama.cpp
179
llama.cpp
@ -2310,16 +2310,17 @@ struct llama_vocab {
|
|||||||
id special_cls_id = -1;
|
id special_cls_id = -1;
|
||||||
id special_mask_id = -1;
|
id special_mask_id = -1;
|
||||||
|
|
||||||
int special_add_bos = -1; // -1 unknown, 1 add, 0 don't add.
|
|
||||||
int special_add_eos = -1; // -1 unknown, 1 add, 0 don't add.
|
|
||||||
|
|
||||||
id linefeed_id = 13;
|
id linefeed_id = 13;
|
||||||
id special_prefix_id = -1;
|
id special_prefix_id = -1;
|
||||||
id special_suffix_id = -1;
|
id special_suffix_id = -1;
|
||||||
id special_middle_id = -1;
|
id special_middle_id = -1;
|
||||||
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token
|
||||||
|
|
||||||
bool add_space_prefix = true;
|
// tokenizer flags
|
||||||
|
bool tokenizer_add_space_prefix = true;
|
||||||
|
bool tokenizer_add_bos = false;
|
||||||
|
bool tokenizer_add_eos = false;
|
||||||
|
bool tokenizer_ignore_merges = false;
|
||||||
|
|
||||||
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
|
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
|
||||||
GGML_ASSERT(token_left.find(' ') == std::string::npos);
|
GGML_ASSERT(token_left.find(' ') == std::string::npos);
|
||||||
@ -4770,7 +4771,7 @@ static void llm_load_vocab(
|
|||||||
|
|
||||||
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
|
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
|
||||||
if (add_space_prefix_keyidx != -1) {
|
if (add_space_prefix_keyidx != -1) {
|
||||||
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
|
vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
|
||||||
} // The default value of add_space_prefix is true.
|
} // The default value of add_space_prefix is true.
|
||||||
} else if (tokenizer_model == "bert") {
|
} else if (tokenizer_model == "bert") {
|
||||||
vocab.type = LLAMA_VOCAB_TYPE_WPM;
|
vocab.type = LLAMA_VOCAB_TYPE_WPM;
|
||||||
@ -4783,13 +4784,13 @@ static void llm_load_vocab(
|
|||||||
vocab.special_pad_id = 0;
|
vocab.special_pad_id = 0;
|
||||||
vocab.special_cls_id = 101;
|
vocab.special_cls_id = 101;
|
||||||
vocab.special_mask_id = 103;
|
vocab.special_mask_id = 103;
|
||||||
vocab.add_space_prefix = false;
|
vocab.tokenizer_add_space_prefix = false;
|
||||||
} else if (tokenizer_model == "gpt2") {
|
} else if (tokenizer_model == "gpt2") {
|
||||||
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
||||||
|
|
||||||
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
|
const int add_space_prefix_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_ADD_PREFIX).c_str());
|
||||||
if (add_space_prefix_keyidx != -1) {
|
if (add_space_prefix_keyidx != -1) {
|
||||||
vocab.add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
|
vocab.tokenizer_add_space_prefix = gguf_get_val_bool(ctx, add_space_prefix_keyidx);
|
||||||
}
|
}
|
||||||
|
|
||||||
// read bpe merges and populate bpe ranks
|
// read bpe merges and populate bpe ranks
|
||||||
@ -4847,6 +4848,8 @@ static void llm_load_vocab(
|
|||||||
tokenizer_pre == "llama-v3" ||
|
tokenizer_pre == "llama-v3" ||
|
||||||
tokenizer_pre == "llama-bpe") {
|
tokenizer_pre == "llama-bpe") {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
|
||||||
|
vocab.tokenizer_ignore_merges = true;
|
||||||
|
vocab.tokenizer_add_bos = true;
|
||||||
} else if (
|
} else if (
|
||||||
tokenizer_pre == "deepseek-llm") {
|
tokenizer_pre == "deepseek-llm") {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
|
||||||
@ -4897,6 +4900,14 @@ static void llm_load_vocab(
|
|||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
}
|
}
|
||||||
|
} else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
vocab.tokenizer_add_bos = true;
|
||||||
|
vocab.tokenizer_add_eos = false;
|
||||||
|
} else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
|
||||||
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
|
vocab.tokenizer_add_bos = true;
|
||||||
|
vocab.tokenizer_add_eos = false;
|
||||||
} else {
|
} else {
|
||||||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||||
}
|
}
|
||||||
@ -5041,10 +5052,10 @@ static void llm_load_vocab(
|
|||||||
bool temp = true;
|
bool temp = true;
|
||||||
|
|
||||||
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
|
if (ml.get_key(LLM_KV_TOKENIZER_ADD_BOS, temp, false)) {
|
||||||
vocab.special_add_bos = int(temp);
|
vocab.tokenizer_add_bos = temp;
|
||||||
}
|
}
|
||||||
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
|
if (ml.get_key(LLM_KV_TOKENIZER_ADD_EOS, temp, false)) {
|
||||||
vocab.special_add_eos = int(temp);
|
vocab.tokenizer_add_eos = temp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5144,7 +5155,7 @@ static void llm_load_vocab(
|
|||||||
);
|
);
|
||||||
|
|
||||||
// set attributes by model/tokenizer name
|
// set attributes by model/tokenizer name
|
||||||
if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
|
if (_contains_any(tokenizer_pre, {"jina-v2-de", "jina-v2-es", "jina-v2-code"})) {
|
||||||
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
|
||||||
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
|
||||||
for (auto id : vocab.cache_special_tokens) {
|
for (auto id : vocab.cache_special_tokens) {
|
||||||
@ -13158,113 +13169,143 @@ struct llm_bigram_bpe {
|
|||||||
};
|
};
|
||||||
|
|
||||||
struct llm_tokenizer_bpe {
|
struct llm_tokenizer_bpe {
|
||||||
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {}
|
llm_tokenizer_bpe(const llama_vocab & vocab): vocab(vocab) {
|
||||||
|
GGML_ASSERT(vocab.type == LLAMA_VOCAB_TYPE_BPE);
|
||||||
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
|
||||||
int final_prev_index = -1;
|
|
||||||
bool ignore_merges = false;
|
|
||||||
|
|
||||||
std::vector<std::string> word_collection;
|
|
||||||
switch (vocab.type) {
|
|
||||||
case LLAMA_VOCAB_TYPE_BPE:
|
|
||||||
switch (vocab.type_pre) {
|
switch (vocab.type_pre) {
|
||||||
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
|
case LLAMA_VOCAB_PRE_TYPE_LLAMA3:
|
||||||
ignore_merges = true;
|
regex_exprs = {
|
||||||
word_collection = unicode_regex_split(text, {
|
|
||||||
// original regex from tokenizer.json
|
// original regex from tokenizer.json
|
||||||
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
//"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
|
||||||
// adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
|
// adapted: https://github.com/ggerganov/llama.cpp/pull/6920#issuecomment-2080233989
|
||||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_DBRX:
|
case LLAMA_VOCAB_PRE_TYPE_DBRX:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
|
case LLAMA_VOCAB_PRE_TYPE_SMAUG:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
// same as llama3
|
// same as llama3
|
||||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
|
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"[\r\n]",
|
"[\r\n]",
|
||||||
"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
|
"\\s?[A-Za-zµÀ-ÖØ-öø-ƺƼ-ƿDŽ-ʓʕ-ʯͰ-ͳͶͷͻ-ͽͿΆΈ-ΊΌΎ-ΡΣ-ϵϷ-ҁҊ-ԯԱ-ՖႠ-ჅᎠ-Ᏽᏸ-ᏽᲐ-ᲺᲽ-Ჿᴀ-ᴫᵫ-ᵷᵹ-ᶚḀ-ἕἘ-Ἕἠ-ὅὈ-Ὅὐ-ὗὙὛὝὟ-ώᾀ-ᾴᾶ-ᾼιῂ-ῄῆ-ῌῐ-ΐῖ-Ίῠ-Ῥῲ-ῴῶ-ῼℂℇℊ-ℓℕℙ-ℝℤΩℨK-ℭℯ-ℴℹℼ-ℿⅅ-ⅉⅎↃↄⰀ-ⱻⱾ-ⳤⳫ-ⳮⳲⳳꙀ-ꙭꚀ-ꚛꜢ-ꝯꝱ-ꞇꞋ-ꞎꭰ-ꮿff-stﬓ-ﬗA-Za-z𐐀-𐑏𐒰-𐓓𐓘-𐓻𐲀-𐲲𐳀-𐳲𑢠-𑣟𞤀-𞥃]+",
|
||||||
"\\s?[!-/:-~!-/:-~‘-‟ -。]+",
|
"\\s?[!-/:-~!-/:-~‘-‟ -。]+",
|
||||||
"\\s+$",
|
"\\s+$",
|
||||||
"[一-龥ࠀ-一가-]+",
|
"[一-龥ࠀ-一가-]+",
|
||||||
"\\p{N}+",
|
"\\p{N}+",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
|
case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"[\r\n]",
|
"[\r\n]",
|
||||||
"\\s?\\p{L}+",
|
"\\s?\\p{L}+",
|
||||||
"\\s?\\p{P}+",
|
"\\s?\\p{P}+",
|
||||||
"[一-龥ࠀ-一가-]+",
|
"[一-龥ࠀ-一가-]+",
|
||||||
"\\p{N}",
|
"\\p{N}",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_FALCON:
|
case LLAMA_VOCAB_PRE_TYPE_FALCON:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"[\\p{P}\\$\\+<=>\\^~\\|]+",
|
"[\\p{P}\\$\\+<=>\\^~\\|`]+",
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
"[0-9][0-9][0-9]",
|
"[0-9][0-9][0-9]",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
case LLAMA_VOCAB_PRE_TYPE_MPT:
|
||||||
// TODO: MPT pre-tokenization regexes are unknown
|
// TODO: MPT pre-tokenization regexes are unknown
|
||||||
// the following are close, but not exact. run the following:
|
// the following are close, but not exact. run the following:
|
||||||
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
|
// ./bin/test-tokenizer-0 ../models/ggml-vocab-mpt.gguf
|
||||||
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
|
GGML_ASSERT("MPT pre-tokenization regexes are unknown - fixes needed");
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"\\s?\\p{L}+",
|
"\\s?\\p{L}+",
|
||||||
"\\s?\\p{P}+",
|
"\\s?\\p{P}+",
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
|
case LLAMA_VOCAB_PRE_TYPE_STARCODER:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_REFACT:
|
case LLAMA_VOCAB_PRE_TYPE_REFACT:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
|
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"\\p{N}",
|
"\\p{N}",
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_GPT2:
|
case LLAMA_VOCAB_PRE_TYPE_GPT2:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
case LLAMA_VOCAB_PRE_TYPE_OLMO:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
case LLAMA_VOCAB_PRE_TYPE_STABLELM2:
|
||||||
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
case LLAMA_VOCAB_PRE_TYPE_QWEN2:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
// original regex from tokenizer.json
|
// original regex from tokenizer.json
|
||||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
case LLAMA_VOCAB_PRE_TYPE_PORO:
|
case LLAMA_VOCAB_PRE_TYPE_PORO:
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
" ?[^(\\s|.,!?…。,、।۔،)]+",
|
" ?[^(\\s|.,!?…。,、।۔،)]+",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
// default regex for BPE tokenization pre-processing
|
// default regex for BPE tokenization pre-processing
|
||||||
word_collection = unicode_regex_split(text, {
|
regex_exprs = {
|
||||||
"[\\p{P}\\$\\+<=>\\^~\\|]+",
|
"[\\p{P}\\$\\+<=>\\^~\\|]+",
|
||||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||||
"\\p{N}+",
|
"\\p{N}+",
|
||||||
"[0-9][0-9][0-9]",
|
"[0-9][0-9][0-9]",
|
||||||
});
|
};
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
default:
|
|
||||||
GGML_ASSERT(false);
|
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void append(const llama_vocab::id token_id, std::vector<llama_vocab::id> & output) const {
|
||||||
|
output.push_back(token_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool append_bos(std::vector<llama_vocab::id> & output) const {
|
||||||
|
if (vocab.tokenizer_add_bos) {
|
||||||
|
GGML_ASSERT(vocab.special_bos_id != -1);
|
||||||
|
output.push_back(vocab.special_bos_id);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool append_eos(std::vector<llama_vocab::id> & output) const {
|
||||||
|
if (vocab.tokenizer_add_eos) {
|
||||||
|
GGML_ASSERT(vocab.special_eos_id != -1);
|
||||||
|
output.push_back(vocab.special_eos_id);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void check_double_bos_eos(const std::vector<llama_vocab::id> & output) const {
|
||||||
|
if (vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
|
||||||
|
LLAMA_LOG_WARN(
|
||||||
|
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
||||||
|
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
||||||
|
"Are you sure this is what you want?\n", __FUNCTION__);
|
||||||
|
}
|
||||||
|
if (vocab.tokenizer_add_eos && output.size() >= 2 && *(output.end()-2) == vocab.special_eos_id) {
|
||||||
|
LLAMA_LOG_WARN(
|
||||||
|
"%s: Added a EOS token to the prompt as specified by the model but the prompt "
|
||||||
|
"also ends with a EOS token. So now the final prompt ends with 2 EOS tokens. "
|
||||||
|
"Are you sure this is what you want?\n", __FUNCTION__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void tokenize(const std::string & text, std::vector<llama_vocab::id> & output) {
|
||||||
|
int final_prev_index = -1;
|
||||||
|
|
||||||
|
const auto word_collection = unicode_regex_split(text, regex_exprs);
|
||||||
|
|
||||||
symbols_final.clear();
|
symbols_final.clear();
|
||||||
|
|
||||||
for (auto & word : word_collection) {
|
for (auto & word : word_collection) {
|
||||||
@ -13274,7 +13315,7 @@ struct llm_tokenizer_bpe {
|
|||||||
int index = 0;
|
int index = 0;
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
|
|
||||||
if (ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
|
if (vocab.tokenizer_ignore_merges && vocab.token_to_id.find(word) != vocab.token_to_id.end()) {
|
||||||
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
symbols.emplace_back(llm_symbol{-1, -1, word.c_str(), word.size()});
|
||||||
offset = word.size();
|
offset = word.size();
|
||||||
}
|
}
|
||||||
@ -13355,10 +13396,9 @@ struct llm_tokenizer_bpe {
|
|||||||
for (auto j = str.begin(); j != str.end(); ++j) {
|
for (auto j = str.begin(); j != str.end(); ++j) {
|
||||||
std::string byte_str(1, *j);
|
std::string byte_str(1, *j);
|
||||||
auto token_multibyte = vocab.token_to_id.find(byte_str);
|
auto token_multibyte = vocab.token_to_id.find(byte_str);
|
||||||
if (token_multibyte == vocab.token_to_id.end()) {
|
if (token_multibyte != vocab.token_to_id.end()) {
|
||||||
throw std::runtime_error("ERROR: byte not found in vocab");
|
output.push_back(token_multibyte->second);
|
||||||
}
|
}
|
||||||
output.push_back((*token_multibyte).second);
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
output.push_back((*token).second);
|
output.push_back((*token).second);
|
||||||
@ -13397,6 +13437,8 @@ private:
|
|||||||
|
|
||||||
const llama_vocab & vocab;
|
const llama_vocab & vocab;
|
||||||
|
|
||||||
|
std::vector<std::string> regex_exprs;
|
||||||
|
|
||||||
std::vector<llm_symbol> symbols;
|
std::vector<llm_symbol> symbols;
|
||||||
std::vector<llm_symbol> symbols_final;
|
std::vector<llm_symbol> symbols_final;
|
||||||
|
|
||||||
@ -13677,7 +13719,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||||||
|
|
||||||
bool is_prev_special = false;
|
bool is_prev_special = false;
|
||||||
|
|
||||||
if (add_special && vocab.special_add_bos != 0) {
|
if (add_special && vocab.tokenizer_add_bos) {
|
||||||
GGML_ASSERT(vocab.special_bos_id != -1);
|
GGML_ASSERT(vocab.special_bos_id != -1);
|
||||||
output.push_back(vocab.special_bos_id);
|
output.push_back(vocab.special_bos_id);
|
||||||
is_prev_special = true;
|
is_prev_special = true;
|
||||||
@ -13687,7 +13729,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||||||
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
|
||||||
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
|
||||||
|
|
||||||
if (vocab.add_space_prefix) {
|
if (vocab.tokenizer_add_space_prefix) {
|
||||||
if (!output.size() || is_prev_special) { // prefix with space if first token
|
if (!output.size() || is_prev_special) { // prefix with space if first token
|
||||||
raw_text = " " + raw_text;
|
raw_text = " " + raw_text;
|
||||||
}
|
}
|
||||||
@ -13705,23 +13747,24 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
|
if (add_special && vocab.tokenizer_add_bos && output.size() >= 2 && output[1] == vocab.special_bos_id) {
|
||||||
LLAMA_LOG_WARN(
|
LLAMA_LOG_WARN(
|
||||||
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
||||||
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
||||||
"Are you sure this is what you want?\n", __FUNCTION__);
|
"Are you sure this is what you want?\n", __FUNCTION__);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_special && vocab.special_add_eos == 1) {
|
if (add_special && vocab.tokenizer_add_eos) {
|
||||||
GGML_ASSERT(vocab.special_eos_id != -1);
|
GGML_ASSERT(vocab.special_eos_id != -1);
|
||||||
output.push_back(vocab.special_eos_id);
|
output.push_back(vocab.special_eos_id);
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_BPE:
|
case LLAMA_VOCAB_TYPE_BPE:
|
||||||
{
|
{
|
||||||
if (add_special && vocab.special_add_bos != 0) {
|
llm_tokenizer_bpe tokenizer(vocab);
|
||||||
GGML_ASSERT(vocab.special_bos_id != -1);
|
|
||||||
output.push_back(vocab.special_bos_id);
|
if (add_special) {
|
||||||
|
tokenizer.append_bos(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto & fragment : fragment_buffer) {
|
for (const auto & fragment : fragment_buffer) {
|
||||||
@ -13731,23 +13774,15 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
|
|||||||
#ifdef PRETOKENIZERDEBUG
|
#ifdef PRETOKENIZERDEBUG
|
||||||
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", raw_text.length(), fragment.offset, fragment.length, raw_text.c_str());
|
||||||
#endif
|
#endif
|
||||||
llm_tokenizer_bpe tokenizer(vocab);
|
|
||||||
tokenizer.tokenize(raw_text, output);
|
tokenizer.tokenize(raw_text, output);
|
||||||
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
|
||||||
output.push_back(fragment.token);
|
tokenizer.append(fragment.token, output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (add_special && vocab.special_add_bos != 0 && output.size() >= 2 && output[1] == vocab.special_bos_id) {
|
if (add_special) {
|
||||||
LLAMA_LOG_WARN(
|
tokenizer.append_eos(output);
|
||||||
"%s: Added a BOS token to the prompt as specified by the model but the prompt "
|
tokenizer.check_double_bos_eos(output);
|
||||||
"also starts with a BOS token. So now the final prompt starts with 2 BOS tokens. "
|
|
||||||
"Are you sure this is what you want?\n", __FUNCTION__);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (add_special && vocab.special_add_eos == 1) {
|
|
||||||
GGML_ASSERT(vocab.special_add_eos != -1);
|
|
||||||
output.push_back(vocab.special_eos_id);
|
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLAMA_VOCAB_TYPE_WPM:
|
case LLAMA_VOCAB_TYPE_WPM:
|
||||||
@ -18320,11 +18355,11 @@ llama_token llama_token_nl(const struct llama_model * model) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_add_bos_token(const struct llama_model * model) {
|
int32_t llama_add_bos_token(const struct llama_model * model) {
|
||||||
return model->vocab.special_add_bos;
|
return model->vocab.tokenizer_add_bos;
|
||||||
}
|
}
|
||||||
|
|
||||||
int32_t llama_add_eos_token(const struct llama_model * model) {
|
int32_t llama_add_eos_token(const struct llama_model * model) {
|
||||||
return model->vocab.special_add_eos;
|
return model->vocab.tokenizer_add_eos;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_token_prefix(const struct llama_model * model) {
|
llama_token llama_token_prefix(const struct llama_model * model) {
|
||||||
|
@ -1,83 +1,143 @@
|
|||||||
import regex
|
import array
|
||||||
import ctypes
|
|
||||||
import unicodedata
|
import unicodedata
|
||||||
|
import requests
|
||||||
|
|
||||||
class CoodepointFlags (ctypes.Structure):
|
|
||||||
_fields_ = [ # see definition in unicode.h
|
|
||||||
("is_undefined", ctypes.c_uint16, 1),
|
|
||||||
("is_number", ctypes.c_uint16, 1), # regex: \p{N}
|
|
||||||
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
|
|
||||||
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
|
|
||||||
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
|
|
||||||
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
|
|
||||||
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
|
|
||||||
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
assert (ctypes.sizeof(CoodepointFlags) == 2)
|
|
||||||
|
|
||||||
|
|
||||||
MAX_CODEPOINTS = 0x110000
|
MAX_CODEPOINTS = 0x110000
|
||||||
|
|
||||||
regex_number = regex.compile(r'\p{N}')
|
UNICODE_DATA_URL = "https://www.unicode.org/Public/UCD/latest/ucd/UnicodeData.txt"
|
||||||
regex_letter = regex.compile(r'\p{L}')
|
|
||||||
regex_separator = regex.compile(r'\p{Z}')
|
|
||||||
regex_accent_mark = regex.compile(r'\p{M}')
|
|
||||||
regex_punctuation = regex.compile(r'\p{P}')
|
|
||||||
regex_symbol = regex.compile(r'\p{S}')
|
|
||||||
regex_control = regex.compile(r'\p{C}')
|
|
||||||
regex_whitespace = regex.compile(r'\s')
|
|
||||||
|
|
||||||
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
|
|
||||||
|
# see https://www.unicode.org/L2/L1999/UnicodeData.html
|
||||||
|
def unicode_data_iter():
|
||||||
|
res = requests.get(UNICODE_DATA_URL)
|
||||||
|
res.raise_for_status()
|
||||||
|
data = res.content.decode()
|
||||||
|
|
||||||
|
prev = []
|
||||||
|
|
||||||
|
for line in data.splitlines():
|
||||||
|
# ej: 0000;<control>;Cc;0;BN;;;;;N;NULL;;;;
|
||||||
|
line = line.split(";")
|
||||||
|
|
||||||
|
cpt = int(line[0], base=16)
|
||||||
|
assert cpt < MAX_CODEPOINTS
|
||||||
|
|
||||||
|
cpt_lower = int(line[-2] or "0", base=16)
|
||||||
|
assert cpt_lower < MAX_CODEPOINTS
|
||||||
|
|
||||||
|
cpt_upper = int(line[-3] or "0", base=16)
|
||||||
|
assert cpt_upper < MAX_CODEPOINTS
|
||||||
|
|
||||||
|
categ = line[2].strip()
|
||||||
|
assert len(categ) == 2
|
||||||
|
|
||||||
|
bidir = line[4].strip()
|
||||||
|
assert len(categ) == 2
|
||||||
|
|
||||||
|
name = line[1]
|
||||||
|
if name.endswith(", First>"):
|
||||||
|
prev = (cpt, cpt_lower, cpt_upper, categ, bidir)
|
||||||
|
continue
|
||||||
|
if name.endswith(", Last>"):
|
||||||
|
assert prev[1:] == (0, 0, categ, bidir)
|
||||||
|
for c in range(prev[0], cpt):
|
||||||
|
yield (c, cpt_lower, cpt_upper, categ, bidir)
|
||||||
|
|
||||||
|
yield (cpt, cpt_lower, cpt_upper, categ, bidir)
|
||||||
|
|
||||||
|
|
||||||
|
# see definition in unicode.h
|
||||||
|
CODEPOINT_FLAG_UNDEFINED = 0x0001 #
|
||||||
|
CODEPOINT_FLAG_NUMBER = 0x0002 # \p{N}
|
||||||
|
CODEPOINT_FLAG_LETTER = 0x0004 # \p{L}
|
||||||
|
CODEPOINT_FLAG_SEPARATOR = 0x0008 # \p{Z}
|
||||||
|
CODEPOINT_FLAG_MARK = 0x0010 # \p{M}
|
||||||
|
CODEPOINT_FLAG_PUNCTUATION = 0x0020 # \p{P}
|
||||||
|
CODEPOINT_FLAG_SYMBOL = 0x0040 # \p{S}
|
||||||
|
CODEPOINT_FLAG_CONTROL = 0x0080 # \p{C}
|
||||||
|
|
||||||
|
UNICODE_CATEGORY_TO_FLAG = {
|
||||||
|
"Cn": CODEPOINT_FLAG_UNDEFINED, # Undefined
|
||||||
|
"Cc": CODEPOINT_FLAG_CONTROL, # Control
|
||||||
|
"Cf": CODEPOINT_FLAG_CONTROL, # Format
|
||||||
|
"Co": CODEPOINT_FLAG_CONTROL, # Private Use
|
||||||
|
"Cs": CODEPOINT_FLAG_CONTROL, # Surrrogate
|
||||||
|
"Ll": CODEPOINT_FLAG_LETTER, # Lowercase Letter
|
||||||
|
"Lm": CODEPOINT_FLAG_LETTER, # Modifier Letter
|
||||||
|
"Lo": CODEPOINT_FLAG_LETTER, # Other Letter
|
||||||
|
"Lt": CODEPOINT_FLAG_LETTER, # Titlecase Letter
|
||||||
|
"Lu": CODEPOINT_FLAG_LETTER, # Uppercase Letter
|
||||||
|
"L&": CODEPOINT_FLAG_LETTER, # Cased Letter
|
||||||
|
"Mc": CODEPOINT_FLAG_MARK, # Spacing Mark
|
||||||
|
"Me": CODEPOINT_FLAG_MARK, # Enclosing Mark
|
||||||
|
"Mn": CODEPOINT_FLAG_MARK, # Nonspacing Mark
|
||||||
|
"Nd": CODEPOINT_FLAG_NUMBER, # Decimal Number
|
||||||
|
"Nl": CODEPOINT_FLAG_NUMBER, # Letter Number
|
||||||
|
"No": CODEPOINT_FLAG_NUMBER, # Other Number
|
||||||
|
"Pc": CODEPOINT_FLAG_PUNCTUATION, # Connector Punctuation
|
||||||
|
"Pd": CODEPOINT_FLAG_PUNCTUATION, # Dash Punctuation
|
||||||
|
"Pe": CODEPOINT_FLAG_PUNCTUATION, # Close Punctuation
|
||||||
|
"Pf": CODEPOINT_FLAG_PUNCTUATION, # Final Punctuation
|
||||||
|
"Pi": CODEPOINT_FLAG_PUNCTUATION, # Initial Punctuation
|
||||||
|
"Po": CODEPOINT_FLAG_PUNCTUATION, # Other Punctuation
|
||||||
|
"Ps": CODEPOINT_FLAG_PUNCTUATION, # Open Punctuation
|
||||||
|
"Sc": CODEPOINT_FLAG_SYMBOL, # Currency Symbol
|
||||||
|
"Sk": CODEPOINT_FLAG_SYMBOL, # Modifier Symbol
|
||||||
|
"Sm": CODEPOINT_FLAG_SYMBOL, # Math Symbol
|
||||||
|
"So": CODEPOINT_FLAG_SYMBOL, # Other Symbol
|
||||||
|
"Zl": CODEPOINT_FLAG_SEPARATOR, # Line Separator
|
||||||
|
"Zp": CODEPOINT_FLAG_SEPARATOR, # Paragraph Separator
|
||||||
|
"Zs": CODEPOINT_FLAG_SEPARATOR, # Space Separator
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
codepoint_flags = array.array('H', [CODEPOINT_FLAG_UNDEFINED]) * MAX_CODEPOINTS
|
||||||
table_whitespace = []
|
table_whitespace = []
|
||||||
table_lowercase = []
|
table_lowercase = []
|
||||||
table_uppercase = []
|
table_uppercase = []
|
||||||
table_nfd = []
|
table_nfd = []
|
||||||
|
|
||||||
for codepoint in range(MAX_CODEPOINTS):
|
for (cpt, cpt_lower, cpt_upper, categ, bidir) in unicode_data_iter():
|
||||||
# convert codepoint to unicode character
|
# convert codepoint to unicode character
|
||||||
char = chr(codepoint)
|
char = chr(cpt)
|
||||||
|
|
||||||
# regex categories
|
# codepoint category flags
|
||||||
flags = codepoint_flags[codepoint]
|
codepoint_flags[cpt] = UNICODE_CATEGORY_TO_FLAG[categ]
|
||||||
flags.is_number = bool(regex_number.match(char))
|
|
||||||
flags.is_letter = bool(regex_letter.match(char))
|
|
||||||
flags.is_separator = bool(regex_separator.match(char))
|
|
||||||
flags.is_accent_mark = bool(regex_accent_mark.match(char))
|
|
||||||
flags.is_punctuation = bool(regex_punctuation.match(char))
|
|
||||||
flags.is_symbol = bool(regex_symbol.match(char))
|
|
||||||
flags.is_control = bool(regex_control.match(char))
|
|
||||||
flags.is_undefined = bytes(flags)[0] == 0
|
|
||||||
assert (not flags.is_undefined)
|
|
||||||
|
|
||||||
# whitespaces
|
|
||||||
if bool(regex_whitespace.match(char)):
|
|
||||||
table_whitespace.append(codepoint)
|
|
||||||
|
|
||||||
# lowercase conversion
|
# lowercase conversion
|
||||||
lower = ord(char.lower()[0])
|
if cpt_lower:
|
||||||
if codepoint != lower:
|
table_lowercase.append((cpt, cpt_lower))
|
||||||
table_lowercase.append((codepoint, lower))
|
|
||||||
|
|
||||||
# uppercase conversion
|
# uppercase conversion
|
||||||
upper = ord(char.upper()[0])
|
if cpt_upper:
|
||||||
if codepoint != upper:
|
table_uppercase.append((cpt, cpt_upper))
|
||||||
table_uppercase.append((codepoint, upper))
|
|
||||||
|
|
||||||
# NFD normalization
|
# NFD normalization
|
||||||
norm = ord(unicodedata.normalize('NFD', char)[0])
|
norm = ord(unicodedata.normalize('NFD', char)[0])
|
||||||
if codepoint != norm:
|
if cpt != norm:
|
||||||
table_nfd.append((codepoint, norm))
|
table_nfd.append((cpt, norm))
|
||||||
|
|
||||||
|
|
||||||
|
# whitespaces, see "<White_Space>" https://www.unicode.org/Public/UCD/latest/ucd/PropList.txt
|
||||||
|
table_whitespace.extend(range(0x0009, 0x000D + 1))
|
||||||
|
table_whitespace.extend(range(0x2000, 0x200A + 1))
|
||||||
|
table_whitespace.extend([0x0020, 0x0085, 0x00A0, 0x1680, 0x2028, 0x2029, 0x202F, 0x205F, 0x3000])
|
||||||
|
|
||||||
|
|
||||||
|
# sort by codepoint
|
||||||
|
table_whitespace.sort()
|
||||||
|
table_lowercase.sort()
|
||||||
|
table_uppercase.sort()
|
||||||
|
table_nfd.sort()
|
||||||
|
|
||||||
|
|
||||||
# group ranges with same flags
|
# group ranges with same flags
|
||||||
ranges_flags = [(0, codepoint_flags[0])] # start, flags
|
ranges_flags = [(0, codepoint_flags[0])] # start, flags
|
||||||
for codepoint, flags in enumerate(codepoint_flags):
|
for codepoint, flags in enumerate(codepoint_flags):
|
||||||
if bytes(flags) != bytes(ranges_flags[-1][1]):
|
if flags != ranges_flags[-1][1]:
|
||||||
ranges_flags.append((codepoint, flags))
|
ranges_flags.append((codepoint, flags))
|
||||||
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
|
ranges_flags.append((MAX_CODEPOINTS, 0x0000))
|
||||||
|
|
||||||
|
|
||||||
# group ranges with same nfd
|
# group ranges with same nfd
|
||||||
@ -90,8 +150,8 @@ for codepoint, norm in table_nfd:
|
|||||||
ranges_nfd[-1] = (start, codepoint, norm)
|
ranges_nfd[-1] = (start, codepoint, norm)
|
||||||
|
|
||||||
|
|
||||||
# Generate 'unicode-data.cpp'
|
# Generate 'unicode-data.cpp':
|
||||||
|
# python ./scripts//gen-unicode-data.py > unicode-data.cpp
|
||||||
|
|
||||||
def out(line=""):
|
def out(line=""):
|
||||||
print(line, end='\n') # noqa
|
print(line, end='\n') # noqa
|
||||||
@ -110,12 +170,12 @@ out("""\
|
|||||||
|
|
||||||
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
|
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
|
||||||
for codepoint, flags in ranges_flags:
|
for codepoint, flags in ranges_flags:
|
||||||
flags = int.from_bytes(bytes(flags), "little")
|
|
||||||
out("{0x%06X, 0x%04X}," % (codepoint, flags))
|
out("{0x%06X, 0x%04X}," % (codepoint, flags))
|
||||||
out("};\n")
|
out("};\n")
|
||||||
|
|
||||||
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
|
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
|
||||||
out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
|
for codepoint in table_whitespace:
|
||||||
|
out("0x%06X," % codepoint)
|
||||||
out("};\n")
|
out("};\n")
|
||||||
|
|
||||||
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
|
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
|
||||||
|
@ -11,13 +11,15 @@ import logging
|
|||||||
import argparse
|
import argparse
|
||||||
import subprocess
|
import subprocess
|
||||||
import random
|
import random
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
from typing import Callable, Iterator
|
from typing import Callable, Iterator
|
||||||
|
|
||||||
import cffi
|
import cffi
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger("test-tokenizer-random-bpe")
|
|
||||||
|
logger = logging.getLogger("test-tokenizer-random")
|
||||||
|
|
||||||
|
|
||||||
class LibLlama:
|
class LibLlama:
|
||||||
@ -155,9 +157,14 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
|
|||||||
'Cửa Việt', # llama-3, ignore_merges = true
|
'Cửa Việt', # llama-3, ignore_merges = true
|
||||||
'<s>a', # Phi-3 fail
|
'<s>a', # Phi-3 fail
|
||||||
'<unk><|endoftext|><s>', # Phi-3 fail
|
'<unk><|endoftext|><s>', # Phi-3 fail
|
||||||
'a\na', # TODO: Bert fail
|
'a\na', # bert fail
|
||||||
|
'"`', # falcon
|
||||||
|
' \u2e4e', # falcon
|
||||||
|
'a\xa0\xa0\x00b', # jina-v2-es
|
||||||
|
'one <mask>', # jina-v2-es <mask> lstrip=true
|
||||||
'a </s> b', # rstrip phi-3
|
'a </s> b', # rstrip phi-3
|
||||||
'a <mask> b', # lstrip jina-v2
|
'a <mask> b', # lstrip jina-v2
|
||||||
|
'\xa0aC', # deepseek
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -189,17 +196,23 @@ def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
|
|||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
rand.seed(m)
|
rand.seed(m)
|
||||||
words = rand.choices(all_tokens, k=500)
|
words = rand.choices(all_tokens, k=500)
|
||||||
if words[0] == tokenizer.bos_token: # skip spam warning of double BOS
|
if words and words[0] == tokenizer.bos_token: # skip spam warning of double BOS
|
||||||
while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS
|
while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS
|
||||||
words.pop(0)
|
words.pop(0)
|
||||||
if tokenizer.add_bos_token: # drop all starting BOS
|
if tokenizer.add_bos_token: # drop all starting BOS
|
||||||
words.pop(0)
|
words.pop(0)
|
||||||
|
if words and words[-1] == tokenizer.eos_token: # skip spam warning of double EOS
|
||||||
|
while len(words) > 1 and words[-2] == tokenizer.eos_token: # leave one trailing EOS
|
||||||
|
words.pop(-1)
|
||||||
|
if tokenizer.add_bos_token: # drop all trailing EOS
|
||||||
|
words.pop(-1)
|
||||||
yield "".join(words)
|
yield "".join(words)
|
||||||
|
|
||||||
|
|
||||||
def generator_random_chars(iterations=100) -> Iterator[str]:
|
def generator_random_chars(iterations=100) -> Iterator[str]:
|
||||||
"""Brute force random text with simple characters"""
|
"""Brute force random text with simple characters"""
|
||||||
|
|
||||||
|
NUM_WORDS = 400
|
||||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||||
CHARS = list(sorted(set("""
|
CHARS = list(sorted(set("""
|
||||||
ABCDEFGHIJKLMNOPQRSTUVWXYZ
|
ABCDEFGHIJKLMNOPQRSTUVWXYZ
|
||||||
@ -213,12 +226,50 @@ def generator_random_chars(iterations=100) -> Iterator[str]:
|
|||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
rand.seed(m)
|
rand.seed(m)
|
||||||
text = []
|
text = []
|
||||||
num_words = rand.randint(300, 400)
|
for _ in range(NUM_WORDS):
|
||||||
for i in range(num_words):
|
|
||||||
k = rand.randint(1, 7)
|
k = rand.randint(1, 7)
|
||||||
word = rand.choices(CHARS, k=k)
|
word = rand.choices(CHARS, k=k)
|
||||||
space = rand.choice(WHITESPACES)
|
word.append(rand.choice(WHITESPACES))
|
||||||
text.append("".join(word) + space)
|
text.append("".join(word))
|
||||||
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
|
def generator_unicodes() -> Iterator[str]:
|
||||||
|
"""Iterate unicode characters"""
|
||||||
|
|
||||||
|
MAX_CODEPOINTS = 0x30000 # 0x110000
|
||||||
|
|
||||||
|
def _valid(cpt):
|
||||||
|
if cpt >= 0x30000: # unassigned and supplementary
|
||||||
|
return False
|
||||||
|
if 0x00D800 <= cpt <= 0x00F8FF: # Surrogates
|
||||||
|
return False
|
||||||
|
if unicodedata.category(chr(cpt)) == "Cn":
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
characters = [chr(cpt) for cpt in range(1, MAX_CODEPOINTS) if _valid(cpt)]
|
||||||
|
|
||||||
|
yield from characters
|
||||||
|
|
||||||
|
|
||||||
|
def generator_random_unicodes(iterations=100) -> Iterator[str]:
|
||||||
|
"""Brute force random text with unicode characters"""
|
||||||
|
|
||||||
|
NUM_WORDS = 200
|
||||||
|
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||||
|
|
||||||
|
characters = list(generator_unicodes())
|
||||||
|
|
||||||
|
rand = random.Random()
|
||||||
|
for m in range(iterations):
|
||||||
|
rand.seed(m)
|
||||||
|
text = []
|
||||||
|
for _ in range(NUM_WORDS):
|
||||||
|
k = rand.randint(1, 7)
|
||||||
|
word = rand.choices(characters, k=k)
|
||||||
|
word.append(rand.choice(WHITESPACES))
|
||||||
|
text.append("".join(word))
|
||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
@ -256,25 +307,7 @@ def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[s
|
|||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def generator_random_bytes(iterations=100) -> Iterator[str]:
|
def compare_tokenizers(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
|
||||||
"""Brute force random bytes"""
|
|
||||||
|
|
||||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
|
||||||
|
|
||||||
rand = random.Random()
|
|
||||||
for m in range(iterations):
|
|
||||||
rand.seed(m)
|
|
||||||
text = []
|
|
||||||
num_words = rand.randint(300, 400)
|
|
||||||
for i in range(num_words):
|
|
||||||
k = rand.randint(1, 8)
|
|
||||||
word = [chr(r) for r in rand.randbytes(k) if r]
|
|
||||||
word.append(rand.choice(WHITESPACES))
|
|
||||||
text.append("".join(word))
|
|
||||||
yield "".join(text)
|
|
||||||
|
|
||||||
|
|
||||||
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
|
|
||||||
|
|
||||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
||||||
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
||||||
@ -284,20 +317,34 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
|
|||||||
return -1
|
return -1
|
||||||
return min(len(ids1), len(ids2))
|
return min(len(ids1), len(ids2))
|
||||||
|
|
||||||
t0 = time.perf_counter()
|
t_tokenizer1 = 0
|
||||||
|
t_tokenizer2 = 0
|
||||||
|
t_start = time.perf_counter()
|
||||||
|
num_errors = 10
|
||||||
|
|
||||||
logger.info("%s: %s" % (generator.__name__, "ini"))
|
logger.info("%s: %s" % (generator.__name__, "ini"))
|
||||||
for text in generator:
|
for text in generator:
|
||||||
|
# print(repr(text), hex(ord(text[0])), text.encode())
|
||||||
|
t0 = time.perf_counter()
|
||||||
ids1 = func_tokenize1(text)
|
ids1 = func_tokenize1(text)
|
||||||
|
t1 = time.perf_counter()
|
||||||
ids2 = func_tokenize2(text)
|
ids2 = func_tokenize2(text)
|
||||||
|
t2 = time.perf_counter()
|
||||||
|
t_tokenizer1 += t1 - t0
|
||||||
|
t_tokenizer2 += t2 - t1
|
||||||
if ids1 != ids2:
|
if ids1 != ids2:
|
||||||
i = find_first_mismatch(ids1, ids2)
|
i = find_first_mismatch(ids1, ids2)
|
||||||
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
|
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
|
||||||
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
|
||||||
logger.info(" TokenIDs: " + str(ids1))
|
logger.error(" TokenIDs: " + str(ids1))
|
||||||
logger.info(" Expected: " + str(ids2))
|
logger.error(" Expected: " + str(ids2))
|
||||||
raise Exception()
|
# raise Exception()
|
||||||
t1 = time.perf_counter()
|
num_errors += 1
|
||||||
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
|
if num_errors > 10:
|
||||||
|
break
|
||||||
|
|
||||||
|
t_total = time.perf_counter() - t_start
|
||||||
|
logger.info("%s: end, tok1: %.3f tok2: %.3f total: %.3f" % (generator.__name__, t_tokenizer1, t_tokenizer2, t_total))
|
||||||
|
|
||||||
|
|
||||||
def main(argv: list[str] = None):
|
def main(argv: list[str] = None):
|
||||||
@ -307,7 +354,8 @@ def main(argv: list[str] = None):
|
|||||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
args = parser.parse_args(argv)
|
args = parser.parse_args(argv)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
logging.basicConfig(level = logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
logger.info(f"VOCABFILE: '{args.vocab_file}'")
|
||||||
|
|
||||||
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
|
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
||||||
@ -321,18 +369,22 @@ def main(argv: list[str] = None):
|
|||||||
ids = func_tokenize2("a")
|
ids = func_tokenize2("a")
|
||||||
assert 1 <= len(ids) <= 3
|
assert 1 <= len(ids) <= 3
|
||||||
add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
|
add_bos_token = len(ids) > 1 and tokenizer.bos_token_id == ids[0]
|
||||||
|
add_eos_token = len(ids) > 1 and tokenizer.eos_token_id == ids[-1]
|
||||||
tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)
|
tokenizer.add_bos_token = getattr(tokenizer, "add_bos_token", add_bos_token)
|
||||||
|
tokenizer.add_eos_token = getattr(tokenizer, "add_eos_token", add_eos_token)
|
||||||
|
|
||||||
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
|
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
|
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
|
compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text())
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
|
compare_tokenizers(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
|
compare_tokenizers(func_tokenize1, func_tokenize2, generator_unicodes())
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
|
compare_tokenizers(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
|
compare_tokenizers(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
|
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
|
||||||
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
|
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
|
||||||
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
|
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_unicodes(10_000))
|
||||||
|
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
|
||||||
|
compare_tokenizers(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
|
||||||
|
|
||||||
model.free()
|
model.free()
|
||||||
|
|
||||||
@ -340,20 +392,40 @@ def main(argv: list[str] = None):
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# main()
|
# main()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level = logging.DEBUG,
|
||||||
|
format = "%(asctime)s.%(msecs)03d %(name)s %(levelname)s %(message)s",
|
||||||
|
datefmt = "%Y-%m-%d %H:%M:%S",
|
||||||
|
filename = logger.name + ".log",
|
||||||
|
filemode = "a"
|
||||||
|
)
|
||||||
|
|
||||||
path_tokenizers = "./models/tokenizers/"
|
path_tokenizers = "./models/tokenizers/"
|
||||||
path_vocab_format = "./models/ggml-vocab-%s.gguf"
|
path_vocab_format = "./models/ggml-vocab-%s.gguf"
|
||||||
|
|
||||||
# import os
|
# import os
|
||||||
# tokenizers = os.listdir(path_tokenizers)
|
# tokenizers = os.listdir(path_tokenizers)
|
||||||
tokenizers = [
|
tokenizers = [
|
||||||
"llama-spm", # SPM
|
# "llama-spm", # SPM
|
||||||
"phi-3", # SPM
|
# "phi-3", # SPM
|
||||||
"jina-v2-en", # WPM
|
# "bert-bge", # WPM
|
||||||
"bert-bge", # WPM
|
# "jina-v2-en", # WPM
|
||||||
|
"gpt-2", # BPE
|
||||||
|
"llama-bpe", # BPE
|
||||||
|
"falcon", # BPE
|
||||||
|
"starcoder", # BPE
|
||||||
|
"jina-v2-es", # BPE
|
||||||
|
"jina-v2-de", # BPE
|
||||||
|
"jina-v2-code", # BPE
|
||||||
|
"smaug-bpe", # BPE
|
||||||
|
"phi-2", # BPE
|
||||||
|
"deepseek-coder", # BPE
|
||||||
|
"deepseek-llm", # BPE
|
||||||
]
|
]
|
||||||
|
|
||||||
for tokenizer in tokenizers:
|
for tokenizer in tokenizers:
|
||||||
print("\n" + "=" * 50 + "\n" + tokenizer + "\n") # noqa
|
logger.info("=" * 50)
|
||||||
|
logger.info(f"TOKENIZER: '{tokenizer}'")
|
||||||
vocab_file = path_vocab_format % tokenizer
|
vocab_file = path_vocab_format % tokenizer
|
||||||
dir_tokenizer = path_tokenizers + "/" + tokenizer
|
dir_tokenizer = path_tokenizers + "/" + tokenizer
|
||||||
main([vocab_file, dir_tokenizer, "--verbose"])
|
main([vocab_file, dir_tokenizer, "--verbose"])
|
||||||
|
1652
unicode-data.cpp
1652
unicode-data.cpp
File diff suppressed because it is too large
Load Diff
29
unicode.cpp
29
unicode.cpp
@ -226,8 +226,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||||||
assert(offset_end <= cpts.size());
|
assert(offset_end <= cpts.size());
|
||||||
start = offset_end;
|
start = offset_end;
|
||||||
|
|
||||||
|
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||||
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
||||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||||
@ -309,7 +310,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||||||
}
|
}
|
||||||
|
|
||||||
// regex: \s+(?!\S)
|
// regex: \s+(?!\S)
|
||||||
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
|
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
|
||||||
pos += num_whitespaces - 1;
|
pos += num_whitespaces - 1;
|
||||||
_add_token(pos);
|
_add_token(pos);
|
||||||
continue;
|
continue;
|
||||||
@ -344,8 +345,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||||||
assert(offset_end <= cpts.size());
|
assert(offset_end <= cpts.size());
|
||||||
start = offset_end;
|
start = offset_end;
|
||||||
|
|
||||||
|
static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF;
|
||||||
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
auto _get_cpt = [&] (const size_t pos) -> uint32_t {
|
||||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||||
@ -450,7 +452,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||||||
}
|
}
|
||||||
|
|
||||||
// regex: \s+(?!\S)
|
// regex: \s+(?!\S)
|
||||||
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != 0) {
|
if (num_whitespaces > 1 && _get_cpt(pos+num_whitespaces) != OUT_OF_RANGE) {
|
||||||
pos += num_whitespaces - 1;
|
pos += num_whitespaces - 1;
|
||||||
_add_token(pos);
|
_add_token(pos);
|
||||||
continue;
|
continue;
|
||||||
@ -679,10 +681,14 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
|
const auto flags = unicode_cpt_flags(cpts[i]);
|
||||||
|
|
||||||
if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
|
if (flags.is_whitespace) {
|
||||||
text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
|
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
|
||||||
|
//text_collapsed[i] = (char) 0x85; // <Next Line> as whitespace fallback
|
||||||
|
text_collapsed[i] = (char) 0x0B; // <vertical tab> as whitespace fallback
|
||||||
|
} else if (k_ucat_cpt.find(flags.category_flag()) != k_ucat_cpt.end()) {
|
||||||
|
text_collapsed[i] = k_ucat_cpt.at(flags.category_flag());
|
||||||
} else {
|
} else {
|
||||||
text_collapsed[i] = (char) 0xD0; // fallback
|
text_collapsed[i] = (char) 0xD0; // fallback
|
||||||
}
|
}
|
||||||
@ -766,9 +772,16 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||||||
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets);
|
||||||
} else {
|
} else {
|
||||||
// no unicode category used, we can use std::wregex directly
|
// no unicode category used, we can use std::wregex directly
|
||||||
const std::wstring wtext = unicode_wstring_from_utf8(text);
|
|
||||||
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr);
|
||||||
|
|
||||||
|
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
|
||||||
|
std::wstring wtext(cpts.begin(), cpts.end());
|
||||||
|
for (size_t i = 0; i < wtext.size(); ++i) {
|
||||||
|
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) {
|
||||||
|
wtext[i] = 0x0B;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//printf("text: %s\n", text.c_str());
|
//printf("text: %s\n", text.c_str());
|
||||||
//printf("regex_expr: %s\n", regex_expr.c_str());
|
//printf("regex_expr: %s\n", regex_expr.c_str());
|
||||||
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
bpe_offsets = unicode_regex_split_stl(wtext, wregex_expr, bpe_offsets);
|
||||||
|
Loading…
Reference in New Issue
Block a user