tests : multi-thread the tokenizer tests (#5474)

* tests : multi-thread the tokenizer tests

ggml-ci

* unicode : fix data race for unidentified codepoints

ggml-ci

* unicode : minor style fixes

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-02-13 15:14:22 +02:00 committed by GitHub
parent 03bf161eb6
commit cf45252a7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 124 additions and 102 deletions

View File

@ -7857,6 +7857,7 @@ private:
if (p == rev_merge.end()) { if (p == rev_merge.end()) {
// output any symbols that did not form tokens as bytes. // output any symbols that did not form tokens as bytes.
output.reserve(output.size() + symbol.n);
for (int j = 0; j < (int)symbol.n; ++j) { for (int j = 0; j < (int)symbol.n; ++j) {
llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]); llama_vocab::id token_id = llama_byte_to_token(vocab, symbol.text[j]);
output.push_back(token_id); output.push_back(token_id);
@ -8420,6 +8421,7 @@ struct fragment_buffer_variant {
raw_text(_dummy), raw_text(_dummy),
offset(0), offset(0),
length(0) {} length(0) {}
fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length) fragment_buffer_variant(const std::string & _raw_text, int64_t _offset, int64_t _length)
: :
type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT), type(FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT),

View File

@ -4,13 +4,13 @@
#include "console.h" #include "console.h"
#include <cassert> #include <cassert>
#include <codecvt>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <string>
#include <codecvt>
#include <map>
#include <vector>
#include <locale> #include <locale>
#include <string>
#include <thread>
#include <vector>
int main(int argc, char **argv) { int main(int argc, char **argv) {
if (argc < 2) { if (argc < 2) {
@ -74,45 +74,46 @@ int main(int argc, char **argv) {
} }
} }
catch (const std::invalid_argument &) { catch (const std::invalid_argument &) {
fprintf(stderr, "%s : info: utf8 conversion %d '%s'\n", __func__, i, str.c_str()); //fprintf(stderr, "%s : info: utf8 conversion %d '%s'\n", __func__, i, str.c_str());
} }
} }
for (uint32_t cp = 0x0000; cp < 0xffff; ++cp) { // unicode
// NOTE: these exceptions seem to be necessary, because the GPT2 tokenizer doesn't want to interfere with some ASCII control characters {
if ((cp < 0x03 || cp > 0x05) && cp != 0x0b && cp != 0x11 && (cp < 0x13 || cp > 0x17) && cp != 0x19 && (cp < 0x1c || cp > 0x1e) && (cp < 0xd800 || cp > 0xdfff)) { const int nthread = std::thread::hardware_concurrency();
std::string str = " " + codepoint_to_utf8(cp);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false); std::vector<std::thread> threads(nthread);
std::string check = llama_detokenize_bpe(ctx, tokens);
if (str != check) { for (int i = 0; i < nthread; ++i) {
fprintf(stderr, "%s : error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n", threads[i] = std::thread([i, nthread, ctx]() {
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) {
return 3; if (!( // NOLINT
(cp < 0x03 || cp > 0x05) && cp != 0x0b && cp != 0x11 &&
(cp < 0x13 || cp > 0x17) && cp != 0x19 &&
(cp < 0x1c || cp > 0x1e) &&
(cp < 0xd800 || cp > 0xdfff) &&
(cp < 0x00040000 || cp >= 0x000e0000)
)) {
continue;
} }
}
}
// Restrict to assigned unicode planes
// for (uint32_t cp = 0x10000; cp < 0x0010ffff; ++cp) {
for (uint32_t cp = 0x10000; cp < 0x00040000; ++cp) {
std::string str = codepoint_to_utf8(cp); std::string str = codepoint_to_utf8(cp);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false); std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_bpe(ctx, tokens); std::string check = llama_detokenize_bpe(ctx, tokens);
if (str != check) { if (cp != 9601 && str != check) {
fprintf(stderr, "%s : error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n", fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); cp, check.c_str(), check.length(), str.c_str(), str.length());
return 4; std::exit(3);
} }
} }
for (uint32_t cp = 0x000e0000; cp < 0x0010ffff; ++cp) { });
std::string str = codepoint_to_utf8(cp); }
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_bpe(ctx, tokens); for (auto & t : threads) {
if (str != check) { t.join();
fprintf(stderr, "%s : error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
return 4;
} }
} }
llama_free_model(model); llama_free_model(model);
llama_free(ctx); llama_free(ctx);

