mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
llama : introduce enum llama_vocab_type + remove hardcoded string constants
This commit is contained in:
parent
a4ad2bf35c
commit
25b8a8922d
78
llama.cpp
78
llama.cpp
@ -777,8 +777,10 @@ struct llama_vocab {
|
|||||||
float score;
|
float score;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
|
||||||
|
|
||||||
std::unordered_map<token, id> token_to_id;
|
std::unordered_map<token, id> token_to_id;
|
||||||
std::vector<token_score> id_to_token;
|
std::vector<token_score> id_to_token;
|
||||||
|
|
||||||
// default LLaMA special tokens
|
// default LLaMA special tokens
|
||||||
id special_bos_id = 1;
|
id special_bos_id = 1;
|
||||||
@ -1406,6 +1408,19 @@ static void llama_model_load_internal(
|
|||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string tokenizer_name;
|
||||||
|
GGUF_GET(tokenizer_name, gguf_get_val_str, GGUF_TYPE_STRING, true, "tokenizer.ggml.model");
|
||||||
|
|
||||||
|
if (tokenizer_name == "llama") {
|
||||||
|
vocab.type = LLAMA_VOCAB_TYPE_SPM;
|
||||||
|
} else if (tokenizer_name == "gpt2") {
|
||||||
|
vocab.type = LLAMA_VOCAB_TYPE_BPE;
|
||||||
|
} else {
|
||||||
|
LLAMA_LOG_WARN("%s: unknown tokenizer: '%s'", __func__, tokenizer_name.c_str());
|
||||||
|
LLAMA_LOG_WARN("%s: using default tokenizer: 'llama'", __func__);
|
||||||
|
vocab.type = LLAMA_VOCAB_TYPE_SPM;
|
||||||
|
}
|
||||||
|
|
||||||
// get hparams kv
|
// get hparams kv
|
||||||
GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens");
|
GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens");
|
||||||
GGUF_GET(hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length");
|
GGUF_GET(hparams.n_ctx_train, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length");
|
||||||
@ -1504,6 +1519,7 @@ static void llama_model_load_internal(
|
|||||||
// hparams
|
// hparams
|
||||||
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->fver));
|
LLAMA_LOG_INFO("%s: format = %s\n", __func__, llama_file_version_name(ml->fver));
|
||||||
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, general_arch.c_str());
|
LLAMA_LOG_INFO("%s: arch = %s\n", __func__, general_arch.c_str());
|
||||||
|
LLAMA_LOG_INFO("%s: vocab type = %s\n", __func__, vocab.type == LLAMA_VOCAB_TYPE_SPM ? "SPM" : "BPE"); // TODO: fix
|
||||||
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
|
LLAMA_LOG_INFO("%s: n_vocab = %u\n", __func__, hparams.n_vocab);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
|
LLAMA_LOG_INFO("%s: n_ctx_train = %u\n", __func__, hparams.n_ctx_train);
|
||||||
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
|
LLAMA_LOG_INFO("%s: n_ctx = %u\n", __func__, hparams.n_ctx);
|
||||||
@ -2317,40 +2333,22 @@ static bool llama_eval_internal(
|
|||||||
// tokenizer
|
// tokenizer
|
||||||
//
|
//
|
||||||
|
|
||||||
static std::string llama_vocab_type(const llama_vocab & vocab) {
|
static enum llama_vocab_type llama_vocab_get_type(const llama_vocab & vocab) {
|
||||||
return vocab.token_to_id.size() == 32000 ? "spm": "bpe";
|
return vocab.type;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) {
|
static bool llama_is_normal_token(const llama_vocab & vocab, llama_token token) {
|
||||||
if (llama_vocab_type(vocab) == "spm") {
|
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
return token >= 259;
|
return token >= 259;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_vocab_type(vocab) == "bpe") {
|
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
|
||||||
return token >= 95;
|
return token >= 95;
|
||||||
}
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) {
|
|
||||||
if (llama_vocab_type(vocab) == "spm") {
|
|
||||||
return token == 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: improve?
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) {
|
|
||||||
if (llama_vocab_type(vocab) == "spm") {
|
|
||||||
return token == 1 || token == 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: improve?
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) {
|
static bool llama_is_bos_token(const llama_vocab & vocab, llama_token token) {
|
||||||
return token == vocab.special_bos_id;
|
return token == vocab.special_bos_id;
|
||||||
}
|
}
|
||||||
@ -2359,6 +2357,24 @@ static bool llama_is_eos_token(const llama_vocab & vocab, llama_token token) {
|
|||||||
return token == vocab.special_eos_id;
|
return token == vocab.special_eos_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool llama_is_control_token(const llama_vocab & vocab, llama_token token) {
|
||||||
|
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
|
return token == llama_is_bos_token(vocab, token) || token == llama_is_eos_token(vocab, token);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: improve?
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool llama_is_unknown_token(const llama_vocab & vocab, llama_token token) {
|
||||||
|
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
|
return token == 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: improve?
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) {
|
static bool llama_is_user_defined_token(const llama_vocab & vocab, llama_token token) {
|
||||||
GGML_UNUSED(vocab);
|
GGML_UNUSED(vocab);
|
||||||
GGML_UNUSED(token);
|
GGML_UNUSED(token);
|
||||||
@ -2374,11 +2390,11 @@ static bool llama_is_unused_token(const llama_vocab & vocab, llama_token token)
|
|||||||
}
|
}
|
||||||
|
|
||||||
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
|
static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
|
||||||
if (llama_vocab_type(vocab) == "spm") {
|
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
return 3 <= token && token < 259;
|
return 3 <= token && token < 259;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_vocab_type(vocab) == "bpe") {
|
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
|
||||||
return 1 <= token && token < 95;
|
return 1 <= token && token < 95;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2386,11 +2402,11 @@ static bool llama_is_byte_token(const llama_vocab & vocab, llama_token token) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
|
static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
|
||||||
if (llama_vocab_type(vocab) == "spm") {
|
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
return byte - 3;
|
return byte - 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_vocab_type(vocab) == "bpe") {
|
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
|
||||||
return byte + 32;
|
return byte + 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2398,11 +2414,11 @@ static uint8_t llama_byte_to_char(const llama_vocab & vocab, uint8_t byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
static uint8_t llama_char_to_byte(const llama_vocab & vocab, uint8_t ch) {
|
static uint8_t llama_char_to_byte(const llama_vocab & vocab, uint8_t ch) {
|
||||||
if (llama_vocab_type(vocab) == "spm") {
|
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
return ch + 3;
|
return ch + 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_vocab_type(vocab) == "bpe") {
|
if (llama_vocab_get_type(vocab) == LLAMA_VOCAB_TYPE_BPE) {
|
||||||
return ch - 32;
|
return ch - 32;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -5027,7 +5043,7 @@ int llama_tokenize_with_model(
|
|||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int n_max_tokens,
|
int n_max_tokens,
|
||||||
bool add_bos) {
|
bool add_bos) {
|
||||||
auto escape = llama_vocab_type(model->vocab) == "spm";
|
auto escape = llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM;
|
||||||
auto res = llama_tokenize_internal(model->vocab, text, add_bos, escape);
|
auto res = llama_tokenize_internal(model->vocab, text, add_bos, escape);
|
||||||
|
|
||||||
if (n_max_tokens < (int) res.size()) {
|
if (n_max_tokens < (int) res.size()) {
|
||||||
@ -5063,7 +5079,7 @@ int llama_token_to_str_with_model(const struct llama_model * model, llama_token
|
|||||||
if (0 <= token && token < llama_model_n_vocab(model)) {
|
if (0 <= token && token < llama_model_n_vocab(model)) {
|
||||||
if (llama_is_normal_token(model->vocab, token)) {
|
if (llama_is_normal_token(model->vocab, token)) {
|
||||||
std::string result = model->vocab.id_to_token[token].tok;
|
std::string result = model->vocab.id_to_token[token].tok;
|
||||||
if (llama_vocab_type(model->vocab) == "spm") {
|
if (llama_vocab_get_type(model->vocab) == LLAMA_VOCAB_TYPE_SPM) {
|
||||||
result = llama_unescape_whitespace(result);
|
result = llama_unescape_whitespace(result);
|
||||||
}
|
}
|
||||||
if (length < (int) result.length()) {
|
if (length < (int) result.length()) {
|
||||||
|
75
llama.h
75
llama.h
@ -61,6 +61,40 @@ extern "C" {
|
|||||||
|
|
||||||
typedef int llama_token;
|
typedef int llama_token;
|
||||||
|
|
||||||
|
enum llama_log_level {
|
||||||
|
LLAMA_LOG_LEVEL_ERROR = 2,
|
||||||
|
LLAMA_LOG_LEVEL_WARN = 3,
|
||||||
|
LLAMA_LOG_LEVEL_INFO = 4
|
||||||
|
};
|
||||||
|
|
||||||
|
enum llama_vocab_type {
|
||||||
|
LLAMA_VOCAB_TYPE_SPM = 0, // SentencePiece
|
||||||
|
LLAMA_VOCAB_TYPE_BPE = 1, // Byte Pair Encoding
|
||||||
|
};
|
||||||
|
|
||||||
|
// model file types
|
||||||
|
enum llama_ftype {
|
||||||
|
LLAMA_FTYPE_ALL_F32 = 0,
|
||||||
|
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
|
||||||
|
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
|
||||||
|
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
|
||||||
|
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
|
||||||
|
};
|
||||||
|
|
||||||
typedef struct llama_token_data {
|
typedef struct llama_token_data {
|
||||||
llama_token id; // token id
|
llama_token id; // token id
|
||||||
float logit; // log-odds of the token
|
float logit; // log-odds of the token
|
||||||
@ -75,19 +109,6 @@ extern "C" {
|
|||||||
|
|
||||||
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
typedef void (*llama_progress_callback)(float progress, void *ctx);
|
||||||
|
|
||||||
enum llama_log_level {
|
|
||||||
LLAMA_LOG_LEVEL_ERROR = 2,
|
|
||||||
LLAMA_LOG_LEVEL_WARN = 3,
|
|
||||||
LLAMA_LOG_LEVEL_INFO = 4
|
|
||||||
};
|
|
||||||
|
|
||||||
// Signature for logging events
|
|
||||||
// Note that text includes the new line character at the end for most events.
|
|
||||||
// If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
|
|
||||||
// if it exists.
|
|
||||||
// It might not exist for progress report where '.' is output repeatedly.
|
|
||||||
typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
|
|
||||||
|
|
||||||
struct llama_context_params {
|
struct llama_context_params {
|
||||||
uint32_t seed; // RNG seed, -1 for random
|
uint32_t seed; // RNG seed, -1 for random
|
||||||
int32_t n_ctx; // text context
|
int32_t n_ctx; // text context
|
||||||
@ -117,28 +138,12 @@ extern "C" {
|
|||||||
bool embedding; // embedding mode only
|
bool embedding; // embedding mode only
|
||||||
};
|
};
|
||||||
|
|
||||||
// model file types
|
// Signature for logging events
|
||||||
enum llama_ftype {
|
// Note that text includes the new line character at the end for most events.
|
||||||
LLAMA_FTYPE_ALL_F32 = 0,
|
// If your logging mechanism cannot handle that, check if the last character is '\n' and strip it
|
||||||
LLAMA_FTYPE_MOSTLY_F16 = 1, // except 1d tensors
|
// if it exists.
|
||||||
LLAMA_FTYPE_MOSTLY_Q4_0 = 2, // except 1d tensors
|
// It might not exist for progress report where '.' is output repeatedly.
|
||||||
LLAMA_FTYPE_MOSTLY_Q4_1 = 3, // except 1d tensors
|
typedef void (*llama_log_callback)(enum llama_log_level level, const char * text, void * user_data);
|
||||||
LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16 = 4, // tok_embeddings.weight and output.weight are F16
|
|
||||||
// LLAMA_FTYPE_MOSTLY_Q4_2 = 5, // support has been removed
|
|
||||||
// LLAMA_FTYPE_MOSTLY_Q4_3 = 6, // support has been removed
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q8_0 = 7, // except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q5_0 = 8, // except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q5_1 = 9, // except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q2_K = 10,// except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q3_K_S = 11,// except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q3_K_M = 12,// except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q3_K_L = 13,// except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q4_K_S = 14,// except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q4_K_M = 15,// except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
|
|
||||||
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
|
|
||||||
};
|
|
||||||
|
|
||||||
// model quantization parameters
|
// model quantization parameters
|
||||||
typedef struct llama_model_quantize_params {
|
typedef struct llama_model_quantize_params {
|
||||||
|
Loading…
Reference in New Issue
Block a user