From 3b38d48609280aa5f8ab7ea135a4351b2a5ee240 Mon Sep 17 00:00:00 2001 From: jaime-m-p <167997752+jaime-m-p@users.noreply.github.com> Date: Tue, 4 Jun 2024 09:17:17 +0200 Subject: [PATCH] Per token attributes (#7685) * Add per token attributes enum * Using phi-3 for testing 'rstrip' * Using jina-v2 for testing 'lstrip' * Brute force test for 'lstrip' and 'rstrip' * Implement 'rstrip' and 'lstrip' * Update phi-3 GGUF file (obsolete since 917dc8c) * Replace llama_token_type with llama_token_attribs --- llama.cpp | 149 +++++++++++++++++++++++---------- llama.h | 18 +++- models/ggml-vocab-phi-3.gguf | Bin 725846 -> 726019 bytes tests/test-tokenizer-random.py | 50 +++++++---- 4 files changed, 155 insertions(+), 62 deletions(-) diff --git a/llama.cpp b/llama.cpp index b0bf70e20..a3e944874 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2149,12 +2149,12 @@ struct llama_control_vector { struct llama_vocab { using id = int32_t; using token = std::string; - using ttype = llama_token_type; + using tattr = llama_token_attr; struct token_data { token text; float score; - ttype type; + tattr attr; }; enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM; @@ -4750,7 +4750,20 @@ static void llm_load_vocab( auto & token_data = vocab.id_to_token[i]; token_data.text = std::move(word); token_data.score = scores ? scores[i] : 0.0f; - token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL; + token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; + + if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file + switch(toktypes[i]) { + case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break; + case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break; + case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break; + case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break; + case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break; + case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break; + case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break; + } + } } GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size()); @@ -4841,7 +4854,7 @@ static void llm_load_vocab( // build special tokens cache { for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) { - if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) { + if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) { vocab.cache_special_tokens.push_back(id); } } @@ -4871,6 +4884,59 @@ static void llm_load_vocab( LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0); } + + // Handle per token attributes + //NOTE: Each model customizes per token attributes. + //NOTE: Per token attributes are missing from the GGUF file. + //TODO: Extract attributes from GGUF file. + { + auto _contains_any = [] (const std::string &str, const std::vector &substrs) -> bool { + for (auto substr : substrs) { + if (str.find(substr) < std::string::npos) { + return true; + } + } + return false; + }; + + auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) { + uint32_t current = vocab.id_to_token.at(id).attr; + current = value ? (current | attr) : (current & ~attr); + vocab.id_to_token[id].attr = (llama_token_attr) current; + }; + + auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) { + _set_tokenid_attr(vocab.token_to_id.at(token), attr, value); + }; + + std::string model_name; + std::string tokenizer_pre; + + ml.get_key(LLM_KV_GENERAL_NAME, model_name, false); + ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false); + + // model name to lowercase + std::transform(model_name.begin(), model_name.end(), model_name.begin(), + [] (const std::string::value_type x) { + return std::tolower(x); + } + ); + + // set attributes by model/tokenizer name + if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) { + _set_token_attr("", LLAMA_TOKEN_ATTR_LSTRIP, true); + } else if (_contains_any(model_name, {"phi-3", "phi3"})) { + for (auto id : vocab.cache_special_tokens) { + _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true); + } + for (auto token : {""}) { + _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true); + } + for (auto token : {"", "", "<|endoftext|>"}) { + _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false); + } + } + } } static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { @@ -12620,27 +12686,27 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) { static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) { GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL; + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL; } static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) { GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN; + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN; } static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) { GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL; + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL; } static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) { GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE; + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE; } static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) { GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE); - return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED; + return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED; } static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) { @@ -13258,7 +13324,8 @@ struct fragment_buffer_variant { static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list & buffer) { // for each special token for (const llama_vocab::id special_id : vocab.cache_special_tokens) { - const auto & special_token = vocab.id_to_token[special_id].text; + const auto & data = vocab.id_to_token[special_id]; + const auto & special_token = data.text; // for each text fragment std::forward_list::iterator it = buffer.begin(); @@ -13295,13 +13362,22 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< if (match > raw_text_base_offset) { // left const int64_t left_reminder_offset = raw_text_base_offset + 0; - const int64_t left_reminder_length = match - raw_text_base_offset; - buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length); + int64_t left_reminder_length = match - raw_text_base_offset; + + if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) { + while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) { + left_reminder_length--; + } + } + + if (left_reminder_length > 0) { + buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length); + it++; + } #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str()); #endif - it++; } // special token @@ -13310,16 +13386,25 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list< // right if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) { - const int64_t right_reminder_offset = match + special_token.length(); - const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); - buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length); + int64_t right_reminder_offset = match + special_token.length(); + int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length()); + + if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) { + while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) { + right_reminder_offset++; + right_reminder_length--; + } + } + + if (right_reminder_length > 0) { + buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length); + it++; + } #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str()); #endif - it++; - if (source == 0) { buffer.erase_after(buffer.before_begin()); } else { @@ -13365,9 +13450,7 @@ static std::vector llama_tokenize_internal(const llama_vocab & // tokenizer.encode('', add_special_tokens=True) returns [1] // tokenizer.encode('', add_special_tokens=False) returns [] - static const bool rtrim = true; //TODO: as param bool is_prev_special = false; - bool special_token_rtrim = false; if (add_special && vocab.special_add_bos != 0) { GGML_ASSERT(vocab.special_bos_id != -1); @@ -13377,25 +13460,8 @@ static std::vector llama_tokenize_internal(const llama_vocab & for (const auto & fragment : fragment_buffer) { if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { - // without adding this leading whitespace, we do not get the same results as the original tokenizer - - // TODO: It's likely possible to get rid of this string copy entirely - // by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer - // and passing 'add space prefix' as bool argument - // auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length); - if (special_token_rtrim) { - size_t num_whitespaces = 0; - while (isspace(raw_text[num_whitespaces])) { - num_whitespaces++; - } - if (num_whitespaces == raw_text.size()) { - continue; // skip if all whitespaces - } - raw_text = raw_text.substr(num_whitespaces); - } - if (vocab.add_space_prefix) { if (!output.size() || is_prev_special) { // prefix with space if first token raw_text = " " + raw_text; @@ -13411,11 +13477,6 @@ static std::vector llama_tokenize_internal(const llama_vocab & } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) output.push_back(fragment.token); is_prev_special = true; - // phi-3 special tokens without rtrim, works fine for llama-spm too - special_token_rtrim = rtrim - && fragment.token != vocab.special_bos_id - && fragment.token != vocab.special_unk_id - && fragment.token != vocab.special_eos_id; } } @@ -18221,9 +18282,9 @@ float llama_token_get_score(const struct llama_model * model, llama_token token) return model->vocab.id_to_token[token].score; } -llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) { +llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) { GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE); - return model->vocab.id_to_token[token].type; + return model->vocab.id_to_token[token].attr; } bool llama_token_is_eog(const struct llama_model * model, llama_token token) { diff --git a/llama.h b/llama.h index 95105c28e..a78ccdaf5 100644 --- a/llama.h +++ b/llama.h @@ -97,7 +97,7 @@ extern "C" { LLAMA_ROPE_TYPE_GLM = 4, }; - enum llama_token_type { + enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file LLAMA_TOKEN_TYPE_UNDEFINED = 0, LLAMA_TOKEN_TYPE_NORMAL = 1, LLAMA_TOKEN_TYPE_UNKNOWN = 2, @@ -107,6 +107,20 @@ extern "C" { LLAMA_TOKEN_TYPE_BYTE = 6, }; + enum llama_token_attr { + LLAMA_TOKEN_ATTR_UNDEFINED = 0, + LLAMA_TOKEN_ATTR_UNKNOWN = 1 << 1, + LLAMA_TOKEN_ATTR_UNUSED = 1 << 2, + LLAMA_TOKEN_ATTR_NORMAL = 1 << 3, + LLAMA_TOKEN_ATTR_CONTROL = 1 << 4, // SPECIAL? + LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 5, + LLAMA_TOKEN_ATTR_BYTE = 1 << 6, + LLAMA_TOKEN_ATTR_NORMALIZED = 1 << 7, + LLAMA_TOKEN_ATTR_LSTRIP = 1 << 8, + LLAMA_TOKEN_ATTR_RSTRIP = 1 << 9, + LLAMA_TOKEN_ATTR_SINGLE_WORD = 1 << 10, + }; + // model file types enum llama_ftype { LLAMA_FTYPE_ALL_F32 = 0, @@ -821,7 +835,7 @@ extern "C" { LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); - LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token); + LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token); // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); diff --git a/models/ggml-vocab-phi-3.gguf b/models/ggml-vocab-phi-3.gguf index f8022a385e4aa48ca10d40d0f079a25365a7be78..745be416a798a1e7a2effaa935bc903edaaf303d 100644 GIT binary patch delta 576 zcmcb%Os9E;j-b1Hs2ei_0!U31tWws5@(VIDjrEH13sUuplM{0?^V0S5i!#$Q^AdC7 zHZQ#0~&Qc{Zyafln?5I05>&w&|a zQe(#jx1hMPxFj{V#*Q5>0(2ajPMAhh9O7m*b{uf+rNya5HFn#S*u7Sj-b1Hs2ei_0*Fr(tlD@)fw7sLshypP5r{!FGZ3=?F)I+WZD(ghKMD)h@nC3_C{s) zm(1Hc)!AP&PoJR9F5h0F&Az=vo8#;I=?kSfk5Atq&1o?`Lx$6klL=@gD6E*buan{Y n&OCju3@1 Iterator[str]: 'a', # Phi-3 fail '<|endoftext|>', # Phi-3 fail 'a\na', # TODO: Bert fail + 'a b', # rstrip phi-3 + 'a b', # lstrip jina-v2 ] -def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]: - special_tokens = set(tokenizer.all_special_tokens) - special_tokens.update([" ", "\n", "\t", "-", "!", "one", "1", "", ""]) - special_tokens = list(sorted(special_tokens)) +def generator_vocab_words(vocab: list[str]) -> Iterator[str]: + """Brute force check all vocab words""" + yield from vocab + + +def generator_added_lr_strip(tokenizer) -> Iterator[str]: + WHITESPACES = ["", " ", " ", " "] + special_tokens = list(tokenizer.all_special_tokens) + added_tokens = list(tokenizer.added_tokens_encoder) + all_tokens = list(sorted(set(special_tokens + added_tokens))) + for token in all_tokens: + for lstrip in WHITESPACES: + for rstrip in WHITESPACES: + yield lstrip + token + rstrip + yield "a" + lstrip + token + rstrip + yield lstrip + token + rstrip + "z" + yield "a" + lstrip + token + rstrip + "z" + + +def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]: + special_tokens = list(tokenizer.all_special_tokens) + added_tokens = list(tokenizer.added_tokens_encoder) + separations = [" ", "\n", "\t", "-", "!", "one", "1", "", ""] + all_tokens = list(sorted(set(special_tokens + added_tokens + separations))) rand = random.Random() for m in range(iterations): rand.seed(m) - words = rand.choices(special_tokens, k=500) + words = rand.choices(all_tokens, k=500) if words[0] == tokenizer.bos_token: # skip spam warning of double BOS while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS words.pop(0) @@ -175,11 +197,6 @@ def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]: yield "".join(words) -def generator_vocab_words(vocab: list[str]) -> Iterator[str]: - """Brute force check all vocab words""" - yield from vocab - - def generator_random_chars(iterations=100) -> Iterator[str]: """Brute force random text with simple characters""" @@ -274,8 +291,8 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g ids2 = func_tokenize2(text) if ids1 != ids2: i = find_first_mismatch(ids1, ids2) - ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1] - ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1] + ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1] + ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1] logger.info(" TokenIDs: " + str(ids1)) logger.info(" Expected: " + str(ids2)) raise Exception() @@ -309,8 +326,9 @@ def main(argv: list[str] = None): vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True))) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text()) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases()) - test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_special_tokens(tokenizer, 10_000)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab)) + test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer)) + test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000)) test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000)) @@ -322,14 +340,14 @@ def main(argv: list[str] = None): if __name__ == "__main__": # main() - path_tokenizers = "./models/tokenizers/" + path_tokenizers = "./models/tokenizers/" path_vocab_format = "./models/ggml-vocab-%s.gguf" # import os # tokenizers = os.listdir(path_tokenizers) tokenizers = [ - # "llama-spm", # SPM - # "phi-3", # SPM + "llama-spm", # SPM + "phi-3", # SPM "jina-v2-en", # WPM "bert-bge", # WPM ]