Unicode codepoint flags for custom regexs (#7245)

* Replace CODEPOINT_TYPE_* with codepoint_flags
* Update and bugfix brute force random test
* Deterministic brute force random test
* Unicode normalization NFD
* Get rid of BOM
This commit is contained in:
jaime-m-p 2024-05-18 01:09:13 +02:00 committed by GitHub
parent 0fc1e820a9
commit b43272afa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 7299 additions and 2409 deletions

View File

@ -12576,16 +12576,16 @@ struct llm_tokenizer_wpm {
// to lowercase, pad chinese characters, pad punctuation // to lowercase, pad chinese characters, pad punctuation
std::string new_str = ""; std::string new_str = "";
for (uint32_t code : cpts_nfd) { for (uint32_t code : cpts_nfd) {
int type = unicode_cpt_type(code); const codepoint_flags flags = unicode_cpt_flags(code);
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) { if (flags.is_accent_mark || flags.is_control) {
continue; continue;
} }
code = unicode_tolower(code); code = unicode_tolower(code);
if (type == CODEPOINT_TYPE_SEPARATOR) { if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
code = ' '; code = ' ';
} }
std::string s = unicode_cpt_to_utf8(code); std::string s = unicode_cpt_to_utf8(code);
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) { if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
new_str += " "; new_str += " ";
new_str += s; new_str += s;
new_str += " "; new_str += " ";

View File

@ -1,64 +1,134 @@
import regex import regex
import ctypes
import unicodedata
def get_matches(regex_expr): class CoodepointFlags (ctypes.Structure):
regex_expr_compiled = regex.compile(regex_expr) _fields_ = [ # see definition in unicode.h
unicode_ranges = [] ("is_undefined", ctypes.c_uint16, 1),
current_range = None ("is_number", ctypes.c_uint16, 1), # regex: \p{N}
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
for codepoint in range(0x110000): ("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
char = chr(codepoint) ("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
if regex_expr_compiled.match(char): ("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
if current_range is None: ("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
current_range = [codepoint, codepoint] ("is_control", ctypes.c_uint16, 1), # regex: \p{C}
else: ]
current_range[1] = codepoint
elif current_range is not None:
unicode_ranges.append(tuple(current_range))
current_range = None
if current_range is not None:
unicode_ranges.append(tuple(current_range))
return unicode_ranges
def print_cat(mode, cat, ranges): assert (ctypes.sizeof(CoodepointFlags) == 2)
if mode == "range":
print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) # noqa: NP100
if mode == "map":
print("const std::map<uint32_t, uint32_t> unicode_map_{} = {{".format(cat)) # noqa: NP100
for i, values in enumerate(ranges):
end = ",\n" if (i % 4 == 3 or i + 1 == len(ranges)) else ", "
values = ["0x%08X" % value for value in values]
print("{" + ", ".join(values) + "}", end=end) # noqa: NP100
print("};") # noqa: NP100
print("") # noqa: NP100
print_cat("range", "number", get_matches(r'\p{N}')) MAX_CODEPOINTS = 0x110000
print_cat("range", "letter", get_matches(r'\p{L}'))
print_cat("range", "separator", get_matches(r'\p{Z}'))
print_cat("range", "accent_mark", get_matches(r'\p{M}'))
print_cat("range", "punctuation", get_matches(r'\p{P}'))
print_cat("range", "symbol", get_matches(r'\p{S}'))
print_cat("range", "control", get_matches(r'\p{C}'))
print_cat("range", "whitespace", get_matches(r'\s')) regex_number = regex.compile(r'\p{N}')
regex_letter = regex.compile(r'\p{L}')
regex_separator = regex.compile(r'\p{Z}')
regex_accent_mark = regex.compile(r'\p{M}')
regex_punctuation = regex.compile(r'\p{P}')
regex_symbol = regex.compile(r'\p{S}')
regex_control = regex.compile(r'\p{C}')
regex_whitespace = regex.compile(r'\s')
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
table_whitespace = []
table_lowercase = []
table_uppercase = []
table_nfd = []
map_lowercase = [] for codepoint in range(MAX_CODEPOINTS):
map_uppercase = [] # convert codepoint to unicode character
for codepoint in range(0x110000):
char = chr(codepoint) char = chr(codepoint)
# regex categories
flags = codepoint_flags[codepoint]
flags.is_number = bool(regex_number.match(char))
flags.is_letter = bool(regex_letter.match(char))
flags.is_separator = bool(regex_separator.match(char))
flags.is_accent_mark = bool(regex_accent_mark.match(char))
flags.is_punctuation = bool(regex_punctuation.match(char))
flags.is_symbol = bool(regex_symbol.match(char))
flags.is_control = bool(regex_control.match(char))
flags.is_undefined = bytes(flags)[0] == 0
assert (not flags.is_undefined)
# whitespaces
if bool(regex_whitespace.match(char)):
table_whitespace.append(codepoint)
# lowercase conversion
lower = ord(char.lower()[0]) lower = ord(char.lower()[0])
upper = ord(char.upper()[0])
if codepoint != lower: if codepoint != lower:
map_lowercase.append((codepoint, lower)) table_lowercase.append((codepoint, lower))
# uppercase conversion
upper = ord(char.upper()[0])
if codepoint != upper: if codepoint != upper:
map_uppercase.append((codepoint, upper)) table_uppercase.append((codepoint, upper))
print_cat("map", "lowercase", map_lowercase)
print_cat("map", "uppercase", map_uppercase) # NFD normalization
norm = ord(unicodedata.normalize('NFD', char)[0])
if codepoint != norm:
table_nfd.append((codepoint, norm))
# TODO: generate unicode_map_nfd # group ranges with same flags
ranges_flags = [(0, codepoint_flags[0])] # start, flags
for codepoint, flags in enumerate(codepoint_flags):
if bytes(flags) != bytes(ranges_flags[-1][1]):
ranges_flags.append((codepoint, flags))
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
# group ranges with same nfd
ranges_nfd = [(0, 0, 0)] # start, last, nfd
for codepoint, norm in table_nfd:
start = ranges_nfd[-1][0]
if ranges_nfd[-1] != (start, codepoint - 1, norm):
ranges_nfd.append(None)
start = codepoint
ranges_nfd[-1] = (start, codepoint, norm)
# Generate 'unicode-data.cpp'
def out(line=""):
print(line, end='\n') # noqa
out("""\
// generated with scripts/gen-unicode-data.py
#include "unicode-data.h"
#include <cstdint>
#include <vector>
#include <unordered_map>
#include <unordered_set>
""")
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
for codepoint, flags in ranges_flags:
flags = int.from_bytes(bytes(flags), "little")
out("{0x%06X, 0x%04X}," % (codepoint, flags))
out("};\n")
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
for tuple in table_lowercase:
out("{0x%06X, 0x%06X}," % tuple)
out("};\n")
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
for tuple in table_uppercase:
out("{0x%06X, 0x%06X}," % tuple)
out("};\n")
out("const std::vector<range_nfd> unicode_ranges_nfd = { // start, last, nfd")
for triple in ranges_nfd:
out("{0x%06X, 0x%06X, 0x%06X}," % triple)
out("};\n")

View File

@ -1,5 +1,5 @@
# Test libllama tokenizer == AutoTokenizer. # Test libllama tokenizer == AutoTokenizer.
# Brute force random tokens/text generation. # Brute force random words/text generation.
# #
# Sample usage: # Sample usage:
# #
@ -12,10 +12,10 @@ import argparse
import subprocess import subprocess
import random import random
from typing import Iterator from typing import Callable, Iterator
import cffi import cffi
from transformers import AutoTokenizer, PreTrainedTokenizerBase from transformers import AutoTokenizer
logger = logging.getLogger("test-tokenizer-random-bpe") logger = logging.getLogger("test-tokenizer-random-bpe")
@ -145,28 +145,35 @@ def generator_custom_text() -> Iterator[str]:
def generator_custom_text_edge_cases() -> Iterator[str]: def generator_custom_text_edge_cases() -> Iterator[str]:
"""Edge cases found while debugging""" """Edge cases found while debugging"""
yield from [ yield from [
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F} '\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
'¼-a', # unicode_ranges_digit, 0x00BC '¼-a', # unicode_ranges_digit, 0x00BC
'½-a', # unicode_ranges_digit, 0x00BD '½-a', # unicode_ranges_digit, 0x00BD
'¾-a', # unicode_ranges_digit, 0x00BE '¾-a', # unicode_ranges_digit, 0x00BE
'a b', # unicode_ranges_digit, 0x3007 'a b', # unicode_ranges_digit, 0x3007
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms 'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM) '\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
'<s>a' # TODO: Phi-3 fail 'Cửa Việt', # llama-3, ignore_merges = true
'<s>a', # TODO: Phi-3 fail
'a\na', # TODO: Bert fail
] ]
def generator_random_chars(iterations = 100) -> Iterator[str]: def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
"""Brute force check all vocab words"""
yield from vocab
def generator_random_chars(iterations=100) -> Iterator[str]:
"""Brute force random text with simple characters""" """Brute force random text with simple characters"""
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5) WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
CHARS = list(set(""" CHARS = list(sorted(set("""
ABCDEFGHIJKLMNOPQRSTUVWXYZ ABCDEFGHIJKLMNOPQRSTUVWXYZ
abcdefghijklmnopqrstuvwxyz abcdefghijklmnopqrstuvwxyz
ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ
áéíóúàèìòùâêîôûäëïöü áéíóúàèìòùâêîôûäëïöü
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_ .-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
""")) """)))
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
@ -181,13 +188,13 @@ def generator_random_chars(iterations = 100) -> Iterator[str]:
yield "".join(text) yield "".join(text)
def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]: def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[str]:
"""Brute force random text with vocab characters""" """Brute force random text with vocab characters"""
vocab_ids = list(tokenizer.vocab.values()) vocab_chars = set()
vocab_text = tokenizer.decode(vocab_ids, skip_special_tokens=True) for word in vocab:
vocab_chars = list(set(vocab_text)) vocab_chars.update(word)
del vocab_ids, vocab_text vocab_chars = list(sorted(vocab_chars))
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
@ -196,19 +203,11 @@ def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations
yield "".join(text) yield "".join(text)
def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]: def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[str]:
"""Brute force random text from vocab tokens""" """Brute force random text from vocab words"""
space_id = tokenizer.encode(" ", add_special_tokens=False)[0] vocab = [w.strip() for w in vocab]
vocab_ids = list(tokenizer.vocab.values()) yield from vocab
vocab_ids = list(sorted(vocab_ids + vocab_ids))
for i in range(1, len(vocab_ids), 2):
vocab_ids[i] = space_id
vocab_tokens = tokenizer.decode(vocab_ids, skip_special_tokens=True)
vocab_tokens = vocab_tokens.split(" ")
del vocab_ids
yield from vocab_tokens
rand = random.Random() rand = random.Random()
for m in range(iterations): for m in range(iterations):
@ -217,14 +216,13 @@ def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations
num_words = rand.randint(300, 400) num_words = rand.randint(300, 400)
for i in range(num_words): for i in range(num_words):
k = rand.randint(1, 3) k = rand.randint(1, 3)
tokens = rand.choices(vocab_tokens, k=k) words = rand.choices(vocab, k=k)
tokens = [t.strip(" \n\r\t") for t in tokens]
sep = rand.choice(" \n\r\t") sep = rand.choice(" \n\r\t")
text.append("".join(tokens) + sep) text.append("".join(words) + sep)
yield "".join(text) yield "".join(text)
def generator_random_bytes(iterations = 100) -> Iterator[str]: def generator_random_bytes(iterations=100) -> Iterator[str]:
"""Brute force random bytes""" """Brute force random bytes"""
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5) WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
@ -242,10 +240,10 @@ def generator_random_bytes(iterations = 100) -> Iterator[str]:
yield "".join(text) yield "".join(text)
def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, generator: Iterator[str]): def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
def find_first_mismatch(ids1: list[int], ids2: list[int]): def find_first_mismatch(ids1: list[int], ids2: list[int]):
for i, (a,b) in enumerate(zip(ids1, ids2)): for i, (a, b) in enumerate(zip(ids1, ids2)):
if a != b: if a != b:
return i return i
if len(ids1) == len(ids2): if len(ids1) == len(ids2):
@ -255,15 +253,12 @@ def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerB
t0 = time.perf_counter() t0 = time.perf_counter()
logger.info("%s: %s" % (generator.__name__, "ini")) logger.info("%s: %s" % (generator.__name__, "ini"))
for text in generator: for text in generator:
ids1 = model.tokenize(text, add_special=False, parse_special=False) ids1 = func_tokenize1(text)
ids2 = tokenizer.encode(text, add_special_tokens=False) ids2 = func_tokenize2(text)
if ids1 != ids2: if ids1 != ids2:
i = find_first_mismatch(ids1, ids2) i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1] ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1] ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
text2 = tokenizer.decode(ids2, skip_special_tokens=True)
assert (text2 in text)
logger.info(" Text: " + repr(text2))
logger.info(" TokenIDs: " + str(ids1)) logger.info(" TokenIDs: " + str(ids1))
logger.info(" Expected: " + str(ids2)) logger.info(" Expected: " + str(ids2))
raise Exception() raise Exception()
@ -271,25 +266,37 @@ def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerB
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0)) logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
if __name__ == "__main__": def main(argv: list[str] = None):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("vocab_file", help="path to vocab 'gguf' file") parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file") parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
parser.add_argument("--verbose", action="store_true", help="increase output verbosity") parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
args = parser.parse_args() args = parser.parse_args(argv)
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=2048)) model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer) tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
test_compare_tokenizer(model, tokenizer, generator_custom_text()) def func_tokenize2(text: str):
test_compare_tokenizer(model, tokenizer, generator_custom_text_edge_cases()) return tokenizer.encode(text, add_special_tokens=False)
test_compare_tokenizer(model, tokenizer, generator_random_chars(10_000))
test_compare_tokenizer(model, tokenizer, generator_random_vocab_chars(tokenizer, 10_000)) parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens)
test_compare_tokenizer(model, tokenizer, generator_random_vocab_tokens(tokenizer, 10_000))
# test_compare_tokenizer(model, tokenizer, generator_random_bytes(10_000)) # FAIL def func_tokenize1(text: str):
return model.tokenize(text, add_special=False, parse_special=parse_special)
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 10_000))
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
model.free() model.free()
if __name__ == "__main__":
main()