View File

@ -4,13 +4,13 @@
#include "console.h" #include "console.h"
#include <cassert> #include <cassert>
#include <codecvt>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <string>
#include <codecvt>
#include <map>
#include <vector>
#include <locale> #include <locale>
#include <string>
#include <thread>
#include <vector>
int main(int argc, char **argv) { int main(int argc, char **argv) {
if (argc < 2) { if (argc < 2) {
@ -72,26 +72,33 @@ int main(int argc, char **argv) {
} }
} }
for (uint32_t cp = 0x0000; cp < 0xffff; ++cp) { // unicode
if (cp < 0xd800 || cp > 0xdfff) { {
const int nthread = std::thread::hardware_concurrency();
std::vector<std::thread> threads(nthread);
for (int i = 0; i < nthread; ++i) {
threads[i] = std::thread([i, nthread, ctx]() {
for (uint32_t cp = i; cp < 0x0010ffff; cp += nthread) {
if (cp >= 0xd800 && cp <= 0xdfff) {
continue;
}
std::string str = codepoint_to_utf8(cp); std::string str = codepoint_to_utf8(cp);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false); std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_spm(ctx, tokens); std::string check = llama_detokenize_spm(ctx, tokens);
if (cp != 9601 && str != check) { if (cp != 9601 && str != check) {
fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n", fprintf(stderr, "error: codepoint %x detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); cp, check.c_str(), check.length(), str.c_str(), str.length());
return 3; std::exit(3);
} }
} }
});
} }
for (uint32_t cp = 0x10000; cp < 0x0010ffff; ++cp) {
std::string str = codepoint_to_utf8(cp); for (auto & t : threads) {
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false); t.join();
std::string check = llama_detokenize_spm(ctx, tokens);
if (str != check) {
fprintf(stderr, "%s : error: codepoint %d detokenizes to '%s'(%zu) instead of '%s'(%zu)\n",
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
return 4;
} }
} }

View File

