llama : introduce enum llama_vocab_type + remove hardcoded string constants

This commit is contained in:
Georgi Gerganov 2023-08-18 18:46:38 +03:00
parent a4ad2bf35c
commit 25b8a8922d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 87 additions and 66 deletions

View File

@ -777,8 +777,10 @@ struct llama_vocab {
float score;
};
llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
std::unordered_map<token, id> token_to_id;
std::vector<token_score> id_to_token;
std::vector<token_score> id_to_token;
// default LLaMA special tokens
id special_bos_id = 1;
@ -1406,6 +1408,19 @@ static void llama_model_load_internal(
} \
}
std::string tokenizer_name;
GGUF_GET(tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model");
if (tokenizer_name == "llama") {
vocab.type = LLAMA_VOCAB_TYPE_SPM;
} else if (tokenizer_name == "gpt2") {
vocab.type = LLAMA_VOCAB_TYPE_BPE;
} else {
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
vocab.type = LLAMA_VOCAB_TYPE_SPM;
}
// get hparams kv
GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens");
GGUF_GET(hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length");
@ -1504,6 +1519,7 @@ static void llama_model_load_internal(
// hparams
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->fver));
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, general_arch.c_str());
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
@ -2317,40 +2333,22 @@ static bool llama_eval_internal(
// tokenizer
//
static std::string llama_vocab_type(const llama_vocab & vocab) {
return vocab.token_to_id.size() == 32000 ? "spm": "bpe";
static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
return vocab.type;
}
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
return token >= 259;
}
if (llama_vocab_type(vocab) == "bpe") {
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
return token >= 95;
}
return false;
}
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
return token == 0;
}
// TODO: improve?
return false;
}
static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
return token == 1 || token == 2;
}
// TODO: improve?
return false;
}
static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) {
return token == vocab.special_bos_id;
}
@ -2359,6 +2357,24 @@ static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) {
return token == vocab.special_eos_id;
}
static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
return token == llama_is_bos_token(vocab, token) || token == llama_is_eos_token(vocab, token);
}
// TODO: improve?
return false;
}
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
return token == 0;
}
// TODO: improve?
return false;
}
static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) {
GGML_UNUSED(vocab);
GGML_UNUSED(token);
@ -2374,11 +2390,11 @@ static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token)
}
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
if (llama_vocab_type(vocab) == "spm") {
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
return 3 <= token && token < 259;
}
if (llama_vocab_type(vocab) == "bpe") {
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
return 1 <= token && token < 95;
}
@ -2386,11 +2402,11 @@ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
}
static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
if (llama_vocab_type(vocab) == "spm") {
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
return byte - 3;
}
if (llama_vocab_type(vocab) == "bpe") {
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
return byte + 32;
}
@ -2398,11 +2414,11 @@ static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
}
static uint8_t llama_char_to_byte(const llama_vocab & vocab, uint8_t ch) {
if (llama_vocab_type(vocab) == "spm") {
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
return ch + 3;
}
if (llama_vocab_type(vocab) == "bpe") {
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
return ch - 32;
}
@ -5027,7 +5043,7 @@ int llama_tokenize_with_model(
llama_token * tokens,
int n_max_tokens,
bool add_bos) {
auto escape = llama_vocab_type(model->vocab) == "spm";
auto escape = llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM;
auto res = llama_tokenize_internal(model->vocab, text, add_bos, escape);
if (n_max_tokens < (int) res.size()) {
@ -5063,7 +5079,7 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token
if (0 <= token && token < llama_model_n_vocab(model)) {
if (llama_is_normal_token(model->vocab, token)) {
std::string result = model->vocab.id_to_token[token].tok;
if (llama_vocab_type(model->vocab) == "spm") {
if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
result = llama_unescape_whitespace(result);
}
if (length < (int) result.length()) {

75
llama.h
View File

@ -61,6 +61,40 @@ extern "C" {
typedef int llama_token;
enum llama_log_level {
LLAMA_LOG_LEVEL_ERROR = 2,
LLAMA_LOG_LEVEL_WARN = 3,
LLAMA_LOG_LEVEL_INFO = 4
};
enum llama_vocab_type {
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
};
// model file types
enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
};
typedef struct llama_token_data {
llama_token id; // token id
float logit; // log-odds of the token
@ -75,19 +109,6 @@ extern "C" {
typedef void (*llama_progress_callback)(float progress, void *ctx);
enum llama_log_level {
LLAMA_LOG_LEVEL_ERROR = 2,
LLAMA_LOG_LEVEL_WARN = 3,
LLAMA_LOG_LEVEL_INFO = 4
};
// Signature for logging events
// Note that text includes the new line character at the end for most events.
// If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
// if it exists.
// It might not exist for progress report where '.' is output repeatedly.
typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
struct llama_context_params {
uint32_t seed; // RNG seed, -1 for random
int32_t n_ctx; // text context
@ -117,28 +138,12 @@ extern "C" {
bool embedding; // embedding mode only
};
// model file types
enum llama_ftype {
LLAMA_FTYPE_ALL_F32 = 0,
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
};
// Signature for logging events
// Note that text includes the new line character at the end for most events.
// If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
// if it exists.
// It might not exist for progress report where '.' is output repeatedly.
typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
// model quantization parameters
typedef struct llama_model_quantize_params {