From 8d177eddeb8c0b284a47a6840ca820ed846708c1 Mon Sep 17 00:00:00 2001 From: goerch Date: Mon, 21 Aug 2023 17:56:02 +0200 Subject: [PATCH] llama : improve token type support (#2668) * Merge tokenizer fixes into the gguf branch. * Add test vocabularies * Adapt convert-new.py (and fix a clang-cl compiler error on windows) * Improved tokenizer test But does it work on MacOS? * Improve token type support - Added @klosax code to convert.py - Improved token type support in vocabulary * Exclude platform dependent tests * More sentencepiece compatibility by eliminating magic numbers * Restored accidentally removed comment --- convert.py | 30 ++++++-- llama.cpp | 128 ++++++++++++++--------------------- models/ggml-vocab-llama.gguf | Bin 467382 -> 595423 bytes tests/test-tokenizer-1.cpp | 34 ++++++---- 4 files changed, 94 insertions(+), 98 deletions(-) diff --git a/convert.py b/convert.py index df589928b..f680f8596 100755 --- a/convert.py +++ b/convert.py @@ -261,12 +261,12 @@ class BpeVocab: for i, item in enumerate(tokenizer): text: bytes = item.encode("utf-8") score: float = -i - yield text, score + yield text, score, 4 def added_tokens(self) -> Iterable[Tuple[bytes, float]]: for text in self.added_tokens_list: score = -1000.0 - yield text.encode("utf-8"), score + yield text.encode("utf-8"), score, 4 def all_tokens(self) -> Iterable[Tuple[bytes, float]]: yield from self.bpe_tokens() @@ -303,12 +303,28 @@ class SentencePieceVocab: piece = tokenizer.id_to_piece(i) text: bytes = piece.encode("utf-8") score: float = tokenizer.get_score(i) - yield text, score + + toktype = 1 # defualt to normal token type + if tokenizer.is_unknown(i): + toktype = 2 + if tokenizer.is_control(i): + toktype = 3 + + # NOTE: I think added_tokens are user defined. + # ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto + # if tokenizer.is_user_defined(i): toktype = 4 + + if tokenizer.is_unused(i): + toktype = 5 + if tokenizer.is_byte(i): + toktype = 6 + + yield text, score, toktype def added_tokens(self) -> Iterable[Tuple[bytes, float]]: for text in self.added_tokens_list: score = -1000.0 - yield text.encode("utf-8"), score + yield text.encode("utf-8"), score, 4 def all_tokens(self) -> Iterable[Tuple[bytes, float]]: yield from self.sentencepiece_tokens() @@ -721,14 +737,16 @@ class OutputFile: def add_meta_vocab(self, vocab: Vocab) -> None: tokens = [] scores = [] - for text, score in vocab.all_tokens(): + toktypes = [] + for text, score, toktype in vocab.all_tokens(): tokens.append(text) scores.append(score) + toktypes.append(toktype) self.gguf.add_tokenizer_model("llama") self.gguf.add_token_list(tokens) self.gguf.add_token_scores(scores) - #self.gguf.add_token_types(toktypes) # TODO: add this + self.gguf.add_token_types(toktypes) # TODO: added / special tokens diff --git a/llama.cpp b/llama.cpp index ec954b84f..1785025f0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -772,15 +772,16 @@ struct llama_vocab { using id = int32_t; using token = std::string; - struct token_score { + struct token_data { token tok; float score; + int toktype; }; llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; std::unordered_map token_to_id; - std::vector id_to_token; + std::vector id_to_token; // default LLaMA special tokens id special_bos_id = 1; @@ -1507,17 +1508,25 @@ static void llama_model_load_internal( const float * scores = (const float * ) gguf_get_arr_data(ctx, score_idx); + const int toktype_idx = gguf_find_key(ctx, "tokenizer.ggml.token_type"); + if (toktype_idx == -1) { + throw std::runtime_error("cannot find token type list in GGUF file\n"); + } + + const int * toktypes = (const int * ) gguf_get_arr_data(ctx, toktype_idx); + for (uint32_t i = 0; i < hparams.n_vocab; i++) { std::string word = gguf_get_arr_str(ctx, token_idx, i); vocab.token_to_id[word] = i; - auto & tok_score = vocab.id_to_token[i]; - tok_score.tok = std::move(word); - tok_score.score = scores[i]; + auto & token_data = vocab.id_to_token[i]; + token_data.tok = std::move(word); + token_data.score = scores[i]; + token_data.toktype = toktypes[i]; // determine the newline token: 0x0A == 10 == '\n' - if (tok_score.tok == "<0x0A>") { + if (token_data.tok == "<0x0A>") { vocab.linefeed_id = i; } } @@ -2345,92 +2354,57 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { return vocab.type; } -static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { - return token >= 259; - } - - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) { - return token >= 95; - } - - return false; +static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 1; } -static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) { - return token == vocab.special_bos_id; +static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 2; } -static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) { - return token == vocab.special_eos_id; +static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 3; } -static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { - return token == llama_is_bos_token(vocab, token) || token == llama_is_eos_token(vocab, token); - } - - // TODO: improve? - return false; +static bool llama_is_bos_token(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(llama_is_control_token(vocab, id)); + return id == vocab.special_bos_id; } -static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { - return token == 0; - } - - // TODO: improve? - return false; +static bool llama_is_eos_token(const llama_vocab & vocab, llama_token id ) { + GGML_ASSERT(llama_is_control_token(vocab, id)); + return id == vocab.special_eos_id; } -static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) { - GGML_UNUSED(vocab); - GGML_UNUSED(token); - // TODO: improve? - return false; +static bool llama_is_pad_token(const llama_vocab & vocab, llama_token id ) { + GGML_ASSERT(id < 0 || llama_is_control_token(vocab, id)); + return id == vocab.special_pad_id; } -static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token) { - GGML_UNUSED(vocab); - GGML_UNUSED(token); - // TODO: improve? - return false; +static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 4; } -static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) { - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { - return 3 <= token && token < 259; - } - - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) { - return 1 <= token && token < 95; - } - - return false; +static bool llama_is_unused_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 5; } -static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) { - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { - return byte - 3; - } - - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) { - return byte + 32; - } - - return false; +static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { + return vocab.id_to_token[id].toktype == 6; } -static uint8_t llama_char_to_byte(const llama_vocab & vocab, uint8_t ch) { - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) { - return ch + 3; - } +static uint8_t llama_token_to_byte(const llama_vocab & vocab, llama_token id) { + GGML_ASSERT(llama_is_byte_token(vocab, id)); + const auto& token_data = vocab.id_to_token.at(id); + auto buf = token_data.tok.substr(3, 2); + return strtol(buf.c_str(), NULL, 16); +} - if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) { - return ch - 32; - } - - return false; +static llama_token llama_byte_to_token(const llama_vocab & vocab, uint8_t ch) { + char buf[7]; + int result = snprintf(buf, sizeof(buf), "<0x%02X>", ch); + GGML_ASSERT(0 <= result && result < 7); + return vocab.token_to_id.at(buf); } static std::string llama_escape_whitespace(const std::string& text) { @@ -2569,7 +2543,7 @@ private: if (p == rev_merge.end()) { // output any symbols that did not form tokens as bytes. for (int j = 0; j < (int)symbol.n; ++j) { - llama_vocab::id token_id = llama_char_to_byte(vocab_, symbol.text[j]); + llama_vocab::id token_id = llama_byte_to_token(vocab_, symbol.text[j]); output.push_back(token_id); } return; @@ -2595,12 +2569,12 @@ private: return; } - const auto &tok_score = vocab_.id_to_token[(*token).second]; + const auto &tok_data = vocab_.id_to_token[(*token).second]; llama_sp_bigram bigram; bigram.left = left; bigram.right = right; - bigram.score = tok_score.score; + bigram.score = tok_data.score; bigram.size = text.size(); work_queue_.push(bigram); @@ -5109,7 +5083,7 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token if (length < 1) { return -1; } - buf[0] = llama_byte_to_char(model->vocab, token); + buf[0] = llama_token_to_byte(model->vocab, token); return 1; } } diff --git a/models/ggml-vocab-llama.gguf b/models/ggml-vocab-llama.gguf index c50db67dc52023444e607bde7b684da5d631e564..63bfaf672f382c0f5bbcffe54736e2698ef3ac55 100644 GIT binary patch delta 129117 zcmeI(y=qf&0Egk6endOu0tyZe&fUZ-&|ZM(YL-yMPo#^2oI%jp>u_?>$;k`xR{Td@ zU6R1NJi?)fG^c&jlRW2F{Py$KuP=92mzR$oOpCvNR-ae*r<-3No=qn|?tQ%bVLg3X zPut^Y9LM808OP~3mfzRYjmNj=^YMAtubypQK7F@&b8&U`;`zm&zdpHs`+9RW{_%AC z!0#7}qw(SR-?1DY4|B|idY~O2XPo~)-^p%|Gan}7ZYLk-p40DsInLjH0iOSTegLq~BvViN~^N#7T`brD+};nsUQ0v|CRc&|M6d`ANwEwmHM&&@gM({1^AEu$^!hye`Nvw z1|CRc&|M6d`ANwEwmHM&&@n5MQ`yc=DUs-_v_^&L$fBaV#;6MH=3-Dj5ANwEw zmHM&&@n5MQ`yc<6`mz7cd}0{q8+WdZ)KmIEV@L#DP`yc<6`mz7N#7T`brD+};nsUQ0v|CRc&|M6d`ANwEw zmHM&&@gM({1^AEu$^!hye`Nvw1|CRc&|M6d`ANwEwmHM&&@n5MQ`yc=DUs-_v z_^&L$fBaV#;6MH=3-Dj5ANwEwmHM&&@n5MQ`yc<6`mz7cd}0{q8+WdZ)< zzp?=TmHM&&@n5MQ`yc<6`mz7KmIEV@L#DP`yc<6 z`mz7N#7T`br zD+};nsUQ0v|CRc&|M6d`ANwEwmHM&&@gM({1^AEu$^!hye`Nvw1|CRc&|M6d` zANwEwmHM&&@n5MQ`yc=DUs-_v_^&L$fBaV#;6MH=3-Dj5ANwEwmHM&&@n5MQ`yc<6 z`mz7cd}0{q8+WdZ)KmIEV@L#DP`yc<6`mz7>&hswF_@E)WzA7_^X3G`_5?-{W&&bnAZ7t#Rv>2E Kp1{Z+rT_p&>=0T2 diff --git a/tests/test-tokenizer-1.cpp b/tests/test-tokenizer-1.cpp index 5841f7339..a8a7e8898 100644 --- a/tests/test-tokenizer-1.cpp +++ b/tests/test-tokenizer-1.cpp @@ -10,10 +10,6 @@ #include #include -static std::string vocab_type(llama_context * ctx) { - return llama_n_vocab(ctx) == 32000 ? "spm": "bpe"; -} - static std::string escape_whitespace(const std::string& text) { std::string result; bool escaping = false; @@ -91,8 +87,8 @@ int main(int argc, char **argv) { return 2; } } else { - if ((vocab_type(ctx) == "spm" && i <= 258) || - (vocab_type(ctx) == "bpe" && (i == 0 || i >= 100000))) { + // TODO: needs access to token types + if (0 <= i && i < 259) { fprintf(stderr, "%s : info: token %d is string %s and bpe returns tokens %s\n", __func__, i, llama_token_to_str(ctx, i).c_str(), unescape_whitespace(ctx, tokens).c_str()); } else { @@ -103,20 +99,28 @@ int main(int argc, char **argv) { } } - std::wstring_convert, wchar_t> converter; - for (wchar_t ch = 0x0000; ch < 0xffff; ++ch) { - std::wstring wstr(1, ch); - std::string str; - try { - str = converter.to_bytes(wstr); - } catch (std::exception & e) { - continue; +#ifdef _WIN32 + std::wstring_convert, char16_t> u16converter; + for (char16_t ch = 0x0000; ch < 0xffff; ++ch) { + std::u16string u16str(1, ch); + std::string str = u16converter.to_bytes(u16str); + std::vector tokens = llama_tokenize(ctx, escape_whitespace(str).c_str(), false); + if (tokens.size() == 1) { + fprintf(stderr, "%s : info: %s tokenized to %d \n", + __func__, str.c_str(), tokens[0]); } - std::vector tokens = llama_tokenize(ctx, escape_whitespace(str), false); + } + + std::wstring_convert, char32_t> u32converter; + for (char32_t ch = 0x0000; ch < 0x0010ffff; ++ch) { + std::u32string u32str(1, ch); + std::string str = u32converter.to_bytes(u32str); + std::vector tokens = llama_tokenize(ctx, escape_whitespace(str).c_str(), false); if (tokens.size() == 1) { fprintf(stderr, "%s : info: %s tokenized to %d \n", __func__, str.c_str(), tokens[0]); } } +#endif llama_free_model(model); llama_free(ctx);