llama : update tokenizer style

This commit is contained in:
Georgi Gerganov 2023-08-14 22:10:19 +03:00
parent 7494c78428
commit 6c63550f63
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 101 additions and 75 deletions

View File

@ -2112,49 +2112,56 @@ static bool llama_eval_internal(
// tokenizer // tokenizer
// //
static std::string llama_vocab_type(const llama_vocab& vocab) { static std::string llama_vocab_type(const llama_vocab & vocab) {
return vocab.token_to_id.size() == 32000 ? "spm": "bpe"; return vocab.token_to_id.size() == 32000 ? "spm": "bpe";
} }
static bool llama_is_normal_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token >= 259; return token >= 259;
else if(llama_vocab_type(vocab) == "bpe") }
if (llama_vocab_type(vocab) == "bpe") {
return token >= 95; return token >= 95;
else }
return false;
return false;
} }
static bool llama_is_unknown_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 0; return token == 0;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_control_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 1 || token == 2; return token == 1 || token == 2;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_bos_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 1; return token == 1;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_eos_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 2; return token == 2;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) { static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) {
@ -2164,29 +2171,35 @@ static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token t
return false; return false;
} }
static bool llama_is_unused_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token) {
UNUSED(vocab); UNUSED(vocab);
UNUSED(token); UNUSED(token);
// TODO: improve? // TODO: improve?
return false; return false;
} }
static bool llama_is_byte_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return 3 <= token && token < 259; return 3 <= token && token < 259;
else if(llama_vocab_type(vocab) == "bpe") }
if (llama_vocab_type(vocab) == "bpe") {
return 1 <= token && token < 95; return 1 <= token && token < 95;
else }
return false;
return false;
} }
static uint8_t llama_byte_to_char(const llama_vocab& vocab, uint8_t byte) { static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return byte + 3; return byte + 3;
else if(llama_vocab_type(vocab) == "bpe") }
if (llama_vocab_type(vocab) == "bpe") {
return byte + 32; return byte + 32;
else }
return false;
return false;
} }
static std::string llama_escape_whitespace(const std::string& text) { static std::string llama_escape_whitespace(const std::string& text) {

View File

@ -1944,81 +1944,94 @@ static bool llama_eval_internal(
// tokenizer // tokenizer
// //
static std::string llama_vocab_type(const llama_vocab& vocab) { static std::string llama_vocab_type(const llama_vocab & vocab) {
return vocab.token_to_id.size() == 32000 ? "spm": "bpe"; return vocab.token_to_id.size() == 32000 ? "spm": "bpe";
} }
static bool llama_is_normal_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token >= 259; return token >= 259;
else if(llama_vocab_type(vocab) == "bpe") }
if (llama_vocab_type(vocab) == "bpe") {
return token >= 95; return token >= 95;
else }
return false;
return false;
} }
static bool llama_is_unknown_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 0; return token == 0;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_control_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 1 || token == 2; return token == 1 || token == 2;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_bos_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 1; return token == 1;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_eos_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return token == 2; return token == 2;
else }
// TODO: improve?
return false; // TODO: improve?
return false;
} }
static bool llama_is_user_defined_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) {
UNUSED(vocab); UNUSED(vocab);
UNUSED(token); UNUSED(token);
// TODO: improve? // TODO: improve?
return false; return false;
} }
static bool llama_is_unused_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token) {
UNUSED(vocab); UNUSED(vocab);
UNUSED(token); UNUSED(token);
// TODO: improve? // TODO: improve?
return false; return false;
} }
static bool llama_is_byte_token(const llama_vocab& vocab, llama_token token) { static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return 3 <= token && token < 259; return 3 <= token && token < 259;
else if(llama_vocab_type(vocab) == "bpe") }
if (llama_vocab_type(vocab) == "bpe") {
return 1 <= token && token < 95; return 1 <= token && token < 95;
else }
return false;
return false;
} }
static uint8_t llama_byte_to_char(const llama_vocab& vocab, uint8_t byte) { static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
if(llama_vocab_type(vocab) == "spm") if (llama_vocab_type(vocab) == "spm") {
return byte + 3; return byte + 3;
else if(llama_vocab_type(vocab) == "bpe") }
if (llama_vocab_type(vocab) == "bpe") {
return byte + 32; return byte + 32;
else }
return false;
return false;
} }
static std::string llama_escape_whitespace(const std::string& text) { static std::string llama_escape_whitespace(const std::string& text) {