File diff suppressed because it is too large Load Diff

View File

@ -1,17 +1,20 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include <map>
#include <utility>
#include <vector> #include <vector>
#include <unordered_map>
#include <unordered_set>
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number; struct range_nfd {
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter; uint32_t first;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator; uint32_t last;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace; uint32_t nfd;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark; };
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol; static const uint32_t MAX_CODEPOINTS = 0x110000;
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd; extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
extern const std::map<char32_t, char32_t> unicode_map_lowercase; extern const std::unordered_set<uint32_t> unicode_set_whitespace;
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
extern const std::vector<range_nfd> unicode_ranges_nfd;

View File

@ -1,4 +1,4 @@
#include "unicode.h" #include "unicode.h"
#include "unicode-data.h" #include "unicode-data.h"
#include <cassert> #include <cassert>
@ -109,57 +109,49 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
// return result; // return result;
//} //}
static std::unordered_map<uint32_t, int> unicode_cpt_type_map() { static std::vector<codepoint_flags> unicode_cpt_flags_array() {
std::unordered_map<uint32_t, int> cpt_types; std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
for (auto p : unicode_ranges_number) {
for (auto i = p.first; i <= p.second; ++i) { assert (unicode_ranges_flags.front().first == 0);
cpt_types[i] = CODEPOINT_TYPE_NUMBER; assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
cpt_flags[cpt] = range_ini.second;
} }
} }
for (auto p : unicode_ranges_letter) {
for (auto i = p.first; i <= p.second; ++i) { for (auto cpt : unicode_set_whitespace) {
cpt_types[i] = CODEPOINT_TYPE_LETTER; cpt_flags[cpt].is_whitespace = true;
}
} }
for (auto p : unicode_ranges_separator) {
for (auto i = p.first; i <= p.second; ++i) { for (auto p : unicode_map_lowercase) {
cpt_types[i] = CODEPOINT_TYPE_SEPARATOR; cpt_flags[p.second].is_lowercase = true;
}
} }
for (auto p : unicode_ranges_accent_mark) {
for (auto i = p.first; i <= p.second; ++i) { for (auto p : unicode_map_uppercase) {
cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK; cpt_flags[p.second].is_uppercase = true;
}
} }
for (auto p : unicode_ranges_punctuation) {
for (auto i = p.first; i <= p.second; ++i) { for (auto &range : unicode_ranges_nfd) { // start, last, nfd
cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION; cpt_flags[range.nfd].is_nfd = true;
}
} }
for (auto p : unicode_ranges_symbol) {
for (auto i = p.first; i <= p.second; ++i) { return cpt_flags;
cpt_types[i] = CODEPOINT_TYPE_SYMBOL;
}
}
for (auto p : unicode_ranges_control) {
for (auto i = p.first; i <= p.second; ++i) {
cpt_types[i] = CODEPOINT_TYPE_CONTROL;
}
}
return cpt_types;
} }
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() { static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
std::unordered_map<uint8_t, std::string> map; std::unordered_map<uint8_t, std::string> map;
for (int ch = u'!'; ch <= u'~'; ++ch) { for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch); map[ch] = unicode_cpt_to_utf8(ch);
} }
for (int ch = u'¡'; ch <= u'¬'; ++ch) { for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch); map[ch] = unicode_cpt_to_utf8(ch);
} }
for (int ch = u'®'; ch <= u'ÿ'; ++ch) { for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[ch] = unicode_cpt_to_utf8(ch); map[ch] = unicode_cpt_to_utf8(ch);
} }
@ -175,15 +167,15 @@ static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() { static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
std::unordered_map<std::string, uint8_t> map; std::unordered_map<std::string, uint8_t> map;
for (int ch = u'!'; ch <= u'~'; ++ch) { for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch; map[unicode_cpt_to_utf8(ch)] = ch;
} }
for (int ch = u'¡'; ch <= u'¬'; ++ch) { for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch; map[unicode_cpt_to_utf8(ch)] = ch;
} }
for (int ch = u'®'; ch <= u'ÿ'; ++ch) { for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
assert(0 <= ch && ch < 256); assert(0 <= ch && ch < 256);
map[unicode_cpt_to_utf8(ch)] = ch; map[unicode_cpt_to_utf8(ch)] = ch;
} }
@ -238,8 +230,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
}; };
auto _get_cpt_type = [&] (const size_t pos) -> int { auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -261,7 +254,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const char32_t cpt = _get_cpt(pos); const char32_t cpt = _get_cpt(pos);
const int cpt_type = _get_cpt_type(pos); const auto flags = _get_flags(pos);
// regex: 's|'t|'re|'ve|'m|'ll|'d // regex: 's|'t|'re|'ve|'m|'ll|'d
if (cpt == '\'' && pos+1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
@ -281,39 +274,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
} }
} }
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
// regex: <space>?\p{L}+ // regex: <space>?\p{L}+
if (cpt2_type == CODEPOINT_TYPE_LETTER) { if (flags2.is_letter) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (cpt2_type == CODEPOINT_TYPE_LETTER) { while (flags2.is_letter) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
// regex: <space>?\p{N}+ // regex: <space>?\p{N}+
if (cpt2_type == CODEPOINT_TYPE_NUMBER) { if (flags2.is_number) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (cpt2_type == CODEPOINT_TYPE_NUMBER) { while (flags2.is_number) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
// regex: <space>?[^\s\p{L}\p{N}]+ // regex: <space>?[^\s\p{L}\p{N}]+
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
cpt2 = _get_cpt(pos);
} }
_add_token(pos); _add_token(pos);
continue; continue;
} }
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { while (_get_flags(pos+num_whitespaces).is_whitespace) {
num_whitespaces++; num_whitespaces++;
} }
@ -357,8 +348,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
}; };
auto _get_cpt_type = [&] (const size_t pos) -> int { auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED; static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -380,7 +372,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
const char32_t cpt = _get_cpt(pos); const char32_t cpt = _get_cpt(pos);
const int cpt_type = _get_cpt_type(pos); const auto flags = _get_flags(pos);
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive // regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
if (cpt == '\'' && pos+1 < offset_end) { if (cpt == '\'' && pos+1 < offset_end) {
@ -401,10 +393,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct? // regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) { if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
pos++; pos++;
while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) { while (_get_flags(pos).is_letter) {
pos++; pos++;
} }
_add_token(pos); _add_token(pos);
@ -413,9 +405,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: \p{N}{1,3} // regex: \p{N}{1,3}
if (cpt_type == CODEPOINT_TYPE_NUMBER) { if (flags.is_number) {
size_t ini = pos; size_t ini = pos;
while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) { while (_get_flags(pos).is_number) {
if (++pos - ini >= 3 ) { if (++pos - ini >= 3 ) {
_add_token(pos); _add_token(pos);
ini = pos; ini = pos;
@ -426,14 +418,13 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
} }
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]* // regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt); auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type); if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
pos += (cpt == ' '); pos += (cpt == ' ');
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) { while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
cpt2_type = _get_cpt_type(++pos); flags2 = _get_flags(++pos);
cpt2 = _get_cpt(pos);
} }
char32_t cpt2 = _get_cpt(pos);
while (cpt2 == '\r' || cpt2 == '\n') { while (cpt2 == '\r' || cpt2 == '\n') {
cpt2 = _get_cpt(++pos); cpt2 = _get_cpt(++pos);
} }
@ -443,7 +434,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
size_t num_whitespaces = 0; size_t num_whitespaces = 0;
size_t last_end_r_or_n = 0; size_t last_end_r_or_n = 0;
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) { while (_get_flags(pos+num_whitespaces).is_whitespace) {
char32_t cpt2 = _get_cpt(pos+num_whitespaces); char32_t cpt2 = _get_cpt(pos+num_whitespaces);
if (cpt2 == '\r' || cpt2 == '\n') { if (cpt2 == '\r' || cpt2 == '\n') {
last_end_r_or_n = pos + num_whitespaces + 1; last_end_r_or_n = pos + num_whitespaces + 1;
@ -589,15 +580,14 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
} }
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) { std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
std::vector<uint32_t> result; auto comp = [] (const uint32_t cpt, const range_nfd & range) {
result.reserve(cpts.size()); return cpt < range.first;
};
std::vector<uint32_t> result(cpts.size());
for (size_t i = 0; i < cpts.size(); ++i) { for (size_t i = 0; i < cpts.size(); ++i) {
auto it = unicode_map_nfd.find(cpts[i]); const uint32_t cpt = cpts[i];
if (it == unicode_map_nfd.end()) { auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
result.push_back(cpts[i]); result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
} else {
result.push_back(it->second);
}
} }
return result; return result;
} }
@ -611,31 +601,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result; return result;
} }
int unicode_cpt_type(uint32_t cp) { codepoint_flags unicode_cpt_flags(const uint32_t cp) {
static std::unordered_map<uint32_t, int> cpt_types = unicode_cpt_type_map(); static const codepoint_flags undef(codepoint_flags::UNDEFINED);
const auto it = cpt_types.find(cp); static const auto cpt_flags = unicode_cpt_flags_array();
return it == cpt_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : it->second; return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
} }
int unicode_cpt_type(const std::string & utf8) { codepoint_flags unicode_cpt_flags(const std::string & utf8) {
if (utf8.length() == 0) { static const codepoint_flags undef(codepoint_flags::UNDEFINED);
return CODEPOINT_TYPE_UNIDENTIFIED; if (utf8.empty()) {
return undef; // undefined
} }
size_t offset = 0; size_t offset = 0;
return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset)); return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
}
bool unicode_cpt_is_whitespace(uint32_t cp) {
static const std::unordered_set<uint32_t> is_whitespace = [] {
std::unordered_set<uint32_t> is_whitespace;
for (auto p : unicode_ranges_whitespace) {
for (auto i = p.first; i <= p.second; ++i) {
is_whitespace.insert(i);
}
}
return is_whitespace;
}();
return (bool)is_whitespace.count(cp);
} }
std::string unicode_byte_to_utf8(uint8_t byte) { std::string unicode_byte_to_utf8(uint8_t byte) {
@ -656,21 +634,21 @@ char32_t unicode_tolower(char32_t cp) {
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) { std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories // unicode categories
static const std::map<std::string, int> k_ucat_enum = { static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", CODEPOINT_TYPE_NUMBER }, { "\\p{N}", codepoint_flags::NUMBER },
{ "\\p{L}", CODEPOINT_TYPE_LETTER }, { "\\p{L}", codepoint_flags::LETTER },
{ "\\p{P}", CODEPOINT_TYPE_PUNCTUATION }, { "\\p{P}", codepoint_flags::PUNCTUATION },
}; };
static const std::map<int, int> k_ucat_cpt = { static const std::map<int, int> k_ucat_cpt = {
{ CODEPOINT_TYPE_NUMBER, 0xD1 }, { codepoint_flags::NUMBER, 0xD1 },
{ CODEPOINT_TYPE_LETTER, 0xD2 }, { codepoint_flags::LETTER, 0xD2 },
{ CODEPOINT_TYPE_PUNCTUATION, 0xD3 }, { codepoint_flags::PUNCTUATION, 0xD3 },
}; };
static const std::map<int, std::string> k_ucat_map = { static const std::map<int, std::string> k_ucat_map = {
{ CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9 { codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z { codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} { codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
}; };
// compute collapsed codepoints only if needed by at least one regex // compute collapsed codepoints only if needed by at least one regex
@ -701,10 +679,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue; continue;
} }
const int cpt_type = unicode_cpt_type(cpts[i]); const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) { if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
text_collapsed[i] = k_ucat_cpt.at(cpt_type); text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
} else { } else {
text_collapsed[i] = (char) 0xD0; // fallback text_collapsed[i] = (char) 0xD0; // fallback
} }

View File

@ -4,24 +4,56 @@
#include <string> #include <string>
#include <vector> #include <vector>
#define CODEPOINT_TYPE_UNIDENTIFIED 0 struct codepoint_flags {
#define CODEPOINT_TYPE_NUMBER 1 enum {
#define CODEPOINT_TYPE_LETTER 2 UNDEFINED = 0x0001,
#define CODEPOINT_TYPE_SEPARATOR 3 NUMBER = 0x0002, // regex: \p{N}
#define CODEPOINT_TYPE_ACCENT_MARK 4 LETTER = 0x0004, // regex: \p{L}
#define CODEPOINT_TYPE_PUNCTUATION 5 SEPARATOR = 0x0008, // regex: \p{Z}
#define CODEPOINT_TYPE_SYMBOL 6 ACCENT_MARK = 0x0010, // regex: \p{M}
#define CODEPOINT_TYPE_CONTROL 7 PUNCTUATION = 0x0020, // regex: \p{P}
SYMBOL = 0x0040, // regex: \p{S}
CONTROL = 0x0080, // regex: \p{C}
MASK_CATEGORIES = 0x00FF,
};
// codepoint type
uint16_t is_undefined : 1;
uint16_t is_number : 1; // regex: \p{N}
uint16_t is_letter : 1; // regex: \p{L}
uint16_t is_separator : 1; // regex: \p{Z}
uint16_t is_accent_mark : 1; // regex: \p{M}
uint16_t is_punctuation : 1; // regex: \p{P}
uint16_t is_symbol : 1; // regex: \p{S}
uint16_t is_control : 1; // regex: \p{C}
// helper flags
uint16_t is_whitespace : 1; // regex: \s
uint16_t is_lowercase : 1;
uint16_t is_uppercase : 1;
uint16_t is_nfd : 1;
// decode from uint16
inline codepoint_flags(const uint16_t flags=0) {
*reinterpret_cast<uint16_t*>(this) = flags;
}
inline uint16_t as_uint() const {
return *reinterpret_cast<const uint16_t*>(this);
}
inline uint16_t category_flag() const {
return this->as_uint() & MASK_CATEGORIES;
}
};
std::string unicode_cpt_to_utf8(uint32_t cp); std::string unicode_cpt_to_utf8(uint32_t cp);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8); std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts); std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
int unicode_cpt_type(uint32_t cp); codepoint_flags unicode_cpt_flags(const uint32_t cp);
int unicode_cpt_type(const std::string & utf8); codepoint_flags unicode_cpt_flags(const std::string & utf8);
bool unicode_cpt_is_whitespace(uint32_t cp);
std::string unicode_byte_to_utf8(uint8_t byte); std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8); uint8_t unicode_utf8_to_byte(const std::string & utf8);