unicode : improve naming style (#10838)

* unicode : improve naming style

ggml-ci

* cont [no ci]
This commit is contained in:
Georgi Gerganov 2024-12-16 12:31:45 +02:00 committed by GitHub
parent 644fd71b44
commit 08ea539df2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 61 additions and 62 deletions

View File

@ -738,7 +738,7 @@ struct llm_tokenizer_wpm_session {
std::vector<std::string> words(1, ""); std::vector<std::string> words(1, "");
for (const uint32_t cpt : cpts_nfd) { for (const uint32_t cpt : cpts_nfd) {
const auto flags = unicode_cpt_flags(cpt); const auto flags = unicode_cpt_flags_from_cpt(cpt);
if (flags.is_whitespace) { if (flags.is_whitespace) {
if (words.back().size()) { // finish previous word if any if (words.back().size()) { // finish previous word if any

View File

@ -71,15 +71,15 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
throw std::invalid_argument("failed to convert utf8 to codepoint"); throw std::invalid_argument("failed to convert utf8 to codepoint");
} }
//static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cp) { //static std::vector<uint16_t> unicode_cpt_to_utf16(uint32_t cpt) {
// std::vector<uint16_t> result; // std::vector<uint16_t> result;
// if (/* 0x0000 <= cp && */ cp <= 0xffff) { // if (/* 0x0000 <= cpt && */ cpt <= 0xffff) {
// result.emplace_back(cp); // result.emplace_back(cpt);
// return result; // return result;
// } // }
// if (0x10000 <= cp && cp <= 0x10ffff) { // if (0x10000 <= cpt && cpt <= 0x10ffff) {
// result.emplace_back(0xd800 | ((cp - 0x10000) >> 10)); // result.emplace_back(0xd800 | ((cpt - 0x10000) >> 10));
// result.emplace_back(0xdc00 | ((cp - 0x10000) & 0x03ff)); // result.emplace_back(0xdc00 | ((cpt - 0x10000) & 0x03ff));
// return result; // return result;
// } // }
// throw std::invalid_argument("failed to convert codepoint to utf16"); // throw std::invalid_argument("failed to convert codepoint to utf16");
@ -120,8 +120,8 @@ uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset) {
// return result; // return result;
//} //}
static std::vector<codepoint_flags> unicode_cpt_flags_array() { static std::vector<unicode_cpt_flags> unicode_cpt_flags_array() {
std::vector<codepoint_flags> cpt_flags(MAX_CODEPOINTS, codepoint_flags::UNDEFINED); std::vector<unicode_cpt_flags> cpt_flags(MAX_CODEPOINTS, unicode_cpt_flags::UNDEFINED);
assert (unicode_ranges_flags.begin()[0].first == 0); assert (unicode_ranges_flags.begin()[0].first == 0);
assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS); assert (unicode_ranges_flags.begin()[unicode_ranges_flags.size()-1].first == MAX_CODEPOINTS);
@ -253,8 +253,8 @@ static std::vector<size_t> unicode_regex_split_custom_gpt2(const std::string & t
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
}; };
auto _get_flags = [&] (const size_t pos) -> codepoint_flags { auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -371,8 +371,8 @@ static std::vector<size_t> unicode_regex_split_custom_llama3(const std::string &
return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE;
}; };
auto _get_flags = [&] (const size_t pos) -> codepoint_flags { auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags {
return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags(cpts[pos]) : codepoint_flags{}; return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{};
}; };
size_t _prev_end = offset_ini; size_t _prev_end = offset_ini;
@ -572,29 +572,29 @@ static std::vector<size_t> unicode_regex_split_custom(const std::string & text,
// interface // interface
// //
std::string unicode_cpt_to_utf8(uint32_t cp) { std::string unicode_cpt_to_utf8(uint32_t cpt) {
std::string result; std::string result;
if (/* 0x00 <= cp && */ cp <= 0x7f) { if (/* 0x00 <= cpt && */ cpt <= 0x7f) {
result.push_back(cp); result.push_back(cpt);
return result; return result;
} }
if (0x80 <= cp && cp <= 0x7ff) { if (0x80 <= cpt && cpt <= 0x7ff) {
result.push_back(0xc0 | ((cp >> 6) & 0x1f)); result.push_back(0xc0 | ((cpt >> 6) & 0x1f));
result.push_back(0x80 | (cp & 0x3f)); result.push_back(0x80 | (cpt & 0x3f));
return result; return result;
} }
if (0x800 <= cp && cp <= 0xffff) { if (0x800 <= cpt && cpt <= 0xffff) {
result.push_back(0xe0 | ((cp >> 12) & 0x0f)); result.push_back(0xe0 | ((cpt >> 12) & 0x0f));
result.push_back(0x80 | ((cp >> 6) & 0x3f)); result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cp & 0x3f)); result.push_back(0x80 | (cpt & 0x3f));
return result; return result;
} }
if (0x10000 <= cp && cp <= 0x10ffff) { if (0x10000 <= cpt && cpt <= 0x10ffff) {
result.push_back(0xf0 | ((cp >> 18) & 0x07)); result.push_back(0xf0 | ((cpt >> 18) & 0x07));
result.push_back(0x80 | ((cp >> 12) & 0x3f)); result.push_back(0x80 | ((cpt >> 12) & 0x3f));
result.push_back(0x80 | ((cp >> 6) & 0x3f)); result.push_back(0x80 | ((cpt >> 6) & 0x3f));
result.push_back(0x80 | (cp & 0x3f)); result.push_back(0x80 | (cpt & 0x3f));
return result; return result;
} }
@ -624,19 +624,19 @@ std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8) {
return result; return result;
} }
codepoint_flags unicode_cpt_flags(const uint32_t cp) { unicode_cpt_flags unicode_cpt_flags_from_cpt(const uint32_t cpt) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED); static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
static const auto cpt_flags = unicode_cpt_flags_array(); static const auto cpt_flags = unicode_cpt_flags_array();
return cp < cpt_flags.size() ? cpt_flags[cp] : undef; return cpt < cpt_flags.size() ? cpt_flags[cpt] : undef;
} }
codepoint_flags unicode_cpt_flags(const std::string & utf8) { unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8) {
static const codepoint_flags undef(codepoint_flags::UNDEFINED); static const unicode_cpt_flags undef(unicode_cpt_flags::UNDEFINED);
if (utf8.empty()) { if (utf8.empty()) {
return undef; // undefined return undef; // undefined
} }
size_t offset = 0; size_t offset = 0;
return unicode_cpt_flags(unicode_cpt_from_utf8(utf8, offset)); return unicode_cpt_flags_from_cpt(unicode_cpt_from_utf8(utf8, offset));
} }
std::string unicode_byte_to_utf8(uint8_t byte) { std::string unicode_byte_to_utf8(uint8_t byte) {
@ -649,41 +649,41 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8) {
return map.at(utf8); return map.at(utf8);
} }
uint32_t unicode_tolower(uint32_t cp) { uint32_t unicode_tolower(uint32_t cpt) {
// binary search // binary search
auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cp, auto it = std::lower_bound(unicode_map_lowercase.begin(), unicode_map_lowercase.end(), cpt,
[](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) { [](const std::pair<uint32_t, uint32_t> & pair, uint32_t value) {
return pair.first < value; return pair.first < value;
}); });
if (it != unicode_map_lowercase.end() && it->first == cp) { if (it != unicode_map_lowercase.end() && it->first == cpt) {
return it->second; return it->second;
} }
return cp; // Return the original code point if no lowercase mapping is found return cpt; // Return the original code point if no lowercase mapping is found
} }
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) { std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs) {
// unicode categories // unicode categories
static const std::map<std::string, int> k_ucat_enum = { static const std::map<std::string, int> k_ucat_enum = {
{ "\\p{N}", codepoint_flags::NUMBER }, { "\\p{N}", unicode_cpt_flags::NUMBER },
{ "\\p{L}", codepoint_flags::LETTER }, { "\\p{L}", unicode_cpt_flags::LETTER },
{ "\\p{P}", codepoint_flags::PUNCTUATION }, { "\\p{P}", unicode_cpt_flags::PUNCTUATION },
}; };
static const std::map<int, int> k_ucat_cpt = { static const std::map<int, int> k_ucat_cpt = {
{ codepoint_flags::NUMBER, 0xD1 }, { unicode_cpt_flags::NUMBER, 0xD1 },
{ codepoint_flags::LETTER, 0xD2 }, { unicode_cpt_flags::LETTER, 0xD2 },
{ codepoint_flags::PUNCTUATION, 0xD3 }, { unicode_cpt_flags::PUNCTUATION, 0xD3 },
}; };
static const std::map<int, std::string> k_ucat_map = { static const std::map<int, std::string> k_ucat_map = {
{ codepoint_flags::NUMBER, "\x30-\x39" }, // 0-9 { unicode_cpt_flags::NUMBER, "\x30-\x39" }, // 0-9
{ codepoint_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z { unicode_cpt_flags::LETTER, "\x41-\x5A\x61-\x7A" }, // A-Za-z
{ codepoint_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\} { unicode_cpt_flags::PUNCTUATION, "\x21-\x23\x25-\x2A\x2C-\x2F\x3A-\x3B\x3F-\x40\\\x5B-\\\x5D\x5F\\\x7B\\\x7D" }, // !-#%-*,-/:-;?-@\[-\]_\{\}
}; };
// compute collapsed codepoints only if needed by at least one regex // compute collapsed codepoints only if needed by at least one regex
bool need_collapse = false; bool need_collapse = false;
for (auto & regex_expr : regex_exprs) { for (const auto & regex_expr : regex_exprs) {
// search for unicode categories // search for unicode categories
for (const auto & ucat : k_ucat_enum) { for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) { if (std::string::npos != regex_expr.find(ucat.first)) {
@ -709,7 +709,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
continue; continue;
} }
const auto flags = unicode_cpt_flags(cpts[i]); const auto flags = unicode_cpt_flags_from_cpt(cpts[i]);
if (flags.is_whitespace) { if (flags.is_whitespace) {
//NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does. //NOTE: C++ std::regex \s does not mach 0x85, Rust and Python regex does.
@ -725,7 +725,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
std::vector<size_t> bpe_offsets = { cpts.size() }; std::vector<size_t> bpe_offsets = { cpts.size() };
for (auto & regex_expr : regex_exprs) { for (const auto & regex_expr : regex_exprs) {
// first, see if we have an efficient custom regex implementation // first, see if we have an efficient custom regex implementation
auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets); auto tmp = unicode_regex_split_custom(text, regex_expr, bpe_offsets);
@ -739,7 +739,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
// if a unicode category is used in the regex, we use the collapsed text and replace the unicode category // if a unicode category is used in the regex, we use the collapsed text and replace the unicode category
// with the corresponding collapsed representation // with the corresponding collapsed representation
bool use_collapsed = false; bool use_collapsed = false;
for (auto & ucat : k_ucat_enum) { for (const auto & ucat : k_ucat_enum) {
if (std::string::npos != regex_expr.find(ucat.first)) { if (std::string::npos != regex_expr.find(ucat.first)) {
use_collapsed = true; use_collapsed = true;
break; break;
@ -805,7 +805,7 @@ std::vector<std::string> unicode_regex_split(const std::string & text, const std
// std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback
std::wstring wtext(cpts.begin(), cpts.end()); std::wstring wtext(cpts.begin(), cpts.end());
for (size_t i = 0; i < wtext.size(); ++i) { for (size_t i = 0; i < wtext.size(); ++i) {
if (wtext[i] > 0x7F && unicode_cpt_flags(wtext[i]).is_whitespace) { if (wtext[i] > 0x7F && unicode_cpt_flags_from_cpt(wtext[i]).is_whitespace) {
wtext[i] = 0x0B; wtext[i] = 0x0B;
} }
} }

View File

@ -4,9 +4,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
// TODO: prefix all symbols with "llama_" struct unicode_cpt_flags {
struct codepoint_flags {
enum { enum {
UNDEFINED = 0x0001, UNDEFINED = 0x0001,
NUMBER = 0x0002, // regex: \p{N} NUMBER = 0x0002, // regex: \p{N}
@ -35,7 +33,7 @@ struct codepoint_flags {
uint16_t is_nfd : 1; uint16_t is_nfd : 1;
// decode from uint16 // decode from uint16
inline codepoint_flags(const uint16_t flags=0) { inline unicode_cpt_flags(const uint16_t flags = 0) {
*reinterpret_cast<uint16_t*>(this) = flags; *reinterpret_cast<uint16_t*>(this) = flags;
} }
@ -50,18 +48,19 @@ struct codepoint_flags {
size_t unicode_len_utf8(char src); size_t unicode_len_utf8(char src);
std::string unicode_cpt_to_utf8(uint32_t cp); std::string unicode_cpt_to_utf8 (uint32_t cpt);
uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset); uint32_t unicode_cpt_from_utf8(const std::string & utf8, size_t & offset);
std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8); std::vector<uint32_t> unicode_cpts_from_utf8(const std::string & utf8);
std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts); std::vector<uint32_t> unicode_cpts_normalize_nfd(const std::vector<uint32_t> & cpts);
codepoint_flags unicode_cpt_flags(const uint32_t cp); unicode_cpt_flags unicode_cpt_flags_from_cpt (uint32_t cpt);
codepoint_flags unicode_cpt_flags(const std::string & utf8); unicode_cpt_flags unicode_cpt_flags_from_utf8(const std::string & utf8);
std::string unicode_byte_to_utf8(uint8_t byte); std::string unicode_byte_to_utf8(uint8_t byte);
uint8_t unicode_utf8_to_byte(const std::string & utf8); uint8_t unicode_utf8_to_byte(const std::string & utf8);
uint32_t unicode_tolower(uint32_t cp); uint32_t unicode_tolower(uint32_t cpt);
std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs); std::vector<std::string> unicode_regex_split(const std::string & text, const std::vector<std::string> & regex_exprs);