mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
Unicode codepoint flags for custom regexs (#7245)
* Replace CODEPOINT_TYPE_* with codepoint_flags * Update and bugfix brute force random test * Deterministic brute force random test * Unicode normalization NFD * Get rid of BOM
This commit is contained in:
parent
0fc1e820a9
commit
b43272afa2
@ -12576,16 +12576,16 @@ struct llm_tokenizer_wpm {
|
|||||||
// to lowercase, pad chinese characters, pad punctuation
|
// to lowercase, pad chinese characters, pad punctuation
|
||||||
std::string new_str = "";
|
std::string new_str = "";
|
||||||
for (uint32_t code : cpts_nfd) {
|
for (uint32_t code : cpts_nfd) {
|
||||||
int type = unicode_cpt_type(code);
|
const codepoint_flags flags = unicode_cpt_flags(code);
|
||||||
if (type == CODEPOINT_TYPE_ACCENT_MARK || type == CODEPOINT_TYPE_CONTROL) {
|
if (flags.is_accent_mark || flags.is_control) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
code = unicode_tolower(code);
|
code = unicode_tolower(code);
|
||||||
if (type == CODEPOINT_TYPE_SEPARATOR) {
|
if (flags.is_separator || flags.is_whitespace) { //####FIXME: is_separator ?
|
||||||
code = ' ';
|
code = ' ';
|
||||||
}
|
}
|
||||||
std::string s = unicode_cpt_to_utf8(code);
|
std::string s = unicode_cpt_to_utf8(code);
|
||||||
if (type == CODEPOINT_TYPE_PUNCTUATION || is_ascii_punct(code) || is_chinese_char(code)) {
|
if (flags.is_punctuation || is_ascii_punct(code) || is_chinese_char(code)) {
|
||||||
new_str += " ";
|
new_str += " ";
|
||||||
new_str += s;
|
new_str += s;
|
||||||
new_str += " ";
|
new_str += " ";
|
||||||
|
@ -1,64 +1,134 @@
|
|||||||
import regex
|
import regex
|
||||||
|
import ctypes
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
|
|
||||||
def get_matches(regex_expr):
|
class CoodepointFlags (ctypes.Structure):
|
||||||
regex_expr_compiled = regex.compile(regex_expr)
|
_fields_ = [ # see definition in unicode.h
|
||||||
unicode_ranges = []
|
("is_undefined", ctypes.c_uint16, 1),
|
||||||
current_range = None
|
("is_number", ctypes.c_uint16, 1), # regex: \p{N}
|
||||||
|
("is_letter", ctypes.c_uint16, 1), # regex: \p{L}
|
||||||
for codepoint in range(0x110000):
|
("is_separator", ctypes.c_uint16, 1), # regex: \p{Z}
|
||||||
char = chr(codepoint)
|
("is_accent_mark", ctypes.c_uint16, 1), # regex: \p{M}
|
||||||
if regex_expr_compiled.match(char):
|
("is_punctuation", ctypes.c_uint16, 1), # regex: \p{P}
|
||||||
if current_range is None:
|
("is_symbol", ctypes.c_uint16, 1), # regex: \p{S}
|
||||||
current_range = [codepoint, codepoint]
|
("is_control", ctypes.c_uint16, 1), # regex: \p{C}
|
||||||
else:
|
]
|
||||||
current_range[1] = codepoint
|
|
||||||
elif current_range is not None:
|
|
||||||
unicode_ranges.append(tuple(current_range))
|
|
||||||
current_range = None
|
|
||||||
|
|
||||||
if current_range is not None:
|
|
||||||
unicode_ranges.append(tuple(current_range))
|
|
||||||
|
|
||||||
return unicode_ranges
|
|
||||||
|
|
||||||
|
|
||||||
def print_cat(mode, cat, ranges):
|
assert (ctypes.sizeof(CoodepointFlags) == 2)
|
||||||
if mode == "range":
|
|
||||||
print("const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_{} = {{".format(cat)) # noqa: NP100
|
|
||||||
if mode == "map":
|
|
||||||
print("const std::map<uint32_t, uint32_t> unicode_map_{} = {{".format(cat)) # noqa: NP100
|
|
||||||
for i, values in enumerate(ranges):
|
|
||||||
end = ",\n" if (i % 4 == 3 or i + 1 == len(ranges)) else ", "
|
|
||||||
values = ["0x%08X" % value for value in values]
|
|
||||||
print("{" + ", ".join(values) + "}", end=end) # noqa: NP100
|
|
||||||
print("};") # noqa: NP100
|
|
||||||
print("") # noqa: NP100
|
|
||||||
|
|
||||||
|
|
||||||
print_cat("range", "number", get_matches(r'\p{N}'))
|
MAX_CODEPOINTS = 0x110000
|
||||||
print_cat("range", "letter", get_matches(r'\p{L}'))
|
|
||||||
print_cat("range", "separator", get_matches(r'\p{Z}'))
|
|
||||||
print_cat("range", "accent_mark", get_matches(r'\p{M}'))
|
|
||||||
print_cat("range", "punctuation", get_matches(r'\p{P}'))
|
|
||||||
print_cat("range", "symbol", get_matches(r'\p{S}'))
|
|
||||||
print_cat("range", "control", get_matches(r'\p{C}'))
|
|
||||||
|
|
||||||
print_cat("range", "whitespace", get_matches(r'\s'))
|
regex_number = regex.compile(r'\p{N}')
|
||||||
|
regex_letter = regex.compile(r'\p{L}')
|
||||||
|
regex_separator = regex.compile(r'\p{Z}')
|
||||||
|
regex_accent_mark = regex.compile(r'\p{M}')
|
||||||
|
regex_punctuation = regex.compile(r'\p{P}')
|
||||||
|
regex_symbol = regex.compile(r'\p{S}')
|
||||||
|
regex_control = regex.compile(r'\p{C}')
|
||||||
|
regex_whitespace = regex.compile(r'\s')
|
||||||
|
|
||||||
|
codepoint_flags = (CoodepointFlags * MAX_CODEPOINTS)()
|
||||||
|
table_whitespace = []
|
||||||
|
table_lowercase = []
|
||||||
|
table_uppercase = []
|
||||||
|
table_nfd = []
|
||||||
|
|
||||||
map_lowercase = []
|
for codepoint in range(MAX_CODEPOINTS):
|
||||||
map_uppercase = []
|
# convert codepoint to unicode character
|
||||||
for codepoint in range(0x110000):
|
|
||||||
char = chr(codepoint)
|
char = chr(codepoint)
|
||||||
|
|
||||||
|
# regex categories
|
||||||
|
flags = codepoint_flags[codepoint]
|
||||||
|
flags.is_number = bool(regex_number.match(char))
|
||||||
|
flags.is_letter = bool(regex_letter.match(char))
|
||||||
|
flags.is_separator = bool(regex_separator.match(char))
|
||||||
|
flags.is_accent_mark = bool(regex_accent_mark.match(char))
|
||||||
|
flags.is_punctuation = bool(regex_punctuation.match(char))
|
||||||
|
flags.is_symbol = bool(regex_symbol.match(char))
|
||||||
|
flags.is_control = bool(regex_control.match(char))
|
||||||
|
flags.is_undefined = bytes(flags)[0] == 0
|
||||||
|
assert (not flags.is_undefined)
|
||||||
|
|
||||||
|
# whitespaces
|
||||||
|
if bool(regex_whitespace.match(char)):
|
||||||
|
table_whitespace.append(codepoint)
|
||||||
|
|
||||||
|
# lowercase conversion
|
||||||
lower = ord(char.lower()[0])
|
lower = ord(char.lower()[0])
|
||||||
upper = ord(char.upper()[0])
|
|
||||||
if codepoint != lower:
|
if codepoint != lower:
|
||||||
map_lowercase.append((codepoint, lower))
|
table_lowercase.append((codepoint, lower))
|
||||||
|
|
||||||
|
# uppercase conversion
|
||||||
|
upper = ord(char.upper()[0])
|
||||||
if codepoint != upper:
|
if codepoint != upper:
|
||||||
map_uppercase.append((codepoint, upper))
|
table_uppercase.append((codepoint, upper))
|
||||||
print_cat("map", "lowercase", map_lowercase)
|
|
||||||
print_cat("map", "uppercase", map_uppercase)
|
# NFD normalization
|
||||||
|
norm = ord(unicodedata.normalize('NFD', char)[0])
|
||||||
|
if codepoint != norm:
|
||||||
|
table_nfd.append((codepoint, norm))
|
||||||
|
|
||||||
|
|
||||||
# TODO: generate unicode_map_nfd
|
# group ranges with same flags
|
||||||
|
ranges_flags = [(0, codepoint_flags[0])] # start, flags
|
||||||
|
for codepoint, flags in enumerate(codepoint_flags):
|
||||||
|
if bytes(flags) != bytes(ranges_flags[-1][1]):
|
||||||
|
ranges_flags.append((codepoint, flags))
|
||||||
|
ranges_flags.append((MAX_CODEPOINTS, CoodepointFlags()))
|
||||||
|
|
||||||
|
|
||||||
|
# group ranges with same nfd
|
||||||
|
ranges_nfd = [(0, 0, 0)] # start, last, nfd
|
||||||
|
for codepoint, norm in table_nfd:
|
||||||
|
start = ranges_nfd[-1][0]
|
||||||
|
if ranges_nfd[-1] != (start, codepoint - 1, norm):
|
||||||
|
ranges_nfd.append(None)
|
||||||
|
start = codepoint
|
||||||
|
ranges_nfd[-1] = (start, codepoint, norm)
|
||||||
|
|
||||||
|
|
||||||
|
# Generate 'unicode-data.cpp'
|
||||||
|
|
||||||
|
|
||||||
|
def out(line=""):
|
||||||
|
print(line, end='\n') # noqa
|
||||||
|
|
||||||
|
|
||||||
|
out("""\
|
||||||
|
// generated with scripts/gen-unicode-data.py
|
||||||
|
|
||||||
|
#include "unicode-data.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
""")
|
||||||
|
|
||||||
|
out("const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags = { // start, flags // last=next_start-1")
|
||||||
|
for codepoint, flags in ranges_flags:
|
||||||
|
flags = int.from_bytes(bytes(flags), "little")
|
||||||
|
out("{0x%06X, 0x%04X}," % (codepoint, flags))
|
||||||
|
out("};\n")
|
||||||
|
|
||||||
|
out("const std::unordered_set<uint32_t> unicode_set_whitespace = {")
|
||||||
|
out(", ".join("0x%06X" % cpt for cpt in table_whitespace))
|
||||||
|
out("};\n")
|
||||||
|
|
||||||
|
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase = {")
|
||||||
|
for tuple in table_lowercase:
|
||||||
|
out("{0x%06X, 0x%06X}," % tuple)
|
||||||
|
out("};\n")
|
||||||
|
|
||||||
|
out("const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase = {")
|
||||||
|
for tuple in table_uppercase:
|
||||||
|
out("{0x%06X, 0x%06X}," % tuple)
|
||||||
|
out("};\n")
|
||||||
|
|
||||||
|
out("const std::vector<range_nfd> unicode_ranges_nfd = { // start, last, nfd")
|
||||||
|
for triple in ranges_nfd:
|
||||||
|
out("{0x%06X, 0x%06X, 0x%06X}," % triple)
|
||||||
|
out("};\n")
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
# Test libllama tokenizer == AutoTokenizer.
|
# Test libllama tokenizer == AutoTokenizer.
|
||||||
# Brute force random tokens/text generation.
|
# Brute force random words/text generation.
|
||||||
#
|
#
|
||||||
# Sample usage:
|
# Sample usage:
|
||||||
#
|
#
|
||||||
@ -12,10 +12,10 @@ import argparse
|
|||||||
import subprocess
|
import subprocess
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from typing import Iterator
|
from typing import Callable, Iterator
|
||||||
|
|
||||||
import cffi
|
import cffi
|
||||||
from transformers import AutoTokenizer, PreTrainedTokenizerBase
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger("test-tokenizer-random-bpe")
|
logger = logging.getLogger("test-tokenizer-random-bpe")
|
||||||
|
|
||||||
@ -145,28 +145,35 @@ def generator_custom_text() -> Iterator[str]:
|
|||||||
def generator_custom_text_edge_cases() -> Iterator[str]:
|
def generator_custom_text_edge_cases() -> Iterator[str]:
|
||||||
"""Edge cases found while debugging"""
|
"""Edge cases found while debugging"""
|
||||||
yield from [
|
yield from [
|
||||||
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
|
'\x1f-a', # unicode_ranges_control, {0x00001C, 0x00001F}
|
||||||
'¼-a', # unicode_ranges_digit, 0x00BC
|
'¼-a', # unicode_ranges_digit, 0x00BC
|
||||||
'½-a', # unicode_ranges_digit, 0x00BD
|
'½-a', # unicode_ranges_digit, 0x00BD
|
||||||
'¾-a', # unicode_ranges_digit, 0x00BE
|
'¾-a', # unicode_ranges_digit, 0x00BE
|
||||||
'a 〇b', # unicode_ranges_digit, 0x3007
|
'a 〇b', # unicode_ranges_digit, 0x3007
|
||||||
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
|
'Ⅵ-a', # unicode_ranges_digit, {0x00002150, 0x0000218F} // Number Forms
|
||||||
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
|
'\uFEFF//', # unicode_ranges_control, 0xFEFF (BOM)
|
||||||
'<s>a' # TODO: Phi-3 fail
|
'Cửa Việt', # llama-3, ignore_merges = true
|
||||||
|
'<s>a', # TODO: Phi-3 fail
|
||||||
|
'a\na', # TODO: Bert fail
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def generator_random_chars(iterations = 100) -> Iterator[str]:
|
def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
|
||||||
|
"""Brute force check all vocab words"""
|
||||||
|
yield from vocab
|
||||||
|
|
||||||
|
|
||||||
|
def generator_random_chars(iterations=100) -> Iterator[str]:
|
||||||
"""Brute force random text with simple characters"""
|
"""Brute force random text with simple characters"""
|
||||||
|
|
||||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||||
CHARS = list(set("""
|
CHARS = list(sorted(set("""
|
||||||
ABCDEFGHIJKLMNOPQRSTUVWXYZ
|
ABCDEFGHIJKLMNOPQRSTUVWXYZ
|
||||||
abcdefghijklmnopqrstuvwxyz
|
abcdefghijklmnopqrstuvwxyz
|
||||||
ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ
|
ÁÉÍÓÚÀÈÌÒÙÂÊÎÔÛÄËÏÖÜ
|
||||||
áéíóúàèìòùâêîôûäëïöü
|
áéíóúàèìòùâêîôûäëïöü
|
||||||
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
|
.-,*/-+ª!"·$%&/()=?¿[]{}<>\\|@#~½¬~;:_
|
||||||
"""))
|
""")))
|
||||||
|
|
||||||
rand = random.Random()
|
rand = random.Random()
|
||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
@ -181,13 +188,13 @@ def generator_random_chars(iterations = 100) -> Iterator[str]:
|
|||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
|
def generator_random_vocab_chars(vocab: list[str], iterations=100) -> Iterator[str]:
|
||||||
"""Brute force random text with vocab characters"""
|
"""Brute force random text with vocab characters"""
|
||||||
|
|
||||||
vocab_ids = list(tokenizer.vocab.values())
|
vocab_chars = set()
|
||||||
vocab_text = tokenizer.decode(vocab_ids, skip_special_tokens=True)
|
for word in vocab:
|
||||||
vocab_chars = list(set(vocab_text))
|
vocab_chars.update(word)
|
||||||
del vocab_ids, vocab_text
|
vocab_chars = list(sorted(vocab_chars))
|
||||||
|
|
||||||
rand = random.Random()
|
rand = random.Random()
|
||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
@ -196,19 +203,11 @@ def generator_random_vocab_chars(tokenizer: PreTrainedTokenizerBase, iterations
|
|||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations = 100) -> Iterator[str]:
|
def generator_random_vocab_words(vocab: list[str], iterations=100) -> Iterator[str]:
|
||||||
"""Brute force random text from vocab tokens"""
|
"""Brute force random text from vocab words"""
|
||||||
|
|
||||||
space_id = tokenizer.encode(" ", add_special_tokens=False)[0]
|
vocab = [w.strip() for w in vocab]
|
||||||
vocab_ids = list(tokenizer.vocab.values())
|
yield from vocab
|
||||||
vocab_ids = list(sorted(vocab_ids + vocab_ids))
|
|
||||||
for i in range(1, len(vocab_ids), 2):
|
|
||||||
vocab_ids[i] = space_id
|
|
||||||
vocab_tokens = tokenizer.decode(vocab_ids, skip_special_tokens=True)
|
|
||||||
vocab_tokens = vocab_tokens.split(" ")
|
|
||||||
del vocab_ids
|
|
||||||
|
|
||||||
yield from vocab_tokens
|
|
||||||
|
|
||||||
rand = random.Random()
|
rand = random.Random()
|
||||||
for m in range(iterations):
|
for m in range(iterations):
|
||||||
@ -217,14 +216,13 @@ def generator_random_vocab_tokens(tokenizer: PreTrainedTokenizerBase, iterations
|
|||||||
num_words = rand.randint(300, 400)
|
num_words = rand.randint(300, 400)
|
||||||
for i in range(num_words):
|
for i in range(num_words):
|
||||||
k = rand.randint(1, 3)
|
k = rand.randint(1, 3)
|
||||||
tokens = rand.choices(vocab_tokens, k=k)
|
words = rand.choices(vocab, k=k)
|
||||||
tokens = [t.strip(" \n\r\t") for t in tokens]
|
|
||||||
sep = rand.choice(" \n\r\t")
|
sep = rand.choice(" \n\r\t")
|
||||||
text.append("".join(tokens) + sep)
|
text.append("".join(words) + sep)
|
||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def generator_random_bytes(iterations = 100) -> Iterator[str]:
|
def generator_random_bytes(iterations=100) -> Iterator[str]:
|
||||||
"""Brute force random bytes"""
|
"""Brute force random bytes"""
|
||||||
|
|
||||||
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
WHITESPACES = list(" " * 20 + "\n" * 5 + "\r\n" * 5 + "\t" * 5)
|
||||||
@ -242,10 +240,10 @@ def generator_random_bytes(iterations = 100) -> Iterator[str]:
|
|||||||
yield "".join(text)
|
yield "".join(text)
|
||||||
|
|
||||||
|
|
||||||
def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerBase, generator: Iterator[str]):
|
def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, generator: Iterator[str]):
|
||||||
|
|
||||||
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
def find_first_mismatch(ids1: list[int], ids2: list[int]):
|
||||||
for i, (a,b) in enumerate(zip(ids1, ids2)):
|
for i, (a, b) in enumerate(zip(ids1, ids2)):
|
||||||
if a != b:
|
if a != b:
|
||||||
return i
|
return i
|
||||||
if len(ids1) == len(ids2):
|
if len(ids1) == len(ids2):
|
||||||
@ -255,15 +253,12 @@ def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerB
|
|||||||
t0 = time.perf_counter()
|
t0 = time.perf_counter()
|
||||||
logger.info("%s: %s" % (generator.__name__, "ini"))
|
logger.info("%s: %s" % (generator.__name__, "ini"))
|
||||||
for text in generator:
|
for text in generator:
|
||||||
ids1 = model.tokenize(text, add_special=False, parse_special=False)
|
ids1 = func_tokenize1(text)
|
||||||
ids2 = tokenizer.encode(text, add_special_tokens=False)
|
ids2 = func_tokenize2(text)
|
||||||
if ids1 != ids2:
|
if ids1 != ids2:
|
||||||
i = find_first_mismatch(ids1, ids2)
|
i = find_first_mismatch(ids1, ids2)
|
||||||
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
|
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
|
||||||
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
|
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
|
||||||
text2 = tokenizer.decode(ids2, skip_special_tokens=True)
|
|
||||||
assert (text2 in text)
|
|
||||||
logger.info(" Text: " + repr(text2))
|
|
||||||
logger.info(" TokenIDs: " + str(ids1))
|
logger.info(" TokenIDs: " + str(ids1))
|
||||||
logger.info(" Expected: " + str(ids2))
|
logger.info(" Expected: " + str(ids2))
|
||||||
raise Exception()
|
raise Exception()
|
||||||
@ -271,25 +266,37 @@ def test_compare_tokenizer(model: LibLlamaModel, tokenizer: PreTrainedTokenizerB
|
|||||||
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
|
logger.info("%s: end, time: %.3f secs" % (generator.__name__, t1 - t0))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def main(argv: list[str] = None):
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
|
parser.add_argument("vocab_file", help="path to vocab 'gguf' file")
|
||||||
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
parser.add_argument("dir_tokenizer", help="directory containing 'tokenizer.model' file")
|
||||||
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
parser.add_argument("--verbose", action="store_true", help="increase output verbosity")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args(argv)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
|
||||||
|
|
||||||
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=2048))
|
model = LibLlamaModel(LibLlama(), args.vocab_file, mparams=dict(vocab_only=True), cparams=dict(n_ctx=4096))
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
tokenizer = AutoTokenizer.from_pretrained(args.dir_tokenizer)
|
||||||
|
|
||||||
test_compare_tokenizer(model, tokenizer, generator_custom_text())
|
def func_tokenize2(text: str):
|
||||||
test_compare_tokenizer(model, tokenizer, generator_custom_text_edge_cases())
|
return tokenizer.encode(text, add_special_tokens=False)
|
||||||
test_compare_tokenizer(model, tokenizer, generator_random_chars(10_000))
|
|
||||||
test_compare_tokenizer(model, tokenizer, generator_random_vocab_chars(tokenizer, 10_000))
|
parse_special = all(len(func_tokenize2(t)) == 1 for t in tokenizer.all_special_tokens)
|
||||||
test_compare_tokenizer(model, tokenizer, generator_random_vocab_tokens(tokenizer, 10_000))
|
|
||||||
# test_compare_tokenizer(model, tokenizer, generator_random_bytes(10_000)) # FAIL
|
def func_tokenize1(text: str):
|
||||||
|
return model.tokenize(text, add_special=False, parse_special=parse_special)
|
||||||
|
|
||||||
|
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
|
||||||
|
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 10_000))
|
||||||
|
# test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_bytes(10_000)) # FAIL
|
||||||
|
|
||||||
model.free()
|
model.free()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
9138
unicode-data.cpp
9138
unicode-data.cpp
File diff suppressed because it is too large
Load Diff
@ -1,17 +1,20 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
#include <map>
|
|
||||||
#include <utility>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_number;
|
struct range_nfd {
|
||||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_letter;
|
uint32_t first;
|
||||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_separator;
|
uint32_t last;
|
||||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_whitespace;
|
uint32_t nfd;
|
||||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_accent_mark;
|
};
|
||||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_punctuation;
|
|
||||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_symbol;
|
static const uint32_t MAX_CODEPOINTS = 0x110000;
|
||||||
extern const std::vector<std::pair<uint32_t, uint32_t>> unicode_ranges_control;
|
|
||||||
extern const std::multimap<uint32_t, uint32_t> unicode_map_nfd;
|
extern const std::vector<std::pair<uint32_t, uint16_t>> unicode_ranges_flags;
|
||||||
extern const std::map<char32_t, char32_t> unicode_map_lowercase;
|
extern const std::unordered_set<uint32_t> unicode_set_whitespace;
|
||||||
|
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_lowercase;
|
||||||
|
extern const std::unordered_map<uint32_t, uint32_t> unicode_map_uppercase;
|
||||||
|
extern const std::vector<range_nfd> unicode_ranges_nfd;
|
||||||
|
200
unicode.cpp
200
unicode.cpp
@ -1,4 +1,4 @@
|
|||||||
#include "unicode.h"
|
#include "unicode.h"
|
||||||
#include "unicode-data.h"
|
#include "unicode-data.h"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
@ -109,57 +109,49 @@ static uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset)
|
|||||||
// return result;
|
// return result;
|
||||||
//}
|
//}
|
||||||
|
|
||||||
static std::unordered_map<uint32_t, int> unicode_cpt_type_map() {
|
static std::vector<codepoint_flags> unicode_cpt_flags_array() {
|
||||||
std::unordered_map<uint32_t, int> cpt_types;
|
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED);
|
||||||
for (auto p : unicode_ranges_number) {
|
|
||||||
for (auto i = p.first; i <= p.second; ++i) {
|
assert (unicode_ranges_flags.front().first == 0);
|
||||||
cpt_types[i] = CODEPOINT_TYPE_NUMBER;
|
assert (unicode_ranges_flags.back().first == MAX_CODEPOINTS);
|
||||||
|
for (size_t i = 1; i < unicode_ranges_flags.size(); ++i) {
|
||||||
|
const auto range_ini = unicode_ranges_flags[i-1]; // codepoint_ini, flags
|
||||||
|
const auto range_end = unicode_ranges_flags[i]; // codepoint_end, flags
|
||||||
|
for (uint32_t cpt = range_ini.first; cpt < range_end.first; ++cpt) {
|
||||||
|
cpt_flags[cpt] = range_ini.second;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (auto p : unicode_ranges_letter) {
|
|
||||||
for (auto i = p.first; i <= p.second; ++i) {
|
for (auto cpt : unicode_set_whitespace) {
|
||||||
cpt_types[i] = CODEPOINT_TYPE_LETTER;
|
cpt_flags[cpt].is_whitespace = true;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for (auto p : unicode_ranges_separator) {
|
|
||||||
for (auto i = p.first; i <= p.second; ++i) {
|
for (auto p : unicode_map_lowercase) {
|
||||||
cpt_types[i] = CODEPOINT_TYPE_SEPARATOR;
|
cpt_flags[p.second].is_lowercase = true;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for (auto p : unicode_ranges_accent_mark) {
|
|
||||||
for (auto i = p.first; i <= p.second; ++i) {
|
for (auto p : unicode_map_uppercase) {
|
||||||
cpt_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
|
cpt_flags[p.second].is_uppercase = true;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for (auto p : unicode_ranges_punctuation) {
|
|
||||||
for (auto i = p.first; i <= p.second; ++i) {
|
for (auto &range : unicode_ranges_nfd) { // start, last, nfd
|
||||||
cpt_types[i] = CODEPOINT_TYPE_PUNCTUATION;
|
cpt_flags[range.nfd].is_nfd = true;
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for (auto p : unicode_ranges_symbol) {
|
|
||||||
for (auto i = p.first; i <= p.second; ++i) {
|
return cpt_flags;
|
||||||
cpt_types[i] = CODEPOINT_TYPE_SYMBOL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (auto p : unicode_ranges_control) {
|
|
||||||
for (auto i = p.first; i <= p.second; ++i) {
|
|
||||||
cpt_types[i] = CODEPOINT_TYPE_CONTROL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cpt_types;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
||||||
std::unordered_map<uint8_t, std::string> map;
|
std::unordered_map<uint8_t, std::string> map;
|
||||||
for (int ch = u'!'; ch <= u'~'; ++ch) {
|
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
||||||
assert(0 <= ch && ch < 256);
|
assert(0 <= ch && ch < 256);
|
||||||
map[ch] = unicode_cpt_to_utf8(ch);
|
map[ch] = unicode_cpt_to_utf8(ch);
|
||||||
}
|
}
|
||||||
for (int ch = u'¡'; ch <= u'¬'; ++ch) {
|
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
|
||||||
assert(0 <= ch && ch < 256);
|
assert(0 <= ch && ch < 256);
|
||||||
map[ch] = unicode_cpt_to_utf8(ch);
|
map[ch] = unicode_cpt_to_utf8(ch);
|
||||||
}
|
}
|
||||||
for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
|
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
|
||||||
assert(0 <= ch && ch < 256);
|
assert(0 <= ch && ch < 256);
|
||||||
map[ch] = unicode_cpt_to_utf8(ch);
|
map[ch] = unicode_cpt_to_utf8(ch);
|
||||||
}
|
}
|
||||||
@ -175,15 +167,15 @@ static std::unordered_map<uint8_t, std::string> unicode_byte_to_utf8_map() {
|
|||||||
|
|
||||||
static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
|
static std::unordered_map<std::string, uint8_t> unicode_utf8_to_byte_map() {
|
||||||
std::unordered_map<std::string, uint8_t> map;
|
std::unordered_map<std::string, uint8_t> map;
|
||||||
for (int ch = u'!'; ch <= u'~'; ++ch) {
|
for (int ch = 0x21; ch <= 0x7E; ++ch) { // u'!' to u'~'
|
||||||
assert(0 <= ch && ch < 256);
|
assert(0 <= ch && ch < 256);
|
||||||
map[unicode_cpt_to_utf8(ch)] = ch;
|
map[unicode_cpt_to_utf8(ch)] = ch;
|
||||||
}
|
}
|
||||||
for (int ch = u'¡'; ch <= u'¬'; ++ch) {
|
for (int ch = 0xA1; ch <= 0xAC; ++ch) { // u'¡' to u'¬'
|
||||||
assert(0 <= ch && ch < 256);
|
assert(0 <= ch && ch < 256);
|
||||||
map[unicode_cpt_to_utf8(ch)] = ch;
|
map[unicode_cpt_to_utf8(ch)] = ch;
|
||||||
}
|
}
|
||||||
for (int ch = u'®'; ch <= u'ÿ'; ++ch) {
|
for (int ch = 0xAE; ch <= 0xFF; ++ch) { // u'®' to u'ÿ'
|
||||||
assert(0 <= ch && ch < 256);
|
assert(0 <= ch && ch < 256);
|
||||||
map[unicode_cpt_to_utf8(ch)] = ch;
|
map[unicode_cpt_to_utf8(ch)] = ch;
|
||||||
}
|
}
|
||||||
@ -238,8 +230,9 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto _get_cpt_type = [&] (const size_t pos) -> int {
|
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
|
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||||
|
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t _prev_end = offset_ini;
|
size_t _prev_end = offset_ini;
|
||||||
@ -261,7 +254,7 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||||||
|
|
||||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||||
const char32_t cpt = _get_cpt(pos);
|
const char32_t cpt = _get_cpt(pos);
|
||||||
const int cpt_type = _get_cpt_type(pos);
|
const auto flags = _get_flags(pos);
|
||||||
|
|
||||||
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
// regex: 's|'t|'re|'ve|'m|'ll|'d
|
||||||
if (cpt == '\'' && pos+1 < offset_end) {
|
if (cpt == '\'' && pos+1 < offset_end) {
|
||||||
@ -281,39 +274,37 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
|
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||||
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
|
|
||||||
// regex: <space>?\p{L}+
|
// regex: <space>?\p{L}+
|
||||||
if (cpt2_type == CODEPOINT_TYPE_LETTER) {
|
if (flags2.is_letter) {
|
||||||
pos += (cpt == ' ');
|
pos += (cpt == ' ');
|
||||||
while (cpt2_type == CODEPOINT_TYPE_LETTER) {
|
while (flags2.is_letter) {
|
||||||
cpt2_type = _get_cpt_type(++pos);
|
flags2 = _get_flags(++pos);
|
||||||
}
|
}
|
||||||
_add_token(pos);
|
_add_token(pos);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// regex: <space>?\p{N}+
|
// regex: <space>?\p{N}+
|
||||||
if (cpt2_type == CODEPOINT_TYPE_NUMBER) {
|
if (flags2.is_number) {
|
||||||
pos += (cpt == ' ');
|
pos += (cpt == ' ');
|
||||||
while (cpt2_type == CODEPOINT_TYPE_NUMBER) {
|
while (flags2.is_number) {
|
||||||
cpt2_type = _get_cpt_type(++pos);
|
flags2 = _get_flags(++pos);
|
||||||
}
|
}
|
||||||
_add_token(pos);
|
_add_token(pos);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// regex: <space>?[^\s\p{L}\p{N}]+
|
// regex: <space>?[^\s\p{L}\p{N}]+
|
||||||
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
||||||
pos += (cpt == ' ');
|
pos += (cpt == ' ');
|
||||||
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
||||||
cpt2_type = _get_cpt_type(++pos);
|
flags2 = _get_flags(++pos);
|
||||||
cpt2 = _get_cpt(pos);
|
|
||||||
}
|
}
|
||||||
_add_token(pos);
|
_add_token(pos);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t num_whitespaces = 0;
|
size_t num_whitespaces = 0;
|
||||||
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
|
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||||
num_whitespaces++;
|
num_whitespaces++;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -357,8 +348,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||||||
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto _get_cpt_type = [&] (const size_t pos) -> int {
|
auto _get_flags = [&] (const size_t pos) -> codepoint_flags {
|
||||||
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_type(cpts[pos]) : CODEPOINT_TYPE_UNIDENTIFIED;
|
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||||
|
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : undef;
|
||||||
};
|
};
|
||||||
|
|
||||||
size_t _prev_end = offset_ini;
|
size_t _prev_end = offset_ini;
|
||||||
@ -380,7 +372,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||||||
|
|
||||||
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) {
|
||||||
const char32_t cpt = _get_cpt(pos);
|
const char32_t cpt = _get_cpt(pos);
|
||||||
const int cpt_type = _get_cpt_type(pos);
|
const auto flags = _get_flags(pos);
|
||||||
|
|
||||||
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
// regex: (?i:'s|'t|'re|'ve|'m|'ll|'d) // case insensitive
|
||||||
if (cpt == '\'' && pos+1 < offset_end) {
|
if (cpt == '\'' && pos+1 < offset_end) {
|
||||||
@ -401,10 +393,10 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||||||
}
|
}
|
||||||
|
|
||||||
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
|
// regex: [^\r\n\p{L}\p{N}]?\p{L}+ //####FIXME: the first \p{L} is correct?
|
||||||
if (cpt != '\r' && cpt != '\n' && /*cpt_type != CODEPOINT_TYPE_LETTER &&*/ cpt_type != CODEPOINT_TYPE_NUMBER) {
|
if (!(cpt == '\r' || cpt == '\n' || /*flags.is_letter |*/ flags.is_number)) {
|
||||||
if (cpt_type == CODEPOINT_TYPE_LETTER || _get_cpt_type(pos+1) == CODEPOINT_TYPE_LETTER) { // one or more letters
|
if (flags.is_letter || _get_flags(pos+1).is_letter) { // one or more letters
|
||||||
pos++;
|
pos++;
|
||||||
while (_get_cpt_type(pos) == CODEPOINT_TYPE_LETTER) {
|
while (_get_flags(pos).is_letter) {
|
||||||
pos++;
|
pos++;
|
||||||
}
|
}
|
||||||
_add_token(pos);
|
_add_token(pos);
|
||||||
@ -413,9 +405,9 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||||||
}
|
}
|
||||||
|
|
||||||
// regex: \p{N}{1,3}
|
// regex: \p{N}{1,3}
|
||||||
if (cpt_type == CODEPOINT_TYPE_NUMBER) {
|
if (flags.is_number) {
|
||||||
size_t ini = pos;
|
size_t ini = pos;
|
||||||
while (_get_cpt_type(pos) == CODEPOINT_TYPE_NUMBER) {
|
while (_get_flags(pos).is_number) {
|
||||||
if (++pos - ini >= 3 ) {
|
if (++pos - ini >= 3 ) {
|
||||||
_add_token(pos);
|
_add_token(pos);
|
||||||
ini = pos;
|
ini = pos;
|
||||||
@ -426,14 +418,13 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||||||
}
|
}
|
||||||
|
|
||||||
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
// regex: <space>?[^\s\p{L}\p{N}]+[\r\n]*
|
||||||
char32_t cpt2 = (cpt == ' ' ? _get_cpt(pos+1) : cpt);
|
auto flags2 = (cpt == ' ' ? _get_flags(pos+1) : flags);
|
||||||
int cpt2_type = (cpt == ' ' ? _get_cpt_type(pos+1) : cpt_type);
|
if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
||||||
if (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
|
||||||
pos += (cpt == ' ');
|
pos += (cpt == ' ');
|
||||||
while (!unicode_cpt_is_whitespace(cpt2) && cpt2_type != CODEPOINT_TYPE_LETTER && cpt2_type != CODEPOINT_TYPE_NUMBER && cpt2_type != CODEPOINT_TYPE_UNIDENTIFIED) {
|
while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number || flags2.is_undefined)) {
|
||||||
cpt2_type = _get_cpt_type(++pos);
|
flags2 = _get_flags(++pos);
|
||||||
cpt2 = _get_cpt(pos);
|
|
||||||
}
|
}
|
||||||
|
char32_t cpt2 = _get_cpt(pos);
|
||||||
while (cpt2 == '\r' || cpt2 == '\n') {
|
while (cpt2 == '\r' || cpt2 == '\n') {
|
||||||
cpt2 = _get_cpt(++pos);
|
cpt2 = _get_cpt(++pos);
|
||||||
}
|
}
|
||||||
@ -443,7 +434,7 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
|
|||||||
|
|
||||||
size_t num_whitespaces = 0;
|
size_t num_whitespaces = 0;
|
||||||
size_t last_end_r_or_n = 0;
|
size_t last_end_r_or_n = 0;
|
||||||
while (unicode_cpt_is_whitespace(_get_cpt(pos+num_whitespaces))) {
|
while (_get_flags(pos+num_whitespaces).is_whitespace) {
|
||||||
char32_t cpt2 = _get_cpt(pos+num_whitespaces);
|
char32_t cpt2 = _get_cpt(pos+num_whitespaces);
|
||||||
if (cpt2 == '\r' || cpt2 == '\n') {
|
if (cpt2 == '\r' || cpt2 == '\n') {
|
||||||
last_end_r_or_n = pos + num_whitespaces + 1;
|
last_end_r_or_n = pos + num_whitespaces + 1;
|
||||||
@ -589,15 +580,14 @@ std::string unicode_cpt_to_utf8(uint32_t cp) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts) {
|
||||||
std::vector<uint32_t> result;
|
auto comp = [] (const uint32_t cpt, const range_nfd & range) {
|
||||||
result.reserve(cpts.size());
|
return cpt < range.first;
|
||||||
|
};
|
||||||
|
std::vector<uint32_t> result(cpts.size());
|
||||||
for (size_t i = 0; i < cpts.size(); ++i) {
|
for (size_t i = 0; i < cpts.size(); ++i) {
|
||||||
auto it = unicode_map_nfd.find(cpts[i]);
|
const uint32_t cpt = cpts[i];
|
||||||
if (it == unicode_map_nfd.end()) {
|
auto it = std::upper_bound(unicode_ranges_nfd.cbegin(), unicode_ranges_nfd.cend(), cpt, comp) - 1;
|
||||||
result.push_back(cpts[i]);
|
result[i] = (it->first <= cpt && cpt <= it->last) ? it->nfd : cpt;
|
||||||
} else {
|
|
||||||
result.push_back(it->second);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
@ -611,31 +601,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
int unicode_cpt_type(uint32_t cp) {
|
codepoint_flags unicode_cpt_flags(const uint32_t cp) {
|
||||||
static std::unordered_map<uint32_t, int> cpt_types = unicode_cpt_type_map();
|
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||||
const auto it = cpt_types.find(cp);
|
static const auto cpt_flags = unicode_cpt_flags_array();
|
||||||
return it == cpt_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : it->second;
|
return cp < cpt_flags.size() ? cpt_flags[cp] : undef;
|
||||||
}
|
}
|
||||||
|
|
||||||
int unicode_cpt_type(const std::string & utf8) {
|
codepoint_flags unicode_cpt_flags(const std::string & utf8) {
|
||||||
if (utf8.length() == 0) {
|
static const codepoint_flags undef(codepoint_flags::UNDEFINED);
|
||||||
return CODEPOINT_TYPE_UNIDENTIFIED;
|
if (utf8.empty()) {
|
||||||
|
return undef; // undefined
|
||||||
}
|
}
|
||||||
size_t offset = 0;
|
size_t offset = 0;
|
||||||
return unicode_cpt_type(unicode_cpt_from_utf8(utf8, offset));
|
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset));
|
||||||
}
|
|
||||||
|
|
||||||
bool unicode_cpt_is_whitespace(uint32_t cp) {
|
|
||||||
static const std::unordered_set<uint32_t> is_whitespace = [] {
|
|
||||||
std::unordered_set<uint32_t> is_whitespace;
|
|
||||||
for (auto p : unicode_ranges_whitespace) {
|
|
||||||
for (auto i = p.first; i <= p.second; ++i) {
|
|
||||||
is_whitespace.insert(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return is_whitespace;
|
|
||||||
}();
|
|
||||||
return (bool)is_whitespace.count(cp);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string unicode_byte_to_utf8(uint8_t byte) {
|
std::string unicode_byte_to_utf8(uint8_t byte) {
|
||||||
@ -656,21 +634,21 @@ char32_t unicode_tolower(char32_t cp) {
|
|||||||
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
|
||||||
// unicode categories
|
// unicode categories
|
||||||
static const std::map<std::string, int> k_ucat_enum = {
|
static const std::map<std::string, int> k_ucat_enum = {
|
||||||
{ "\\p{N}", CODEPOINT_TYPE_NUMBER },
|
{ "\\p{N}", codepoint_flags::NUMBER },
|
||||||
{ "\\p{L}", CODEPOINT_TYPE_LETTER },
|
{ "\\p{L}", codepoint_flags::LETTER },
|
||||||
{ "\\p{P}", CODEPOINT_TYPE_PUNCTUATION },
|
{ "\\p{P}", codepoint_flags::PUNCTUATION },
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<int, int> k_ucat_cpt = {
|
static const std::map<int, int> k_ucat_cpt = {
|
||||||
{ CODEPOINT_TYPE_NUMBER, 0xD1 },
|
{ codepoint_flags::NUMBER, 0xD1 },
|
||||||
{ CODEPOINT_TYPE_LETTER, 0xD2 },
|
{ codepoint_flags::LETTER, 0xD2 },
|
||||||
{ CODEPOINT_TYPE_PUNCTUATION, 0xD3 },
|
{ codepoint_flags::PUNCTUATION, 0xD3 },
|
||||||
};
|
};
|
||||||
|
|
||||||
static const std::map<int, std::string> k_ucat_map = {
|
static const std::map<int, std::string> k_ucat_map = {
|
||||||
{ CODEPOINT_TYPE_NUMBER, "\x30-\x39" }, // 0-9
|
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9
|
||||||
{ CODEPOINT_TYPE_LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
|
||||||
{ CODEPOINT_TYPE_PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
|
||||||
};
|
};
|
||||||
|
|
||||||
// compute collapsed codepoints only if needed by at least one regex
|
// compute collapsed codepoints only if needed by at least one regex
|
||||||
@ -701,10 +679,10 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int cpt_type = unicode_cpt_type(cpts[i]);
|
const int cpt_flag = unicode_cpt_flags(cpts[i]).category_flag();
|
||||||
|
|
||||||
if (k_ucat_cpt.find(cpt_type) != k_ucat_cpt.end()) {
|
if (k_ucat_cpt.find(cpt_flag) != k_ucat_cpt.end()) {
|
||||||
text_collapsed[i] = k_ucat_cpt.at(cpt_type);
|
text_collapsed[i] = k_ucat_cpt.at(cpt_flag);
|
||||||
} else {
|
} else {
|
||||||
text_collapsed[i] = (char) 0xD0; // fallback
|
text_collapsed[i] = (char) 0xD0; // fallback
|
||||||
}
|
}
|
||||||
|
56
unicode.h
56
unicode.h
@ -4,24 +4,56 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#define CODEPOINT_TYPE_UNIDENTIFIED 0
|
struct codepoint_flags {
|
||||||
#define CODEPOINT_TYPE_NUMBER 1
|
enum {
|
||||||
#define CODEPOINT_TYPE_LETTER 2
|
UNDEFINED = 0x0001,
|
||||||
#define CODEPOINT_TYPE_SEPARATOR 3
|
NUMBER = 0x0002, // regex: \p{N}
|
||||||
#define CODEPOINT_TYPE_ACCENT_MARK 4
|
LETTER = 0x0004, // regex: \p{L}
|
||||||
#define CODEPOINT_TYPE_PUNCTUATION 5
|
SEPARATOR = 0x0008, // regex: \p{Z}
|
||||||
#define CODEPOINT_TYPE_SYMBOL 6
|
ACCENT_MARK = 0x0010, // regex: \p{M}
|
||||||
#define CODEPOINT_TYPE_CONTROL 7
|
PUNCTUATION = 0x0020, // regex: \p{P}
|
||||||
|
SYMBOL = 0x0040, // regex: \p{S}
|
||||||
|
CONTROL = 0x0080, // regex: \p{C}
|
||||||
|
MASK_CATEGORIES = 0x00FF,
|
||||||
|
};
|
||||||
|
|
||||||
|
// codepoint type
|
||||||
|
uint16_t is_undefined : 1;
|
||||||
|
uint16_t is_number : 1; // regex: \p{N}
|
||||||
|
uint16_t is_letter : 1; // regex: \p{L}
|
||||||
|
uint16_t is_separator : 1; // regex: \p{Z}
|
||||||
|
uint16_t is_accent_mark : 1; // regex: \p{M}
|
||||||
|
uint16_t is_punctuation : 1; // regex: \p{P}
|
||||||
|
uint16_t is_symbol : 1; // regex: \p{S}
|
||||||
|
uint16_t is_control : 1; // regex: \p{C}
|
||||||
|
// helper flags
|
||||||
|
uint16_t is_whitespace : 1; // regex: \s
|
||||||
|
uint16_t is_lowercase : 1;
|
||||||
|
uint16_t is_uppercase : 1;
|
||||||
|
uint16_t is_nfd : 1;
|
||||||
|
|
||||||
|
// decode from uint16
|
||||||
|
inline codepoint_flags(const uint16_t flags=0) {
|
||||||
|
*reinterpret_cast<uint16_t*>(this) = flags;
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint16_t as_uint() const {
|
||||||
|
return *reinterpret_cast<const uint16_t*>(this);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline uint16_t category_flag() const {
|
||||||
|
return this->as_uint() & MASK_CATEGORIES;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
std::string unicode_cpt_to_utf8(uint32_t cp);
|
std::string unicode_cpt_to_utf8(uint32_t cp);
|
||||||
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
|
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
|
||||||
|
|
||||||
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
|
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
|
||||||
|
|
||||||
int unicode_cpt_type(uint32_t cp);
|
codepoint_flags unicode_cpt_flags(const uint32_t cp);
|
||||||
int unicode_cpt_type(const std::string & utf8);
|
codepoint_flags unicode_cpt_flags(const std::string & utf8);
|
||||||
|
|
||||||
bool unicode_cpt_is_whitespace(uint32_t cp);
|
|
||||||
|
|
||||||
std::string unicode_byte_to_utf8(uint8_t byte);
|
std::string unicode_byte_to_utf8(uint8_t byte);
|
||||||
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
uint8_t unicode_utf8_to_byte(const std::string & utf8);
|
||||||
|
Loading…
Reference in New Issue
Block a user