@ -264,26 +264,29 @@ static uint32_t codepoint_from_utf8(const std::string & utf8, size_t & offset) {
offset += 1; offset += 1;
return result; return result;
} }
else if (!(utf8[offset + 0] & 0x40)) { if (!(utf8[offset + 0] & 0x40)) {
throw std::invalid_argument("invalid character"); throw std::invalid_argument("invalid character");
} }
else if (!(utf8[offset + 0] & 0x20)) { if (!(utf8[offset + 0] & 0x20)) {
if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) if (offset + 1 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character"); throw std::invalid_argument("invalid character");
}
auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f); auto result = ((utf8[offset + 0] & 0x1f) << 6) | (utf8[offset + 1] & 0x3f);
offset += 2; offset += 2;
return result; return result;
} }
else if (!(utf8[offset + 0] & 0x10)) { if (!(utf8[offset + 0] & 0x10)) {
if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) if (offset + 2 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character"); throw std::invalid_argument("invalid character");
}
auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f); auto result = ((utf8[offset + 0] & 0x0f) << 12) | ((utf8[offset + 1] & 0x3f) << 6) | (utf8[offset + 2] & 0x3f);
offset += 3; offset += 3;
return result; return result;
} }
else if (!(utf8[offset + 0] & 0x08)) { if (!(utf8[offset + 0] & 0x08)) {
if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) if (offset + 3 >= utf8.size() || ! ((utf8[offset + 1] & 0xc0) == 0x80) || ! ((utf8[offset + 2] & 0xc0) == 0x80) || !((utf8[offset + 3] & 0xc0) == 0x80)) {
throw std::invalid_argument("invalid character"); throw std::invalid_argument("invalid character");
}
auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f); auto result = ((utf8[offset + 0] & 0x07) << 18) | ((utf8[offset + 1] & 0x3f) << 12) | ((utf8[offset + 2] & 0x3f) << 6) | (utf8[offset + 3] & 0x3f);
offset += 4; offset += 4;
return result; return result;
@ -331,21 +334,22 @@ static uint32_t codepoint_from_utf16(const std::vector<uint16_t> & utf16, size_t
offset += 1; offset += 1;
return result; return result;
} }
else {
if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) if (offset + 1 >= utf16.size() || !((utf16[1] & 0xdc00) == 0xdc00)) {
throw std::invalid_argument("invalid character"); throw std::invalid_argument("invalid character");
}
auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff)); auto result = 0x10000 + (((utf16[0] & 0x03ff) << 10) | (utf16[1] & 0x03ff));
offset += 2; offset += 2;
return result; return result;
} }
throw std::invalid_argument("invalid string");
}
static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> & utf16) { static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> & utf16) {
std::vector<uint32_t> result; std::vector<uint32_t> result;
size_t offset = 0; size_t offset = 0;
while (offset < utf16.size()) while (offset < utf16.size()) {
result.push_back(codepoint_from_utf16(utf16, offset)); result.push_back(codepoint_from_utf16(utf16, offset));
}
return result; return result;
} }
@ -361,44 +365,52 @@ static std::vector<uint32_t> codepoints_from_utf16(const std::vector<uint16_t> &
static std::unordered_map<uint32_t, int> codepoint_type_map() { static std::unordered_map<uint32_t, int> codepoint_type_map() {
std::unordered_map<uint32_t, int> codepoint_types; std::unordered_map<uint32_t, int> codepoint_types;
for (auto p : digit_ranges) { for (auto p : digit_ranges) {
for(auto i = p.first; i <= p.second; ++ i) for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_DIGIT; codepoint_types[i] = CODEPOINT_TYPE_DIGIT;
} }
}
for (auto p : letter_ranges) { for (auto p : letter_ranges) {
for(auto i = p.first; i <= p.second; ++ i) for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_LETTER; codepoint_types[i] = CODEPOINT_TYPE_LETTER;
} }
}
for (auto p : whitespace_ranges) { for (auto p : whitespace_ranges) {
for(auto i = p.first; i <= p.second; ++ i) for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_WHITESPACE; codepoint_types[i] = CODEPOINT_TYPE_WHITESPACE;
} }
}
for (auto p : accent_mark_ranges) { for (auto p : accent_mark_ranges) {
for(auto i = p.first; i <= p.second; ++ i) for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_ACCENT_MARK; codepoint_types[i] = CODEPOINT_TYPE_ACCENT_MARK;
} }
}
for (auto p : punctuation_ranges) { for (auto p : punctuation_ranges) {
for(auto i = p.first; i <= p.second; ++ i) for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_PUNCTUATION; codepoint_types[i] = CODEPOINT_TYPE_PUNCTUATION;
} }
}
for (auto p : symbol_ranges) { for (auto p : symbol_ranges) {
for (auto i = p.first; i <= p.second; ++i) for (auto i = p.first; i <= p.second; ++i) {
codepoint_types[i] = CODEPOINT_TYPE_SYMBOL; codepoint_types[i] = CODEPOINT_TYPE_SYMBOL;
} }
}
for (auto p : control_ranges) { for (auto p : control_ranges) {
for(auto i = p.first; i <= p.second; ++ i) for (auto i = p.first; i <= p.second; ++ i) {
codepoint_types[i] = CODEPOINT_TYPE_CONTROL; codepoint_types[i] = CODEPOINT_TYPE_CONTROL;
} }
}
return codepoint_types; return codepoint_types;
} }
static int codepoint_type(uint32_t cp) { static int codepoint_type(uint32_t cp) {
static std::unordered_map<uint32_t, int> codepoint_types = codepoint_type_map(); static std::unordered_map<uint32_t, int> codepoint_types = codepoint_type_map();
return codepoint_types[cp]; return codepoint_types.find(cp) == codepoint_types.end() ? CODEPOINT_TYPE_UNIDENTIFIED : codepoint_types.at(cp);
} }
static int codepoint_type(const std::string & utf8) { static int codepoint_type(const std::string & utf8) {
if (utf8.length() == 0) if (utf8.length() == 0) {
return CODEPOINT_TYPE_UNIDENTIFIED; return CODEPOINT_TYPE_UNIDENTIFIED;
}
size_t offset = 0; size_t offset = 0;
return codepoint_type(codepoint_from_utf8(utf8, offset)); return codepoint_type(codepoint_from_utf8(utf8, offset));
} }