Per token attributes (#7685)

* Add per token attributes enum
* Using phi-3 for testing 'rstrip'
* Using jina-v2 for testing 'lstrip'
* Brute force test for 'lstrip' and 'rstrip'
* Implement 'rstrip' and 'lstrip'
* Update phi-3 GGUF file (obsolete since 917dc8c)
* Replace llama_token_type with llama_token_attribs
This commit is contained in:
jaime-m-p 2024-06-04 09:17:17 +02:00 committed by GitHub
parent 6d1616944d
commit 3b38d48609
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 155 additions and 62 deletions

145
llama.cpp
View File

@ -2149,12 +2149,12 @@ struct llama_control_vector {
struct llama_vocab {
using id = int32_t;
using token = std::string;
using ttype = llama_token_type;
using tattr = llama_token_attr;
struct token_data {
token text;
float score;
ttype type;
tattr attr;
};
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
@ -4750,7 +4750,20 @@ static void llm_load_vocab(
auto & token_data = vocab.id_to_token[i];
token_data.text = std::move(word);
token_data.score = scores ? scores[i] : 0.0f;
token_data.type = toktypes ? (llama_token_type) toktypes[i] : LLAMA_TOKEN_TYPE_NORMAL;
token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;
if (toktypes) { //TODO: remove, required until per token attributes are available from GGUF file
switch(toktypes[i]) {
case LLAMA_TOKEN_TYPE_UNKNOWN: token_data.attr = LLAMA_TOKEN_ATTR_UNKNOWN; break;
case LLAMA_TOKEN_TYPE_UNUSED: token_data.attr = LLAMA_TOKEN_ATTR_UNUSED; break;
case LLAMA_TOKEN_TYPE_NORMAL: token_data.attr = LLAMA_TOKEN_ATTR_NORMAL; break;
case LLAMA_TOKEN_TYPE_CONTROL: token_data.attr = LLAMA_TOKEN_ATTR_CONTROL; break;
case LLAMA_TOKEN_TYPE_USER_DEFINED: token_data.attr = LLAMA_TOKEN_ATTR_USER_DEFINED; break;
case LLAMA_TOKEN_TYPE_BYTE: token_data.attr = LLAMA_TOKEN_ATTR_BYTE; break;
case LLAMA_TOKEN_TYPE_UNDEFINED: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break;
default: token_data.attr = LLAMA_TOKEN_ATTR_UNDEFINED; break;
}
}
}
GGML_ASSERT(vocab.id_to_token.size() == vocab.token_to_id.size());
@ -4841,7 +4854,7 @@ static void llm_load_vocab(
// build special tokens cache
{
for (llama_vocab::id id = 0; id < (llama_vocab::id)n_vocab; ++id) {
if (vocab.id_to_token[id].type != LLAMA_TOKEN_TYPE_NORMAL) {
if (!(vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL)) {
vocab.cache_special_tokens.push_back(id);
}
}
@ -4871,6 +4884,59 @@ static void llm_load_vocab(
LLAMA_LOG_INFO("%s: token to piece cache size = %.4f MB\n", __func__, size_cache / 1024.0 / 1024.0);
}
// Handle per token attributes
//NOTE: Each model customizes per token attributes.
//NOTE: Per token attributes are missing from the GGUF file.
//TODO: Extract attributes from GGUF file.
{
auto _contains_any = [] (const std::string &str, const std::vector<std::string> &substrs) -> bool {
for (auto substr : substrs) {
if (str.find(substr) < std::string::npos) {
return true;
}
}
return false;
};
auto _set_tokenid_attr = [&] (const llama_vocab::id id, llama_token_attr attr, bool value) {
uint32_t current = vocab.id_to_token.at(id).attr;
current = value ? (current | attr) : (current & ~attr);
vocab.id_to_token[id].attr = (llama_token_attr) current;
};
auto _set_token_attr = [&] (const std::string & token, llama_token_attr attr, bool value) {
_set_tokenid_attr(vocab.token_to_id.at(token), attr, value);
};
std::string model_name;
std::string tokenizer_pre;
ml.get_key(LLM_KV_GENERAL_NAME, model_name, false);
ml.get_key(LLM_KV_TOKENIZER_PRE, tokenizer_pre, false);
// model name to lowercase
std::transform(model_name.begin(), model_name.end(), model_name.begin(),
[] (const std::string::value_type x) {
return std::tolower(x);
}
);
// set attributes by model/tokenizer name
if (_contains_any(tokenizer_pre, {"jina-v2-es", "jina-v2-de"})) {
_set_token_attr("<mask>", LLAMA_TOKEN_ATTR_LSTRIP, true);
} else if (_contains_any(model_name, {"phi-3", "phi3"})) {
for (auto id : vocab.cache_special_tokens) {
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
}
for (auto token : {"</s>"}) {
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
}
for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) {
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
}
}
}
}
static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
@ -12620,27 +12686,27 @@ static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_NORMAL;
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_NORMAL;
}
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_UNKNOWN;
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_UNKNOWN;
}
static bool llama_is_control_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_CONTROL;
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_CONTROL;
}
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_BYTE;
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_BYTE;
}
static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token id) {
GGML_ASSERT(vocab.type != LLAMA_VOCAB_TYPE_NONE);
return vocab.id_to_token[id].type == LLAMA_TOKEN_TYPE_USER_DEFINED;
return vocab.id_to_token[id].attr & LLAMA_TOKEN_ATTR_USER_DEFINED;
}
static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
@ -13258,7 +13324,8 @@ struct fragment_buffer_variant {
static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<fragment_buffer_variant> & buffer) {
// for each special token
for (const llama_vocab::id special_id : vocab.cache_special_tokens) {
const auto & special_token = vocab.id_to_token[special_id].text;
const auto & data = vocab.id_to_token[special_id];
const auto & special_token = data.text;
// for each text fragment
std::forward_list<fragment_buffer_variant>::iterator it = buffer.begin();
@ -13295,13 +13362,22 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
if (match > raw_text_base_offset) {
// left
const int64_t left_reminder_offset = raw_text_base_offset + 0;
const int64_t left_reminder_length = match - raw_text_base_offset;
int64_t left_reminder_length = match - raw_text_base_offset;
if (data.attr & LLAMA_TOKEN_ATTR_LSTRIP) {
while (left_reminder_length > 0 && isspace(raw_text[left_reminder_offset + left_reminder_length - 1])) {
left_reminder_length--;
}
}
if (left_reminder_length > 0) {
buffer.emplace_after(it, raw_text, left_reminder_offset, left_reminder_length);
it++;
}
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("FL: (%ld %ld) '%s'\n", left_reminder_offset, left_reminder_length, raw_text->substr(left_reminder_offset, left_reminder_length).c_str());
#endif
it++;
}
// special token
@ -13310,16 +13386,25 @@ static void tokenizer_st_partition(const llama_vocab & vocab, std::forward_list<
// right
if (match + special_token.length() < raw_text_base_offset + raw_text_base_length) {
const int64_t right_reminder_offset = match + special_token.length();
const int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
int64_t right_reminder_offset = match + special_token.length();
int64_t right_reminder_length = raw_text_base_length - ((match - raw_text_base_offset) + special_token.length());
if (data.attr & LLAMA_TOKEN_ATTR_RSTRIP) {
while (right_reminder_length > 0 && isspace(raw_text[right_reminder_offset])) {
right_reminder_offset++;
right_reminder_length--;
}
}
if (right_reminder_length > 0) {
buffer.emplace_after(it, raw_text, right_reminder_offset, right_reminder_length);
it++;
}
#ifdef PRETOKENIZERDEBUG
LLAMA_LOG_WARN("FR: (%ld %ld) '%s'\n", right_reminder_offset, right_reminder_length, raw_text->substr(right_reminder_offset, right_reminder_length).c_str());
#endif
it++;
if (source == 0) {
buffer.erase_after(buffer.before_begin());
} else {
@ -13365,9 +13450,7 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
// tokenizer.encode('', add_special_tokens=True) returns [1]
// tokenizer.encode('', add_special_tokens=False) returns []
static const bool rtrim = true; //TODO: as param
bool is_prev_special = false;
bool special_token_rtrim = false;
if (add_special && vocab.special_add_bos != 0) {
GGML_ASSERT(vocab.special_bos_id != -1);
@ -13377,25 +13460,8 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
for (const auto & fragment : fragment_buffer) {
if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) {
// without adding this leading whitespace, we do not get the same results as the original tokenizer
// TODO: It's likely possible to get rid of this string copy entirely
// by modifying llm_tokenizer_x to operate with string offsets like pre-tokenizer
// and passing 'add space prefix' as bool argument
//
auto raw_text = fragment.raw_text.substr(fragment.offset, fragment.length);
if (special_token_rtrim) {
size_t num_whitespaces = 0;
while (isspace(raw_text[num_whitespaces])) {
num_whitespaces++;
}
if (num_whitespaces == raw_text.size()) {
continue; // skip if all whitespaces
}
raw_text = raw_text.substr(num_whitespaces);
}
if (vocab.add_space_prefix) {
if (!output.size() || is_prev_special) { // prefix with space if first token
raw_text = " " + raw_text;
@ -13411,11 +13477,6 @@ static std::vector<llama_vocab::id> llama_tokenize_internal(const llama_vocab &
} else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN)
output.push_back(fragment.token);
is_prev_special = true;
// phi-3 special tokens without rtrim, works fine for llama-spm too
special_token_rtrim = rtrim
&& fragment.token != vocab.special_bos_id
&& fragment.token != vocab.special_unk_id
&& fragment.token != vocab.special_eos_id;
}
}
@ -18221,9 +18282,9 @@ float llama_token_get_score(const struct llama_model * model, llama_token token)
return model->vocab.id_to_token[token].score;
}
llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token) {
llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
GGML_ASSERT(model->vocab.type != LLAMA_VOCAB_TYPE_NONE);
return model->vocab.id_to_token[token].type;
return model->vocab.id_to_token[token].attr;
}
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {

18
llama.h
View File

@ -97,7 +97,7 @@ extern "C" {
LLAMA_ROPE_TYPE_GLM = 4,
};
enum llama_token_type {
enum llama_token_type { //TODO: remove, required until per token attributes are available from GGUF file
LLAMA_TOKEN_TYPE_UNDEFINED = 0,
LLAMA_TOKEN_TYPE_NORMAL = 1,
LLAMA_TOKEN_TYPE_UNKNOWN = 2,
@ -107,6 +107,20 @@ extern "C" {
LLAMA_TOKEN_TYPE_BYTE = 6,
};
enum llama_token_attr {
LLAMA_TOKEN_ATTR_UNDEFINED = 0,
LLAMA_TOKEN_ATTR_UNKNOWN = 1 << 1,
LLAMA_TOKEN_ATTR_UNUSED = 1 << 2,
LLAMA_TOKEN_ATTR_NORMAL = 1 << 3,
LLAMA_TOKEN_ATTR_CONTROL = 1 << 4, // SPECIAL?
LLAMA_TOKEN_ATTR_USER_DEFINED = 1 << 5,
LLAMA_TOKEN_ATTR_BYTE = 1 << 6,
LLAMA_TOKEN_ATTR_NORMALIZED = 1 << 7,
LLAMA_TOKEN_ATTR_LSTRIP = 1 << 8,
LLAMA_TOKEN_ATTR_RSTRIP = 1 << 9,
LLAMA_TOKEN_ATTR_SINGLE_WORD = 1 << 10,
};
// model file types
enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0,
@ -821,7 +835,7 @@ extern "C" {
LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token);
LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_model * model, llama_token token);
LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token);
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token);

Binary file not shown.

View File

@ -156,17 +156,39 @@ def generator_custom_text_edge_cases() -> Iterator[str]:
'<s>a', # Phi-3 fail
'<unk><|endoftext|><s>', # Phi-3 fail
'a\na', # TODO: Bert fail
'a </s> b', # rstrip phi-3
'a <mask> b', # lstrip jina-v2
]
def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]:
special_tokens = set(tokenizer.all_special_tokens)
special_tokens.update([" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"])
special_tokens = list(sorted(special_tokens))
def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
"""Brute force check all vocab words"""
yield from vocab
def generator_added_lr_strip(tokenizer) -> Iterator[str]:
WHITESPACES = ["", " ", " ", " "]
special_tokens = list(tokenizer.all_special_tokens)
added_tokens = list(tokenizer.added_tokens_encoder)
all_tokens = list(sorted(set(special_tokens + added_tokens)))
for token in all_tokens:
for lstrip in WHITESPACES:
for rstrip in WHITESPACES:
yield lstrip + token + rstrip
yield "a" + lstrip + token + rstrip
yield lstrip + token + rstrip + "z"
yield "a" + lstrip + token + rstrip + "z"
def generator_random_added_tokens(tokenizer, iterations=100) -> Iterator[str]:
special_tokens = list(tokenizer.all_special_tokens)
added_tokens = list(tokenizer.added_tokens_encoder)
separations = [" ", "\n", "\t", "-", "!", "one", "1", "<s>", "</s>"]
all_tokens = list(sorted(set(special_tokens + added_tokens + separations)))
rand = random.Random()
for m in range(iterations):
rand.seed(m)
words = rand.choices(special_tokens, k=500)
words = rand.choices(all_tokens, k=500)
if words[0] == tokenizer.bos_token: # skip spam warning of double BOS
while len(words) > 1 and words[1] == tokenizer.bos_token: # leave one starting BOS
words.pop(0)
@ -175,11 +197,6 @@ def generator_random_special_tokens(tokenizer, iterations=100) -> Iterator[str]:
yield "".join(words)
def generator_vocab_words(vocab: list[str]) -> Iterator[str]:
"""Brute force check all vocab words"""
yield from vocab
def generator_random_chars(iterations=100) -> Iterator[str]:
"""Brute force random text with simple characters"""
@ -274,8 +291,8 @@ def test_compare_tokenizer(func_tokenize1: Callable, func_tokenize2: Callable, g
ids2 = func_tokenize2(text)
if ids1 != ids2:
i = find_first_mismatch(ids1, ids2)
ids1 = list(ids1)[max(0, i - 2) : i + 2 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 2 + 1]
ids1 = list(ids1)[max(0, i - 2) : i + 5 + 1]
ids2 = list(ids2)[max(0, i - 2) : i + 5 + 1]
logger.info(" TokenIDs: " + str(ids1))
logger.info(" Expected: " + str(ids2))
raise Exception()
@ -309,8 +326,9 @@ def main(argv: list[str] = None):
vocab = list(sorted(tokenizer.batch_decode(list(tokenizer.get_vocab().values()), skip_special_tokens=True)))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_custom_text_edge_cases())
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_special_tokens(tokenizer, 10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_vocab_words(vocab))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_added_lr_strip(tokenizer))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_added_tokens(tokenizer, 10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_chars(10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_chars(vocab, 10_000))
test_compare_tokenizer(func_tokenize1, func_tokenize2, generator_random_vocab_words(vocab, 5_000))
@ -328,8 +346,8 @@ if __name__ == "__main__":
# import os
# tokenizers = os.listdir(path_tokenizers)
tokenizers = [
# "llama-spm", # SPM
# "phi-3", # SPM
"llama-spm", # SPM
"phi-3", # SPM
"jina-v2-en", # WPM
"bert-bge", # WPM
]