mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
llama : introduce enum llama_vocab_type + remove hardcoded string constants
This commit is contained in:
parent
a4ad2bf35c
commit
25b8a8922d
78
llama.cpp
78
llama.cpp
@ -777,8 +777,10 @@ struct llama_vocab {
|
||||
float score;
|
||||
};
|
||||
|
||||
llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||
|
||||
std::unordered_map<token, id> token_to_id;
|
||||
std::vector<token_score> id_to_token;
|
||||
std::vector<token_score> id_to_token;
|
||||
|
||||
// default LLaMA special tokens
|
||||
id special_bos_id = 1;
|
||||
@ -1406,6 +1408,19 @@ static void llama_model_load_internal(
|
||||
} \
|
||||
}
|
||||
|
||||
std::string tokenizer_name;
|
||||
GGUF_GET(tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model");
|
||||
|
||||
if (tokenizer_name == "llama") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_SPM;
|
||||
} else if (tokenizer_name == "gpt2") {
|
||||
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
||||
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
|
||||
vocab.type = LLAMA_VOCAB_TYPE_SPM;
|
||||
}
|
||||
|
||||
// get hparams kv
|
||||
GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens");
|
||||
GGUF_GET(hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length");
|
||||
@ -1504,6 +1519,7 @@ static void llama_model_load_internal(
|
||||
// hparams
|
||||
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->fver));
|
||||
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, general_arch.c_str());
|
||||
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
|
||||
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
|
||||
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
|
||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
|
||||
@ -2317,40 +2333,22 @@ static bool llama_eval_internal(
|
||||
// tokenizer
|
||||
//
|
||||
|
||||
static std::string llama_vocab_type(const llama_vocab & vocab) {
|
||||
return vocab.token_to_id.size() == 32000 ? "spm": "bpe";
|
||||
static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
|
||||
return vocab.type;
|
||||
}
|
||||
|
||||
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) {
|
||||
if (llama_vocab_type(vocab) == "spm") {
|
||||
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||
return token >= 259;
|
||||
}
|
||||
|
||||
if (llama_vocab_type(vocab) == "bpe") {
|
||||
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
|
||||
return token >= 95;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) {
|
||||
if (llama_vocab_type(vocab) == "spm") {
|
||||
return token == 0;
|
||||
}
|
||||
|
||||
// TODO: improve?
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) {
|
||||
if (llama_vocab_type(vocab) == "spm") {
|
||||
return token == 1 || token == 2;
|
||||
}
|
||||
|
||||
// TODO: improve?
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) {
|
||||
return token == vocab.special_bos_id;
|
||||
}
|
||||
@ -2359,6 +2357,24 @@ static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) {
|
||||
return token == vocab.special_eos_id;
|
||||
}
|
||||
|
||||
static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) {
|
||||
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||
return token == llama_is_bos_token(vocab, token) || token == llama_is_eos_token(vocab, token);
|
||||
}
|
||||
|
||||
// TODO: improve?
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) {
|
||||
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||
return token == 0;
|
||||
}
|
||||
|
||||
// TODO: improve?
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) {
|
||||
GGML_UNUSED(vocab);
|
||||
GGML_UNUSED(token);
|
||||
@ -2374,11 +2390,11 @@ static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token)
|
||||
}
|
||||
|
||||
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
|
||||
if (llama_vocab_type(vocab) == "spm") {
|
||||
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||
return 3 <= token && token < 259;
|
||||
}
|
||||
|
||||
if (llama_vocab_type(vocab) == "bpe") {
|
||||
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
|
||||
return 1 <= token && token < 95;
|
||||
}
|
||||
|
||||
@ -2386,11 +2402,11 @@ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
|
||||
}
|
||||
|
||||
static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
|
||||
if (llama_vocab_type(vocab) == "spm") {
|
||||
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||
return byte - 3;
|
||||
}
|
||||
|
||||
if (llama_vocab_type(vocab) == "bpe") {
|
||||
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
|
||||
return byte + 32;
|
||||
}
|
||||
|
||||
@ -2398,11 +2414,11 @@ static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
|
||||
}
|
||||
|
||||
static uint8_t llama_char_to_byte(const llama_vocab & vocab, uint8_t ch) {
|
||||
if (llama_vocab_type(vocab) == "spm") {
|
||||
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||
return ch + 3;
|
||||
}
|
||||
|
||||
if (llama_vocab_type(vocab) == "bpe") {
|
||||
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
|
||||
return ch - 32;
|
||||
}
|
||||
|
||||
@ -5027,7 +5043,7 @@ int llama_tokenize_with_model(
|
||||
llama_token * tokens,
|
||||
int n_max_tokens,
|
||||
bool add_bos) {
|
||||
auto escape = llama_vocab_type(model->vocab) == "spm";
|
||||
auto escape = llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM;
|
||||
auto res = llama_tokenize_internal(model->vocab, text, add_bos, escape);
|
||||
|
||||
if (n_max_tokens < (int) res.size()) {
|
||||
@ -5063,7 +5079,7 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token
|
||||
if (0 <= token && token < llama_model_n_vocab(model)) {
|
||||
if (llama_is_normal_token(model->vocab, token)) {
|
||||
std::string result = model->vocab.id_to_token[token].tok;
|
||||
if (llama_vocab_type(model->vocab) == "spm") {
|
||||
if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||
result = llama_unescape_whitespace(result);
|
||||
}
|
||||
if (length < (int) result.length()) {
|
||||
|
75
llama.h
75
llama.h
@ -61,6 +61,40 @@ extern "C" {
|
||||
|
||||
typedef int llama_token;
|
||||
|
||||
enum llama_log_level {
|
||||
LLAMA_LOG_LEVEL_ERROR = 2,
|
||||
LLAMA_LOG_LEVEL_WARN = 3,
|
||||
LLAMA_LOG_LEVEL_INFO = 4
|
||||
};
|
||||
|
||||
enum llama_vocab_type {
|
||||
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
|
||||
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum llama_ftype {
|
||||
LLAMA_FTYPE_ALL_F32 = 0,
|
||||
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
|
||||
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
|
||||
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
|
||||
};
|
||||
|
||||
typedef struct llama_token_data {
|
||||
llama_token id; // token id
|
||||
float logit; // log-odds of the token
|
||||
@ -75,19 +109,6 @@ extern "C" {
|
||||
|
||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||
|
||||
enum llama_log_level {
|
||||
LLAMA_LOG_LEVEL_ERROR = 2,
|
||||
LLAMA_LOG_LEVEL_WARN = 3,
|
||||
LLAMA_LOG_LEVEL_INFO = 4
|
||||
};
|
||||
|
||||
// Signature for logging events
|
||||
// Note that text includes the new line character at the end for most events.
|
||||
// If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
|
||||
// if it exists.
|
||||
// It might not exist for progress report where '.' is output repeatedly.
|
||||
typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
|
||||
|
||||
struct llama_context_params {
|
||||
uint32_t seed; // RNG seed, -1 for random
|
||||
int32_t n_ctx; // text context
|
||||
@ -117,28 +138,12 @@ extern "C" {
|
||||
bool embedding; // embedding mode only
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum llama_ftype {
|
||||
LLAMA_FTYPE_ALL_F32 = 0,
|
||||
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
|
||||
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
|
||||
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
|
||||
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
|
||||
};
|
||||
// Signature for logging events
|
||||
// Note that text includes the new line character at the end for most events.
|
||||
// If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
|
||||
// if it exists.
|
||||
// It might not exist for progress report where '.' is output repeatedly.
|
||||
typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
|
||||
|
||||
// model quantization parameters
|
||||
typedef struct llama_model_quantize_params {
|
||||
|
Loading…
Reference in New Issue
Block a user