llama : add llama_vocab, functions -> methods, naming (#11110)

* llama : functions -> methods (#11110)

* llama : add struct llama_vocab to the API (#11156)

ggml-ci

* hparams : move vocab params to llama_vocab (#11159)

ggml-ci

* vocab : more pimpl (#11165)

ggml-ci

* vocab : minor tokenization optimizations (#11160)

ggml-ci

Co-authored-by: Diego Devesa <slarengh@gmail.com>

* lora : update API names (#11167)

ggml-ci

* llama : update API names to use correct prefix (#11174)

* llama : update API names to use correct prefix

ggml-ci

* cont

ggml-ci

* cont

ggml-ci

* minor [no ci]

* vocab : llama_vocab_add_[be]os -> llama_vocab_get_add_[be]os (#11174)

ggml-ci

* vocab : llama_vocab_n_vocab -> llama_vocab_n_tokens (#11174)

ggml-ci

---------

Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
Georgi Gerganov 2025-01-12 11:32:42 +02:00 committed by GitHub
parent c05e8c9934
commit afa8a9ec9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
68 changed files with 5855 additions and 5400 deletions

View File

@ -857,21 +857,23 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams; return iparams;
} }
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.reranking) { if (params.reranking) {
bool ok = true; bool ok = true;
if (llama_token_bos(model) == LLAMA_TOKEN_NULL) { if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have a BOS token, reranking will not work\n", __func__); LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
ok = false; ok = false;
} }
if (llama_token_eos(model) == LLAMA_TOKEN_NULL) { if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__); LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
ok = false; ok = false;
} }
if (llama_token_sep(model) == LLAMA_TOKEN_NULL) { if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have a SEP token, reranking will not work\n", __func__); LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false; ok = false;
} }
@ -884,7 +886,7 @@ struct common_init_result common_init_from_params(common_params & params) {
auto cparams = common_context_params_to_llama(params); auto cparams = common_context_params_to_llama(params);
llama_context * lctx = llama_new_context_with_model(model, cparams); llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) { if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str());
llama_model_free(model); llama_model_free(model);
@ -898,7 +900,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (!params.control_vectors.empty()) { if (!params.control_vectors.empty()) {
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model);
const auto cvec = common_control_vector_load(params.control_vectors); const auto cvec = common_control_vector_load(params.control_vectors);
if (cvec.n_embd == -1) { if (cvec.n_embd == -1) {
@ -908,12 +910,13 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams; return iparams;
} }
int err = llama_control_vector_apply(lctx, int err = llama_apply_adapter_cvec(
cvec.data.data(), lctx,
cvec.data.size(), cvec.data.data(),
cvec.n_embd, cvec.data.size(),
params.control_vector_layer_start, cvec.n_embd,
params.control_vector_layer_end); params.control_vector_layer_start,
params.control_vector_layer_end);
if (err) { if (err) {
llama_free(lctx); llama_free(lctx);
llama_model_free(model); llama_model_free(model);
@ -924,8 +927,8 @@ struct common_init_result common_init_from_params(common_params & params) {
// load and optionally apply lora adapters // load and optionally apply lora adapters
for (auto & la : params.lora_adapters) { for (auto & la : params.lora_adapters) {
llama_lora_adapter_ptr lora; llama_adapter_lora_ptr lora;
lora.reset(llama_lora_adapter_init(model, la.path.c_str())); lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
if (lora == nullptr) { if (lora == nullptr) {
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
llama_free(lctx); llama_free(lctx);
@ -938,17 +941,17 @@ struct common_init_result common_init_from_params(common_params & params) {
} }
if (!params.lora_init_without_apply) { if (!params.lora_init_without_apply) {
common_lora_adapters_apply(lctx, params.lora_adapters); common_set_adapter_lora(lctx, params.lora_adapters);
} }
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false; params.sampling.ignore_eos = false;
} }
if (params.sampling.ignore_eos) { if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_n_vocab(model); i++) { for (llama_token i = 0; i < llama_vocab_n_tokens(vocab); i++) {
if (llama_token_is_eog(model, i)) { if (llama_vocab_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY}); params.sampling.logit_bias.push_back({i, -INFINITY});
} }
@ -969,8 +972,9 @@ struct common_init_result common_init_from_params(common_params & params) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
std::vector<llama_token> tmp; std::vector<llama_token> tmp;
llama_token bos = llama_token_bos(model); llama_token bos = llama_vocab_bos(vocab);
llama_token eos = llama_token_eos(model); llama_token eos = llama_vocab_eos(vocab);
// some models (e.g. T5) don't have a BOS token // some models (e.g. T5) don't have a BOS token
if (bos != LLAMA_TOKEN_NULL) { if (bos != LLAMA_TOKEN_NULL) {
tmp.push_back(bos); tmp.push_back(bos);
@ -1005,11 +1009,11 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams; return iparams;
} }
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora) { void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora) {
llama_lora_adapter_clear(ctx); llama_clear_adapter_lora(ctx);
for (auto & la : lora) { for (auto & la : lora) {
if (la.scale != 0.0f) { if (la.scale != 0.0f) {
llama_lora_adapter_set(ctx, la.ptr, la.scale); llama_set_adapter_lora(ctx, la.ptr, la.scale);
} }
} }
} }
@ -1559,21 +1563,23 @@ std::vector<llama_token> common_tokenize(
const std::string & text, const std::string & text,
bool add_special, bool add_special,
bool parse_special) { bool parse_special) {
return common_tokenize(llama_get_model(ctx), text, add_special, parse_special); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
return common_tokenize(vocab, text, add_special, parse_special);
} }
std::vector<llama_token> common_tokenize( std::vector<llama_token> common_tokenize(
const struct llama_model * model, const struct llama_vocab * vocab,
const std::string & text, const std::string & text,
bool add_special, bool add_special,
bool parse_special) { bool parse_special) {
// upper limit for the number of tokens // upper limit for the number of tokens
int n_tokens = text.length() + 2 * add_special; int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens); std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens < 0) { if (n_tokens < 0) {
result.resize(-n_tokens); result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
GGML_ASSERT(check == -n_tokens); GGML_ASSERT(check == -n_tokens);
} else { } else {
result.resize(n_tokens); result.resize(n_tokens);
@ -1582,12 +1588,18 @@ std::vector<llama_token> common_tokenize(
} }
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
return common_token_to_piece(vocab, token, special);
}
std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) {
std::string piece; std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) { if (n_chars < 0) {
piece.resize(-n_chars); piece.resize(-n_chars);
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars); GGML_ASSERT(check == -n_chars);
} }
else { else {
@ -1597,13 +1609,19 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token
return piece; return piece;
} }
std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) { std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
return common_detokenize(vocab, tokens, special);
}
std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
std::string text; std::string text;
text.resize(std::max(text.capacity(), tokens.size())); text.resize(std::max(text.capacity(), tokens.size()));
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
if (n_chars < 0) { if (n_chars < 0) {
text.resize(-n_chars); text.resize(-n_chars);
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
} }
@ -1631,7 +1649,7 @@ std::string common_get_builtin_chat_template(const struct llama_model * model) {
bool common_chat_verify_template(const std::string & tmpl) { bool common_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0; return res >= 0;
} }
@ -1642,16 +1660,16 @@ std::string common_chat_apply_template(const struct llama_model * model,
int alloc_size = 0; int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat; std::vector<llama_chat_message> chat;
for (auto & msg : msgs) { for (const auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()}); chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25; alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
} }
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
std::vector<char> buf(alloc_size); std::vector<char> buf(alloc_size);
// run the first time to get the total output length // run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported // error: chat template is not supported
if (res < 0) { if (res < 0) {
@ -1659,18 +1677,17 @@ std::string common_chat_apply_template(const struct llama_model * model,
// if the custom "tmpl" is not supported, we throw an error // if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported"); throw std::runtime_error("this custom template is not supported");
} else {
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
} }
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
} }
// if it turns out that our buffer is too small, we resize it // if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) { if ((size_t) res > buf.size()) {
buf.resize(res); buf.resize(res);
res = llama_chat_apply_template( res = llama_chat_apply_template(
fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl, fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size()); chat.data(), chat.size(), add_ass, buf.data(), buf.size());
} }

View File

@ -24,11 +24,11 @@
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf" #define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
struct common_lora_adapter_info { struct common_adapter_lora_info {
std::string path; std::string path;
float scale; float scale;
struct llama_lora_adapter * ptr; struct llama_adapter_lora * ptr;
}; };
using llama_tokens = std::vector<llama_token>; using llama_tokens = std::vector<llama_token>;
@ -246,8 +246,8 @@ struct common_params {
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides; std::vector<llama_model_kv_override> kv_overrides;
bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_lora_adapter_apply) bool lora_init_without_apply = false; // only load lora to memory, but do not apply it to ctx (user can manually apply lora later using llama_adapter_lora_apply)
std::vector<common_lora_adapter_info> lora_adapters; // lora adapter path with user defined scale std::vector<common_adapter_lora_info> lora_adapters; // lora adapter path with user defined scale
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
@ -481,7 +481,7 @@ struct common_init_result {
llama_model_ptr model; llama_model_ptr model;
llama_context_ptr context; llama_context_ptr context;
std::vector<llama_lora_adapter_ptr> lora; std::vector<llama_adapter_lora_ptr> lora;
}; };
struct common_init_result common_init_from_params(common_params & params); struct common_init_result common_init_from_params(common_params & params);
@ -503,7 +503,7 @@ struct llama_model * common_load_model_from_hf(
const struct llama_model_params & params); const struct llama_model_params & params);
// clear LoRA adapters from context, then apply new list of adapters // clear LoRA adapters from context, then apply new list of adapters
void common_lora_adapters_apply(struct llama_context * ctx, std::vector<common_lora_adapter_info> & lora); void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);
// //
// Batch utils // Batch utils
@ -541,7 +541,7 @@ std::vector<llama_token> common_tokenize(
bool parse_special = false); bool parse_special = false);
std::vector<llama_token> common_tokenize( std::vector<llama_token> common_tokenize(
const struct llama_model * model, const struct llama_vocab * vocab,
const std::string & text, const std::string & text,
bool add_special, bool add_special,
bool parse_special = false); bool parse_special = false);
@ -553,11 +553,21 @@ std::string common_token_to_piece(
llama_token token, llama_token token,
bool special = true); bool special = true);
std::string common_token_to_piece(
const struct llama_vocab * vocab,
llama_token token,
bool special = true);
// detokenizes a vector of tokens into a string // detokenizes a vector of tokens into a string
// should work similar to Python's `tokenizer.decode` // should work similar to Python's `tokenizer.decode`
// optionally renders special/control tokens // optionally renders special/control tokens
std::string common_detokenize( std::string common_detokenize(
llama_context * ctx, const struct llama_context * ctx,
const std::vector<llama_token> & tokens,
bool special = true);
std::string common_detokenize(
const struct llama_vocab * vocab,
const std::vector<llama_token> & tokens, const std::vector<llama_token> & tokens,
bool special = true); bool special = true);

View File

@ -113,7 +113,10 @@ struct common_sampler {
void set_logits(struct llama_context * ctx, int idx) { void set_logits(struct llama_context * ctx, int idx) {
const auto * logits = llama_get_logits_ith(ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
cur.resize(n_vocab); cur.resize(n_vocab);
@ -142,13 +145,15 @@ std::string common_params_sampling::print() const {
} }
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
lparams.no_perf = params.no_perf; lparams.no_perf = params.no_perf;
auto * result = new common_sampler { auto * result = new common_sampler {
/* .params = */ params, /* .params = */ params,
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), /* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"),
/* .chain = */ llama_sampler_chain_init(lparams), /* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)), /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {}, /* .cur = */ {},
@ -157,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias( llama_sampler_init_logit_bias(
llama_n_vocab(model), llama_vocab_n_tokens(vocab),
params.logit_bias.size(), params.logit_bias.size(),
params.logit_bias.data())); params.logit_bias.data()));
@ -172,7 +177,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
c_breakers.push_back(str.c_str()); c_breakers.push_back(str.c_str());
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break; break;
case COMMON_SAMPLER_TYPE_TOP_K: case COMMON_SAMPLER_TYPE_TOP_K:
@ -194,7 +199,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break; break;
case COMMON_SAMPLER_TYPE_INFILL: case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
break; break;
case COMMON_SAMPLER_TYPE_PENALTIES: case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
@ -206,7 +211,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) { } else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_tokens(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) { } else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));

View File

@ -79,10 +79,13 @@ bool common_speculative_are_compatible(
const struct llama_model * model_tgt = llama_get_model(ctx_tgt); const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
const struct llama_model * model_dft = llama_get_model(ctx_dft); const struct llama_model * model_dft = llama_get_model(ctx_dft);
const bool vocab_type_tgt = llama_vocab_type(model_tgt); const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft); const bool vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) { if (vocab_type_tgt != vocab_type_dft) {
@ -91,34 +94,34 @@ bool common_speculative_are_compatible(
return false; return false;
} }
if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) || llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)) { llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) {
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_tgt), llama_add_bos_token(model_tgt), llama_token_eos(model_tgt), llama_add_eos_token(model_tgt)); LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_tgt));
LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_dft), llama_add_bos_token(model_dft), llama_token_eos(model_dft), llama_add_eos_token(model_dft)); LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_get_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_get_add_eos(vocab_dft));
return false; return false;
} }
{ {
const int n_vocab_tgt = llama_n_vocab(model_tgt); const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft); const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
__func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); __func__, n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false; return false;
} }
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i); const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i); const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) { if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft model vocab must match target model to use speculation but " LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
"token %d content differs - target '%s', draft '%s'\n", __func__, i, "token %d content differs - target '%s', draft '%s'\n", __func__, i,
common_token_to_piece(ctx_tgt, i).c_str(), common_token_to_piece(ctx_tgt, i).c_str(),
common_token_to_piece(ctx_dft, i).c_str()); common_token_to_piece(ctx_dft, i).c_str());

View File

@ -50,7 +50,7 @@ int main(int argc, char ** argv) {
// ensure enough sequences are available // ensure enough sequences are available
ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end()); ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end());
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);

View File

@ -23,12 +23,12 @@ defer {
} }
let model_params = llama_model_default_params() let model_params = llama_model_default_params()
guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), model_params) else { guard let model = llama_model_load_from_file(modelPath.cString(using: .utf8), model_params) else {
print("Failed to load model") print("Failed to load model")
exit(1) exit(1)
} }
defer { defer {
llama_free_model(model) llama_model_free(model)
} }
var tokens = tokenize(text: prompt, add_bos: true) var tokens = tokenize(text: prompt, add_bos: true)
@ -141,7 +141,7 @@ while n_cur <= n_len {
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
// is it an end of stream? -> mark the stream as finished // is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len { if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
i_batch[i] = -1 i_batch[i] = -1
// print("") // print("")
if n_parallel > 1 { if n_parallel > 1 {

View File

@ -48,10 +48,12 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_model_get_vocab(model);
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> tokens_list; std::vector<llama_token> tokens_list;
tokens_list = common_tokenize(model, params.prompt, true); tokens_list = common_tokenize(vocab, params.prompt, true);
const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel; const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel;
@ -62,7 +64,7 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = n_kv_req; ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_predict, n_parallel); ctx_params.n_batch = std::max(n_predict, n_parallel);
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
auto sparams = llama_sampler_chain_default_params(); auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false; sparams.no_perf = false;
@ -121,7 +123,7 @@ int main(int argc, char ** argv) {
llama_token decoder_start_token_id = llama_model_decoder_start_token(model); llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) { if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = llama_token_bos(model); decoder_start_token_id = llama_vocab_bos(vocab);
} }
common_batch_clear(batch); common_batch_clear(batch);
@ -174,7 +176,7 @@ int main(int argc, char ** argv) {
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
// is it an end of generation? -> mark the stream as finished // is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
i_batch[i] = -1; i_batch[i] = -1;
LOG("\n"); LOG("\n");
if (n_parallel > 1) { if (n_parallel > 1) {

View File

@ -911,7 +911,7 @@ int main(int argc, char ** argv) {
load_vocab(params.fn_vocab_model, &config, &vocab); load_vocab(params.fn_vocab_model, &config, &vocab);
struct my_llama_model model; struct my_llama_model model;
model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx); model.hparams.n_vocab = config.vocab_size; //llama_vocab_n_vocab(lctx);
model.hparams.n_ctx = params.n_ctx; model.hparams.n_ctx = params.n_ctx;
model.hparams.n_embd = config.dim; //params.n_embd; model.hparams.n_embd = config.dim; //params.n_embd;
model.hparams.n_ff = config.hidden_dim; model.hparams.n_ff = config.hidden_dim;

View File

@ -273,7 +273,9 @@ struct tokenized_prompt {
size_t max_seq_len; size_t max_seq_len;
tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
tokens_pos = common_tokenize(ctx, pos, add_bos, true); tokens_pos = common_tokenize(ctx, pos, add_bos, true);
tokens_neg = common_tokenize(ctx, neg, add_bos, true); tokens_neg = common_tokenize(ctx, neg, add_bos, true);
max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size());
@ -421,8 +423,8 @@ int main(int argc, char ** argv) {
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
// int n_ctx = llama_n_ctx(ctx); // int n_ctx = llama_n_ctx(ctx);
int n_layers = llama_n_layer(model); int n_layers = llama_model_n_layer(model);
int n_embd = llama_n_embd(model); int n_embd = llama_model_n_embd(model);
// get model hint param (a.k.a model arch name) // get model hint param (a.k.a model arch name)
char model_hint[128]; char model_hint[128];

View File

@ -105,7 +105,9 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int n_ctx_train = llama_n_ctx_train(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
@ -148,7 +150,7 @@ int main(int argc, char ** argv) {
// check if the last token is SEP // check if the last token is SEP
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
for (auto & inp : inputs) { for (auto & inp : inputs) {
if (inp.empty() || inp.back() != llama_token_sep(model)) { if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) {
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__); LOG_WRN("%s: last token in the prompt is not SEP\n", __func__);
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
} }
@ -181,7 +183,7 @@ int main(int argc, char ** argv) {
} }
// allocate output // allocate output
const int n_embd = llama_n_embd(model); const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_embd_count * n_embd, 0); std::vector<float> embeddings(n_embd_count * n_embd, 0);
float * emb = embeddings.data(); float * emb = embeddings.data();

View File

@ -127,7 +127,10 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
} }
static bool run(llama_context * ctx, const common_params & params) { static bool run(llama_context * ctx, const common_params & params) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos); std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);

View File

@ -8,7 +8,6 @@
#include <map> #include <map>
#include <vector> #include <vector>
#include <string> #include <string>
#include <thread>
#include <fstream> #include <fstream>
static bool g_verbose = false; static bool g_verbose = false;
@ -130,7 +129,7 @@ struct lora_merge_ctx {
lora_merge_ctx( lora_merge_ctx(
std::string & base_fname, std::string & base_fname,
std::vector<common_lora_adapter_info> & lora_files, std::vector<common_adapter_lora_info> & lora_files,
std::string & outfile, std::string & outfile,
int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) { int n_threads) : base_model(base_fname, 0), n_threads(n_threads), fout(outfile, std::ios::binary) {
fout.exceptions(std::ofstream::failbit); // fail fast on write errors fout.exceptions(std::ofstream::failbit); // fail fast on write errors

View File

@ -11,6 +11,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
std::vector<std::vector<float>> result; std::vector<std::vector<float>> result;
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
@ -19,16 +20,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
const std::string input_string = instruction + sentences[i]; const std::string input_string = instruction + sentences[i];
std::vector<llama_token> inputs = common_tokenize(model, input_string, true, false); std::vector<llama_token> inputs = common_tokenize(vocab, input_string, true, false);
const int32_t n_toks = inputs.size(); const int32_t n_toks = inputs.size();
// GritLM seems to have EOS = "" // GritLM seems to have EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
// inputs.push_back(llama_token_eos(model)); // inputs.push_back(llama_vocab_eos(vocab));
// we want to ignore instruction tokens for mean pooling // we want to ignore instruction tokens for mean pooling
const int32_t n_inst = common_tokenize(model, instruction, true, false).size(); const int32_t n_inst = common_tokenize(vocab, instruction, true, false).size();
#ifdef GRIT_DEBUG #ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample // debug tokens - should be matching as referenced in the GritLM sample
@ -52,7 +53,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
llama_decode(ctx, batch); llama_decode(ctx, batch);
// get embedding dimensions // get embedding dimensions
uint64_t n_embd = llama_n_embd(model); uint64_t n_embd = llama_model_n_embd(model);
// allocate embedding output // allocate embedding output
std::vector<float> emb_unorm(n_embd, 0.0f); std::vector<float> emb_unorm(n_embd, 0.0f);
@ -97,7 +98,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
std::string result; std::string result;
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_token eos_token = llama_vocab_eos(vocab);
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, false); llama_set_embeddings(ctx, false);
@ -105,7 +108,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = common_tokenize(model, prompt, false, true); std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
int32_t i_current_token = 0; int32_t i_current_token = 0;
while (true) { while (true) {
@ -168,7 +171,7 @@ int main(int argc, char * argv[]) {
llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams); llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams);
// create generation context // create generation context
llama_context * ctx = llama_new_context_with_model(model, cparams); llama_context * ctx = llama_init_from_model(model, cparams);
auto sparams = llama_sampler_chain_default_params(); auto sparams = llama_sampler_chain_default_params();
@ -197,7 +200,7 @@ int main(int argc, char * argv[]) {
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
const int n_embd = llama_n_embd(model); const int n_embd = llama_model_n_embd(model);
const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);

View File

@ -7,7 +7,6 @@
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <ctime> #include <ctime>
#include <sstream>
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#include <vector> #include <vector>
@ -40,7 +39,7 @@ public:
void set_params(common_params params) { m_params = std::move(params); } void set_params(common_params params) { m_params = std::move(params); }
bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data);
void save_imatrix(int ncall = -1) const; void save_imatrix(int ncall = -1) const;
bool load_imatrix(const char * file_name); bool load_imatrix(const char * fname);
private: private:
std::unordered_map<std::string, Stats> m_stats; std::unordered_map<std::string, Stats> m_stats;
common_params m_params; common_params m_params;
@ -429,10 +428,13 @@ static void process_logits(
} }
static bool compute_imatrix(llama_context * ctx, const common_params & params) { static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
auto tim1 = std::chrono::high_resolution_clock::now(); auto tim1 = std::chrono::high_resolution_clock::now();
LOG_INF("%s: tokenizing the input ..\n", __func__); LOG_INF("%s: tokenizing the input ..\n", __func__);
@ -468,7 +470,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const int n_chunk_max = tokens.size() / n_ctx; const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_vocab_n_tokens(vocab);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
int count = 0; int count = 0;
@ -508,7 +510,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (add_bos && j == 0) { if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); tokens[batch_start] = llama_vocab_bos(vocab);
} }
common_batch_clear(batch); common_batch_clear(batch);
@ -627,7 +629,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_model_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) { if (params.n_ctx > n_ctx_train) {
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx); __func__, n_ctx_train, params.n_ctx);

View File

@ -139,7 +139,9 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int n_ctx_train = llama_n_ctx_train(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
LOG_DBG("n_ctx: %d\n", n_ctx); LOG_DBG("n_ctx: %d\n", n_ctx);
@ -152,28 +154,28 @@ int main(int argc, char ** argv) {
LOG_INF("\n"); LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str()); LOG_INF("%s\n", common_params_get_system_info(params).c_str());
} }
const bool add_bos = llama_add_bos_token(model); const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_add_eos_token(model)); GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
std::vector<llama_token> embd_end; std::vector<llama_token> embd_end;
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false); std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false); std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);
GGML_ASSERT(llama_token_fim_pre(model) >= 0); GGML_ASSERT(llama_vocab_fim_pre(vocab) >= 0);
GGML_ASSERT(llama_token_fim_suf(model) >= 0); GGML_ASSERT(llama_vocab_fim_suf(vocab) >= 0);
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model)); inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab));
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model)); inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab));
embd_inp = params.spm_infill ? inp_sfx : inp_pfx; embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx; embd_end = params.spm_infill ? inp_pfx : inp_sfx;
if (add_bos) { if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
} }
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
const llama_token middle_token = llama_token_fim_mid(model); const llama_token middle_token = llama_vocab_fim_mid(vocab);
if (middle_token >= 0) { if (middle_token >= 0) {
embd_inp.push_back(middle_token); embd_inp.push_back(middle_token);
} }
@ -185,7 +187,7 @@ int main(int argc, char ** argv) {
// Should not run without any tokens // Should not run without any tokens
if (embd_inp.empty()) { if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_vocab_bos(vocab));
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
} }
@ -420,10 +422,10 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs; // if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) { if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode // deal with eot token in infill mode
if ((common_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ if ((common_sampler_last(smpl) == llama_vocab_eot(vocab) || is_interacting) && params.interactive){
if (is_interacting && !params.interactive_first) { if (is_interacting && !params.interactive_first) {
// print an eot token // print an eot token
LOG("%s", common_token_to_piece(ctx, llama_token_eot(model)).c_str()); LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str());
} }
LOG("\n"); LOG("\n");
console::set_display(console::user_input); console::set_display(console::user_input);
@ -463,13 +465,13 @@ int main(int argc, char ** argv) {
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false); std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false); std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model)); inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab));
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model)); inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab));
embd_inp = params.spm_infill ? inp_sfx : inp_pfx; embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx; embd_end = params.spm_infill ? inp_pfx : inp_sfx;
if (add_bos) { if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
} }
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
@ -484,7 +486,7 @@ int main(int argc, char ** argv) {
is_interacting = false; is_interacting = false;
} }
// deal with end of generation tokens in interactive mode // deal with end of generation tokens in interactive mode
else if (llama_token_is_eog(model, common_sampler_last(smpl))) { else if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) {
LOG_DBG("found EOS token\n"); LOG_DBG("found EOS token\n");
if (params.interactive) { if (params.interactive) {
@ -500,7 +502,7 @@ int main(int argc, char ** argv) {
if (params.input_prefix_bos) { if (params.input_prefix_bos) {
LOG_DBG("adding input prefix BOS token\n"); LOG_DBG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_vocab_bos(vocab));
} }
std::string buffer; std::string buffer;
@ -563,7 +565,7 @@ int main(int argc, char ** argv) {
} }
// end of generation // end of generation
if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !params.interactive) { if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !params.interactive) {
break; break;
} }
@ -575,7 +577,7 @@ int main(int argc, char ** argv) {
} }
} }
if (!params.interactive && n_remain <= 0) { if (!params.interactive && n_remain <= 0) {
LOG("%s", common_token_to_piece(ctx, llama_token_eot(model)).c_str()); LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str());
} }
LOG("\n"); LOG("\n");

View File

@ -1401,7 +1401,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const int32_t n_vocab = llama_n_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
std::vector<llama_token> tokens(n_batch); std::vector<llama_token> tokens(n_batch);
@ -1409,7 +1410,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
while (n_processed < n_prompt) { while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch); int n_tokens = std::min(n_prompt - n_processed, n_batch);
tokens[0] = n_processed == 0 && llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; tokens[0] = n_processed == 0 && llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
for (int i = 1; i < n_tokens; i++) { for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab; tokens[i] = std::rand() % n_vocab;
} }
@ -1424,9 +1425,10 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const int32_t n_vocab = llama_n_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int32_t n_vocab = llama_vocab_n_tokens(vocab);
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) { for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1)); llama_decode(ctx, llama_batch_get_one(&token, 1));
@ -1537,7 +1539,7 @@ int main(int argc, char ** argv) {
prev_inst = &inst; prev_inst = &inst;
} }
llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams()); llama_context * ctx = llama_init_from_model(lmodel, inst.to_llama_cparams());
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str()); fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str());
llama_model_free(lmodel); llama_model_free(lmodel);

View File

@ -87,7 +87,7 @@ Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring fi
auto path_to_model = env->GetStringUTFChars(filename, 0); auto path_to_model = env->GetStringUTFChars(filename, 0);
LOGi("Loading model from %s", path_to_model); LOGi("Loading model from %s", path_to_model);
auto model = llama_load_model_from_file(path_to_model, model_params); auto model = llama_model_load_from_file(path_to_model, model_params);
env->ReleaseStringUTFChars(filename, path_to_model); env->ReleaseStringUTFChars(filename, path_to_model);
if (!model) { if (!model) {
@ -102,7 +102,7 @@ Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring fi
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
llama_free_model(reinterpret_cast<llama_model *>(model)); llama_model_free(reinterpret_cast<llama_model *>(model));
} }
extern "C" extern "C"
@ -405,6 +405,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer); const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
const auto model = llama_get_model(context); const auto model = llama_get_model(context);
const auto vocab = llama_model_get_vocab(model);
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
@ -414,7 +415,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
const auto new_token_id = llama_sampler_sample(sampler, context, -1); const auto new_token_id = llama_sampler_sample(sampler, context, -1);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
return nullptr; return nullptr;
} }

View File

@ -52,8 +52,8 @@ actor LlamaContext {
deinit { deinit {
llama_sampler_free(sampling) llama_sampler_free(sampling)
llama_batch_free(batch) llama_batch_free(batch)
llama_model_free(model)
llama_free(context) llama_free(context)
llama_free_model(model)
llama_backend_free() llama_backend_free()
} }
@ -65,7 +65,7 @@ actor LlamaContext {
model_params.n_gpu_layers = 0 model_params.n_gpu_layers = 0
print("Running on simulator, force use n_gpu_layers = 0") print("Running on simulator, force use n_gpu_layers = 0")
#endif #endif
let model = llama_load_model_from_file(path, model_params) let model = llama_model_load_from_file(path, model_params)
guard let model else { guard let model else {
print("Could not load model at \(path)") print("Could not load model at \(path)")
throw LlamaError.couldNotInitializeContext throw LlamaError.couldNotInitializeContext
@ -151,7 +151,7 @@ actor LlamaContext {
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
if llama_token_is_eog(model, new_token_id) || n_cur == n_len { if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
print("\n") print("\n")
is_done = true is_done = true
let new_token_str = String(cString: temporary_invalid_cchars + [0]) let new_token_str = String(cString: temporary_invalid_cchars + [0])

View File

@ -47,8 +47,12 @@ static const char * sample(struct common_sampler * smpl,
int * n_past) { int * n_past) {
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
common_sampler_accept(smpl, id, true); common_sampler_accept(smpl, id, true);
const llama_model * model = llama_get_model(ctx_llama);
const llama_vocab * vocab = llama_model_get_vocab(model);
static std::string ret; static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { if (llama_vocab_is_eog(vocab, id)) {
ret = "</s>"; ret = "</s>";
} else { } else {
ret = common_token_to_piece(ctx_llama, id); ret = common_token_to_piece(ctx_llama, id);
@ -239,11 +243,10 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
llama_context_params ctx_params = common_context_params_to_llama(*params); llama_context_params ctx_params = common_context_params_to_llama(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
if (ctx_llama == NULL) { if (ctx_llama == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__); LOG_ERR("%s: failed to create the llama_context\n" , __func__);

View File

@ -384,7 +384,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
// make sure that the correct mmproj was used, i.e., compare apples to apples // make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama));
auto n_image_embd = clip_n_mmproj_embd(ctx_clip); auto n_image_embd = clip_n_mmproj_embd(ctx_clip);
if (n_image_embd != n_llama_embd) { if (n_image_embd != n_llama_embd) {
LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
@ -456,7 +456,7 @@ struct llava_embd_batch {
}; };
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
int n_eval = image_embed->n_image_pos - i; int n_eval = image_embed->n_image_pos - i;

View File

@ -54,7 +54,7 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
ctx_params.n_ctx = params->n_ctx; ctx_params.n_ctx = params->n_ctx;
} }
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
if (ctx_llama == NULL) { if (ctx_llama == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__); LOG_ERR("%s: failed to create the llama_context\n" , __func__);
@ -167,8 +167,12 @@ static const char * sample(struct common_sampler * smpl,
int * n_past) { int * n_past) {
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
common_sampler_accept(smpl, id, true); common_sampler_accept(smpl, id, true);
const llama_model * model = llama_get_model(ctx_llama);
const llama_vocab * vocab = llama_model_get_vocab(model);
static std::string ret; static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { if (llama_vocab_is_eog(vocab, id)) {
ret = "</s>"; ret = "</s>";
} else { } else {
ret = common_token_to_piece(ctx_llama, id); ret = common_token_to_piece(ctx_llama, id);

View File

@ -27,7 +27,7 @@
static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed,
int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) { int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) {
int n_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_embd = llama_model_n_embd(llama_get_model(ctx_llama));
const int patch_size = 14 * 2; const int patch_size = 14 * 2;
const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0); const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0);
const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0); const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0);
@ -132,8 +132,12 @@ static const char * sample(struct common_sampler * smpl,
int * n_past, int * st_pos_id) { int * n_past, int * st_pos_id) {
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
common_sampler_accept(smpl, id, true); common_sampler_accept(smpl, id, true);
const llama_model * model = llama_get_model(ctx_llama);
const llama_vocab * vocab = llama_model_get_vocab(model);
static std::string ret; static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { if (llama_vocab_is_eog(vocab, id)) {
ret = "</s>"; ret = "</s>";
} else { } else {
ret = common_token_to_piece(ctx_llama, id); ret = common_token_to_piece(ctx_llama, id);
@ -328,11 +332,10 @@ static struct llava_context * llava_init_context(common_params * params, llama_m
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
llama_context_params ctx_params = common_context_params_to_llama(*params); llama_context_params ctx_params = common_context_params_to_llama(*params);
ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings
llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); llama_context * ctx_llama = llama_init_from_model(model, ctx_params);
if (ctx_llama == NULL) { if (ctx_llama == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__); LOG_ERR("%s: failed to create the llama_context\n" , __func__);
@ -481,7 +484,7 @@ static void debug_test_mrope_2d() {
} }
static void debug_dump_img_embed(struct llava_context * ctx_llava) { static void debug_dump_img_embed(struct llava_context * ctx_llava) {
int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama)); int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama));
int ne = n_embd * 4; int ne = n_embd * 4;
float vals[56 * 56 * 3]; float vals[56 * 56 * 3];
// float embd[ne]; // float embd[ne];

View File

@ -61,6 +61,8 @@ int main(int argc, char ** argv) {
llama_model * model = llama_init.model.get(); llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model);
// Tokenize the prompt // Tokenize the prompt
std::vector<llama_token> inp; std::vector<llama_token> inp;
std::vector<llama_token> all; std::vector<llama_token> all;
@ -147,7 +149,7 @@ int main(int argc, char ** argv) {
} }
// here we keep adding new n-grams as we go // here we keep adding new n-grams as we go
ngram_container ngrams_observed(llama_n_vocab(model), N, G); ngram_container ngrams_observed(llama_vocab_n_tokens(vocab), N, G);
// debug // debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1); struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
@ -297,7 +299,7 @@ int main(int argc, char ** argv) {
} }
fflush(stdout); fflush(stdout);
if (llama_token_is_eog(model, id)) { if (llama_vocab_is_eog(vocab, id)) {
has_eos = true; has_eos = true;
} }

View File

@ -36,6 +36,8 @@ int main(int argc, char ** argv){
llama_model * model = llama_init.model.get(); llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model);
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> inp; std::vector<llama_token> inp;
inp = common_tokenize(ctx, params.prompt, true, true); inp = common_tokenize(ctx, params.prompt, true, true);
@ -136,7 +138,7 @@ int main(int argc, char ** argv){
LOG("%s", token_str.c_str()); LOG("%s", token_str.c_str());
} }
if (llama_token_is_eog(model, id)) { if (llama_vocab_is_eog(vocab, id)) {
has_eos = true; has_eos = true;
} }

View File

@ -5,7 +5,6 @@
#include "sampling.h" #include "sampling.h"
#include "llama.h" #include "llama.h"
#include <cassert>
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <ctime> #include <ctime>
@ -163,6 +162,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_model_get_vocab(model);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
@ -196,7 +197,7 @@ int main(int argc, char ** argv) {
llama_attach_threadpool(ctx, threadpool, threadpool_batch); llama_attach_threadpool(ctx, threadpool, threadpool_batch);
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
if (n_ctx > n_ctx_train) { if (n_ctx > n_ctx_train) {
@ -241,9 +242,9 @@ int main(int argc, char ** argv) {
} }
} }
const bool add_bos = llama_add_bos_token(model); const bool add_bos = llama_vocab_get_add_bos(vocab);
if (!llama_model_has_encoder(model)) { if (!llama_model_has_encoder(model)) {
GGML_ASSERT(!llama_add_eos_token(model)); GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
} }
LOG_DBG("n_ctx: %d, add_bos: %d\n", n_ctx, add_bos); LOG_DBG("n_ctx: %d, add_bos: %d\n", n_ctx, add_bos);
@ -269,7 +270,7 @@ int main(int argc, char ** argv) {
// Should not run without any tokens // Should not run without any tokens
if (embd_inp.empty()) { if (embd_inp.empty()) {
if (add_bos) { if (add_bos) {
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_vocab_bos(vocab));
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
} else { } else {
LOG_ERR("input is empty\n"); LOG_ERR("input is empty\n");
@ -495,7 +496,7 @@ int main(int argc, char ** argv) {
llama_token decoder_start_token_id = llama_model_decoder_start_token(model); llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) { if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = llama_token_bos(model); decoder_start_token_id = llama_vocab_bos(vocab);
} }
embd_inp.clear(); embd_inp.clear();
@ -742,7 +743,7 @@ int main(int argc, char ** argv) {
} }
// deal with end of generation tokens in interactive mode // deal with end of generation tokens in interactive mode
if (llama_token_is_eog(model, common_sampler_last(smpl))) { if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) {
LOG_DBG("found an EOG token\n"); LOG_DBG("found an EOG token\n");
if (params.interactive) { if (params.interactive) {
@ -776,7 +777,7 @@ int main(int argc, char ** argv) {
if (params.input_prefix_bos) { if (params.input_prefix_bos) {
LOG_DBG("adding input prefix BOS token\n"); LOG_DBG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_vocab_bos(vocab));
} }
std::string buffer; std::string buffer;
@ -830,8 +831,8 @@ int main(int argc, char ** argv) {
// if user stop generation mid-way, we must add EOT to finish model's last response // if user stop generation mid-way, we must add EOT to finish model's last response
if (need_insert_eot && format_chat) { if (need_insert_eot && format_chat) {
llama_token eot = llama_token_eot(model); llama_token eot = llama_vocab_eot(vocab);
embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_token_eos(model) : eot); embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_vocab_eos(vocab) : eot);
need_insert_eot = false; need_insert_eot = false;
} }
@ -866,7 +867,7 @@ int main(int argc, char ** argv) {
} }
// end of generation // end of generation
if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !(params.interactive)) { if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !(params.interactive)) {
LOG(" [end of text]\n"); LOG(" [end of text]\n");
break; break;
} }

View File

@ -135,6 +135,8 @@ int main(int argc, char ** argv) {
llama_model * model = llama_init.model.get(); llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model);
// load the prompts from an external file if there are any // load the prompts from an external file if there are any
if (params.prompt.empty()) { if (params.prompt.empty()) {
LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@ -358,7 +360,7 @@ int main(int argc, char ** argv) {
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
if (client.n_decoded > 2 && if (client.n_decoded > 2 &&
(llama_token_is_eog(model, id) || (llama_vocab_is_eog(vocab, id) ||
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
client.response.find("User:") != std::string::npos || client.response.find("User:") != std::string::npos ||
client.response.find('\n') != std::string::npos)) { client.response.find('\n') != std::string::npos)) {

View File

@ -70,15 +70,17 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_model_get_vocab(model);
// initialize the context // initialize the context
llama_context_params ctx_params = common_context_params_to_llama(params); llama_context_params ctx_params = common_context_params_to_llama(params);
ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; ctx_params.n_ctx = llama_model_n_ctx_train(model)*n_grp + n_keep;
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) { if (ctx == NULL) {
LOG_ERR("%s: failed to create the llama_context\n" , __func__); LOG_ERR("%s: failed to create the llama_context\n" , __func__);
return 1; return 1;
@ -223,7 +225,7 @@ int main(int argc, char ** argv) {
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) {
LOG("\n"); LOG("\n");
break; break;

View File

@ -296,8 +296,11 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
LOG_INF("%s: tokenizing the input ..\n", __func__); LOG_INF("%s: tokenizing the input ..\n", __func__);
@ -338,7 +341,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_vocab_n_tokens(vocab);
int count = 0; int count = 0;
double nll = 0.0; double nll = 0.0;
@ -382,7 +385,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (add_bos && j == 0) { if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); tokens[batch_start] = llama_vocab_bos(vocab);
} }
const auto * batch_logits = llama_get_logits(ctx); const auto * batch_logits = llama_get_logits(ctx);
@ -444,8 +447,11 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
std::ofstream logits_stream; std::ofstream logits_stream;
if (!params.logits_file.empty()) { if (!params.logits_file.empty()) {
@ -485,7 +491,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_vocab_n_tokens(vocab);
int count = 0; int count = 0;
double nll = 0.0; double nll = 0.0;
@ -557,7 +563,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (add_bos && j == 0) { if (add_bos && j == 0) {
tokens[seq_start] = llama_token_bos(llama_get_model(ctx)); tokens[seq_start] = llama_vocab_bos(vocab);
} }
for (int k = 0; k < batch_size; ++k) { for (int k = 0; k < batch_size; ++k) {
@ -732,6 +738,9 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
} }
static void hellaswag_score(llama_context * ctx, const common_params & params) { static void hellaswag_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
// Calculates hellaswag score (acc_norm) from prompt // Calculates hellaswag score (acc_norm) from prompt
// //
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
@ -765,7 +774,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
size_t hs_task_count = prompt_lines.size()/6; size_t hs_task_count = prompt_lines.size()/6;
LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count); LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; const bool is_spm = llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_SPM;
LOG_INF("================================= is_spm = %d\n", is_spm); LOG_INF("================================= is_spm = %d\n", is_spm);
// The tasks should be randomized so the score stabilizes quickly. // The tasks should be randomized so the score stabilizes quickly.
@ -848,7 +857,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_vocab_n_tokens(vocab);
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1072,6 +1081,8 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
* *
*/ */
static void winogrande_score(llama_context * ctx, const common_params & params) { static void winogrande_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
constexpr int k_min_trailing_ctx = 3; constexpr int k_min_trailing_ctx = 3;
@ -1130,7 +1141,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_vocab_n_tokens(vocab);
const int max_tasks_per_batch = 128; const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1374,6 +1385,8 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic
// https://huggingface.co/datasets/truthful_qa // https://huggingface.co/datasets/truthful_qa
// //
static void multiple_choice_score(llama_context * ctx, const common_params & params) { static void multiple_choice_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
std::istringstream strstream(params.prompt); std::istringstream strstream(params.prompt);
uint32_t n_task; uint32_t n_task;
@ -1482,7 +1495,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_vocab_n_tokens(vocab);
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1655,6 +1668,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
} }
static void kl_divergence(llama_context * ctx, const common_params & params) { static void kl_divergence(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
if (params.logits_file.empty()) { if (params.logits_file.empty()) {
LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__); LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
return; return;
@ -1688,8 +1704,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str()); LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
return; return;
} }
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) { if (n_vocab != llama_vocab_n_tokens(vocab)) {
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx))); LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_vocab_n_tokens(vocab));
} }
std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk); std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
@ -1701,8 +1717,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int num_batches = (n_ctx + n_batch - 1)/n_batch; const int num_batches = (n_ctx + n_batch - 1)/n_batch;
const int nv = 2*((n_vocab + 1)/2) + 4; const int nv = 2*((n_vocab + 1)/2) + 4;
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_vocab_get_add_bos(vocab);
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); GGML_ASSERT(!llama_vocab_get_add_eos(vocab));
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
@ -1761,7 +1777,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (add_bos && j == 0) { if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); tokens[batch_start] = llama_vocab_bos(vocab);
} }
common_batch_clear(batch); common_batch_clear(batch);
@ -1995,7 +2011,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_model_n_ctx_train(model);
if (params.n_ctx > n_ctx_train) { if (params.n_ctx > n_ctx_train) {
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n",

View File

@ -319,7 +319,7 @@ int main(int argc, char ** argv) {
auto cparams = llama_context_default_params(); auto cparams = llama_context_default_params();
cparams.n_ctx = 256; cparams.n_ctx = 256;
ctx = llama_new_context_with_model(model, cparams); ctx = llama_init_from_model(model, cparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());

View File

@ -159,7 +159,9 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int n_ctx_train = llama_n_ctx_train(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_ctx_train = llama_model_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
@ -192,8 +194,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
// add eos if not present // add eos if not present
if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) { if (llama_vocab_eos(vocab) >= 0 && (inp.empty() || inp.back() != llama_vocab_eos(vocab))) {
inp.push_back(llama_token_eos(model)); inp.push_back(llama_vocab_eos(vocab));
} }
chunk.tokens = inp; chunk.tokens = inp;
} }
@ -215,7 +217,7 @@ int main(int argc, char ** argv) {
struct llama_batch batch = llama_batch_init(n_batch, 0, 1); struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// allocate output // allocate output
const int n_embd = llama_n_embd(model); const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_chunks * n_embd, 0); std::vector<float> embeddings(n_chunks * n_embd, 0);
float * emb = embeddings.data(); float * emb = embeddings.data();

View File

@ -685,7 +685,7 @@ class LlamaData {
// Initializes the context with the specified parameters // Initializes the context with the specified parameters
llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) {
llama_context_ptr context(llama_new_context_with_model(model.get(), opt.ctx_params)); llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params));
if (!context) { if (!context) {
printe("%s: error: failed to create the llama_context\n", __func__); printe("%s: error: failed to create the llama_context\n", __func__);
} }
@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
// Function to apply the chat template and resize `formatted` if needed // Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(LlamaData & llama_data, const bool append) { static int apply_chat_template(LlamaData & llama_data, const bool append) {
int result = llama_chat_apply_template( int result = llama_chat_apply_template(
llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append, llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) { if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result); llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size()); llama_data.fmtted.size());
} }
@ -726,11 +726,11 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
} }
// Function to tokenize the prompt // Function to tokenize the prompt
static int tokenize_prompt(const llama_model_ptr & model, const std::string & prompt, static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) { std::vector<llama_token> & prompt_tokens) {
const int n_prompt_tokens = -llama_tokenize(model.get(), prompt.c_str(), prompt.size(), NULL, 0, true, true); const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
prompt_tokens.resize(n_prompt_tokens); prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
true) < 0) { true) < 0) {
printe("failed to tokenize the prompt\n"); printe("failed to tokenize the prompt\n");
return -1; return -1;
@ -753,9 +753,9 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch &
} }
// convert the token to a string // convert the token to a string
static int convert_token_to_string(const llama_model_ptr & model, const llama_token token_id, std::string & piece) { static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) {
char buf[256]; char buf[256];
int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true);
if (n < 0) { if (n < 0) {
printe("failed to convert token to piece\n"); printe("failed to convert token to piece\n");
return 1; return 1;
@ -773,8 +773,10 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st
// helper function to evaluate a prompt and generate a response // helper function to evaluate a prompt and generate a response
static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) {
const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get());
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
if (tokenize_prompt(llama_data.model, prompt, tokens) < 0) { if (tokenize_prompt(vocab, prompt, tokens) < 0) {
return 1; return 1;
} }
@ -790,12 +792,12 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
// sample the next token, check is it an end of generation? // sample the next token, check is it an end of generation?
new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1);
if (llama_token_is_eog(llama_data.model.get(), new_token_id)) { if (llama_vocab_is_eog(vocab, new_token_id)) {
break; break;
} }
std::string piece; std::string piece;
if (convert_token_to_string(llama_data.model, new_token_id, piece)) { if (convert_token_to_string(vocab, new_token_id, piece)) {
return 1; return 1;
} }

View File

@ -97,7 +97,7 @@ int main(int argc, char ** argv) {
printf("\n\n"); printf("\n\n");
// make new context // make new context
llama_context * ctx2 = llama_new_context_with_model(model, common_context_params_to_llama(params)); llama_context * ctx2 = llama_init_from_model(model, common_context_params_to_llama(params));
llama_sampler * smpl2 = llama_sampler_chain_init(sparams); llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
@ -154,7 +154,7 @@ int main(int argc, char ** argv) {
} }
// make new context // make new context
llama_context * ctx3 = llama_new_context_with_model(model, common_context_params_to_llama(params)); llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params));
llama_sampler * smpl3 = llama_sampler_chain_init(sparams); llama_sampler * smpl3 = llama_sampler_chain_init(sparams);

View File

@ -98,7 +98,7 @@ struct slot_params {
int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_prompt_ms = -1; // TODO: implement
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
std::vector<common_lora_adapter_info> lora; std::vector<common_adapter_lora_info> lora;
std::vector<std::string> antiprompt; std::vector<std::string> antiprompt;
std::vector<std::string> response_fields; std::vector<std::string> response_fields;
@ -198,15 +198,17 @@ struct server_task {
bool metrics_reset_bucket = false; bool metrics_reset_bucket = false;
// used by SERVER_TASK_TYPE_SET_LORA // used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_lora_adapter_info> set_lora; std::vector<common_adapter_lora_info> set_lora;
server_task(server_task_type type) : type(type) {} server_task(server_task_type type) : type(type) {}
static slot_params params_from_json_cmpl( static slot_params params_from_json_cmpl(
const llama_model * model,
const llama_context * ctx, const llama_context * ctx,
const common_params & params_base, const common_params & params_base,
const json & data) { const json & data) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
slot_params params; slot_params params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
@ -329,7 +331,7 @@ struct server_task {
const auto & logit_bias = data.find("logit_bias"); const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) { if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model); const int n_vocab = llama_vocab_n_tokens(vocab);
for (const auto & el : *logit_bias) { for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect // TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) { if (el.is_array() && el.size() == 2) {
@ -348,7 +350,7 @@ struct server_task {
params.sampling.logit_bias.push_back({tok, bias}); params.sampling.logit_bias.push_back({tok, bias});
} }
} else if (el[0].is_string()) { } else if (el[0].is_string()) {
auto toks = common_tokenize(model, el[0].get<std::string>(), false); auto toks = common_tokenize(vocab, el[0].get<std::string>(), false);
for (auto tok : toks) { for (auto tok : toks) {
params.sampling.logit_bias.push_back({tok, bias}); params.sampling.logit_bias.push_back({tok, bias});
} }
@ -1131,7 +1133,7 @@ struct server_slot {
common_speculative * spec = nullptr; common_speculative * spec = nullptr;
std::vector<common_lora_adapter_info> lora; std::vector<common_adapter_lora_info> lora;
// the index relative to completion multi-task request // the index relative to completion multi-task request
size_t index = 0; size_t index = 0;
@ -1633,6 +1635,8 @@ struct server_context {
llama_model * model = nullptr; llama_model * model = nullptr;
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
const llama_vocab * vocab = nullptr;
llama_model * model_dft = nullptr; llama_model * model_dft = nullptr;
llama_context_params cparams_dft; llama_context_params cparams_dft;
@ -1690,10 +1694,12 @@ struct server_context {
return false; return false;
} }
vocab = llama_model_get_vocab(model);
n_ctx = llama_n_ctx(ctx); n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_add_bos_token(model); add_bos_token = llama_vocab_get_add_bos(vocab);
has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL; has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL;
if (!params_base.speculative.model.empty()) { if (!params_base.speculative.model.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
@ -1736,7 +1742,8 @@ struct server_context {
bool validate_builtin_chat_template() const { bool validate_builtin_chat_template() const {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); const char * tmpl = llama_model_chat_template(model);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0; return chat_res > 0;
} }
@ -1756,7 +1763,7 @@ struct server_context {
if (model_dft) { if (model_dft) {
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft);
if (slot.ctx_dft == nullptr) { if (slot.ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create draft context\n"); SRV_ERR("%s", "failed to create draft context\n");
return; return;
@ -1891,7 +1898,7 @@ struct server_context {
} }
if (slot.params.ignore_eos && has_eos_token) { if (slot.params.ignore_eos && has_eos_token) {
slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY}); slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY});
} }
{ {
@ -2047,14 +2054,14 @@ struct server_context {
slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
} }
if (llama_token_is_eog(model, result.tok)) { if (llama_vocab_is_eog(vocab, result.tok)) {
slot.stop = STOP_TYPE_EOS; slot.stop = STOP_TYPE_EOS;
slot.has_next_token = false; slot.has_next_token = false;
SLT_DBG(slot, "%s", "stopped by EOS\n"); SLT_DBG(slot, "%s", "stopped by EOS\n");
} }
const auto n_ctx_train = llama_n_ctx_train(model); const auto n_ctx_train = llama_model_n_ctx_train(model);
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
slot.truncated = true; slot.truncated = true;
@ -2074,7 +2081,7 @@ struct server_context {
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
size_t n_probs = slot.params.sampling.n_probs; size_t n_probs = slot.params.sampling.n_probs;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); size_t n_vocab = llama_vocab_n_tokens(vocab);
if (post_sampling) { if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl); const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size; const size_t max_probs = cur_p->size;
@ -2225,7 +2232,7 @@ struct server_context {
res->n_tokens = slot.n_prompt_tokens; res->n_tokens = slot.n_prompt_tokens;
res->oaicompat = slot.params.oaicompat; res->oaicompat = slot.params.oaicompat;
const int n_embd = llama_n_embd(model); const int n_embd = llama_model_n_embd(model);
std::vector<float> embd_res(n_embd, 0.0f); std::vector<float> embd_res(n_embd, 0.0f);
@ -2927,7 +2934,7 @@ struct server_context {
// make sure we're in the right embedding mode // make sure we're in the right embedding mode
llama_set_embeddings(ctx, slot_batched->is_non_causal()); llama_set_embeddings(ctx, slot_batched->is_non_causal());
// apply lora, only need to do it once per batch // apply lora, only need to do it once per batch
common_lora_adapters_apply(ctx, slot_batched->lora); common_set_adapter_lora(ctx, slot_batched->lora);
} }
// process the created batch of tokens // process the created batch of tokens
@ -3129,12 +3136,12 @@ struct server_context {
json model_meta() const { json model_meta() const {
return json { return json {
{"vocab_type", llama_vocab_type (model)}, {"vocab_type", llama_vocab_type (vocab)},
{"n_vocab", llama_n_vocab (model)}, {"n_vocab", llama_vocab_n_tokens (vocab)},
{"n_ctx_train", llama_n_ctx_train (model)}, {"n_ctx_train", llama_model_n_ctx_train(model)},
{"n_embd", llama_n_embd (model)}, {"n_embd", llama_model_n_embd (model)},
{"n_params", llama_model_n_params(model)}, {"n_params", llama_model_n_params (model)},
{"size", llama_model_size (model)}, {"size", llama_model_size (model)},
}; };
} }
}; };
@ -3639,7 +3646,7 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks; std::vector<server_task> tasks;
try { try {
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true); std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
tasks.reserve(tokenized_prompts.size()); tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) { for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type); server_task task = server_task(type);
@ -3649,7 +3656,6 @@ int main(int argc, char ** argv) {
task.prompt_tokens = std::move(tokenized_prompts[i]); task.prompt_tokens = std::move(tokenized_prompts[i]);
task.params = server_task::params_from_json_cmpl( task.params = server_task::params_from_json_cmpl(
ctx_server.model,
ctx_server.ctx, ctx_server.ctx,
ctx_server.params_base, ctx_server.params_base,
data); data);
@ -3745,13 +3751,13 @@ int main(int argc, char ** argv) {
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
// check model compatibility // check model compatibility
std::string err; std::string err;
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
err += "prefix token is missing. "; err += "prefix token is missing. ";
} }
if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
err += "suffix token is missing. "; err += "suffix token is missing. ";
} }
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
err += "middle token is missing. "; err += "middle token is missing. ";
} }
if (!err.empty()) { if (!err.empty()) {
@ -3797,10 +3803,10 @@ int main(int argc, char ** argv) {
data["input_extra"] = input_extra; // default to empty array if it's not exist data["input_extra"] = input_extra; // default to empty array if it's not exist
std::string prompt = json_value(data, "prompt", std::string()); std::string prompt = json_value(data, "prompt", std::string());
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, false, true); std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true);
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
data["prompt"] = format_infill( data["prompt"] = format_infill(
ctx_server.ctx, ctx_server.vocab,
data.at("input_prefix"), data.at("input_prefix"),
data.at("input_suffix"), data.at("input_suffix"),
data.at("input_extra"), data.at("input_extra"),
@ -3857,7 +3863,7 @@ int main(int argc, char ** argv) {
const bool add_special = json_value(body, "add_special", false); const bool add_special = json_value(body, "add_special", false);
const bool with_pieces = json_value(body, "with_pieces", false); const bool with_pieces = json_value(body, "with_pieces", false);
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true); llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true);
if (with_pieces) { if (with_pieces) {
for (const auto& token : tokens) { for (const auto& token : tokens) {
@ -3933,7 +3939,7 @@ int main(int argc, char ** argv) {
} }
} }
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true); std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
for (const auto & tokens : tokenized_prompts) { for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input // this check is necessary for models that do not add BOS token to the input
if (tokens.empty()) { if (tokens.empty()) {
@ -4033,20 +4039,20 @@ int main(int argc, char ** argv) {
return; return;
} }
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0]; llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
// create and queue the task // create and queue the task
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks; std::vector<server_task> tasks;
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true); std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size()); tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) { for (size_t i = 0; i < tokenized_docs.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_RERANK); server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id(); task.id = ctx_server.queue_tasks.get_new_id();
task.index = i; task.index = i;
task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]); task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
tasks.push_back(task); tasks.push_back(task);
} }

View File

@ -118,7 +118,7 @@ static json json_get_nested_values(const std::vector<std::string> & paths, const
* - only string, example: "string" * - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78] * - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/ */
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
// If `add_bos` is true, we only add BOS, when json_prompt is a string, // If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string. // or the first element of the json_prompt array is a string.
llama_tokens prompt_tokens; llama_tokens prompt_tokens;
@ -131,10 +131,10 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
llama_tokens p; llama_tokens p;
if (first) { if (first) {
p = common_tokenize(ctx, s, add_special, parse_special); p = common_tokenize(vocab, s, add_special, parse_special);
first = false; first = false;
} else { } else {
p = common_tokenize(ctx, s, false, parse_special); p = common_tokenize(vocab, s, false, parse_special);
} }
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
@ -148,7 +148,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
} }
} else { } else {
auto s = json_prompt.template get<std::string>(); auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special); prompt_tokens = common_tokenize(vocab, s, add_special, parse_special);
} }
return prompt_tokens; return prompt_tokens;
@ -166,11 +166,11 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
* - "prompt": [[12, 34, 56], [78, 90, 12]] * - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/ */
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { static std::vector<llama_tokens> tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
std::vector<llama_tokens> result; std::vector<llama_tokens> result;
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
// string or mixed // string or mixed
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special)); result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special));
} else if (json_is_array_of_numbers(json_prompt)) { } else if (json_is_array_of_numbers(json_prompt)) {
// array of tokens // array of tokens
result.push_back(json_prompt.get<llama_tokens>()); result.push_back(json_prompt.get<llama_tokens>());
@ -179,7 +179,7 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
result.reserve(json_prompt.size()); result.reserve(json_prompt.size());
for (const auto & p : json_prompt) { for (const auto & p : json_prompt) {
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special)); result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
} else if (json_is_array_of_numbers(p)) { } else if (json_is_array_of_numbers(p)) {
// array of tokens // array of tokens
result.push_back(p.get<llama_tokens>()); result.push_back(p.get<llama_tokens>());
@ -231,21 +231,23 @@ static size_t validate_utf8(const std::string& text) {
// //
// format rerank task: [BOS]query[EOS][SEP]doc[EOS] // format rerank task: [BOS]query[EOS][SEP]doc[EOS]
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) { static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result; llama_tokens result;
result.reserve(doc.size() + query.size() + 4); result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_token_bos(model)); result.push_back(llama_vocab_bos(vocab));
result.insert(result.end(), query.begin(), query.end()); result.insert(result.end(), query.begin(), query.end());
result.push_back(llama_token_eos(model)); result.push_back(llama_vocab_eos(vocab));
result.push_back(llama_token_sep(model)); result.push_back(llama_vocab_sep(vocab));
result.insert(result.end(), doc.begin(), doc.end()); result.insert(result.end(), doc.begin(), doc.end());
result.push_back(llama_token_eos(model)); result.push_back(llama_vocab_eos(vocab));
return result; return result;
} }
// format infill task // format infill task
static llama_tokens format_infill( static llama_tokens format_infill(
const llama_context * ctx, const llama_vocab * vocab,
const json & input_prefix, const json & input_prefix,
const json & input_suffix, const json & input_suffix,
const json & input_extra, const json & input_extra,
@ -272,15 +274,14 @@ static llama_tokens format_infill(
llama_tokens extra_tokens; llama_tokens extra_tokens;
extra_tokens.reserve(n_ctx); extra_tokens.reserve(n_ctx);
auto model = llama_get_model(ctx); auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false);
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false); auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false);
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) { if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) {
// TODO: make project name an input // TODO: make project name an input
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false); static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false);
extra_tokens.push_back(llama_token_fim_rep(model)); extra_tokens.push_back(llama_vocab_fim_rep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
} }
for (const auto & chunk : input_extra) { for (const auto & chunk : input_extra) {
@ -288,28 +289,28 @@ static llama_tokens format_infill(
const std::string text = json_value(chunk, "text", std::string()); const std::string text = json_value(chunk, "text", std::string());
const std::string filename = json_value(chunk, "filename", std::string("tmp")); const std::string filename = json_value(chunk, "filename", std::string("tmp"));
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) { if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false); const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model)); extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else { } else {
// chunk separator in binary form to avoid confusing the AI // chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false); static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false);
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
} }
const auto chunk_tokens = common_tokenize(ctx, text, false, false); const auto chunk_tokens = common_tokenize(vocab, text, false, false);
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
} }
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) { if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
// TODO: current filename // TODO: current filename
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false); static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model)); extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} }
@ -325,15 +326,15 @@ static llama_tokens format_infill(
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take); tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model)); tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model)); tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab));
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
if (llama_add_bos_token(model)) { if (llama_vocab_get_add_bos(vocab)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab));
} }
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
@ -342,7 +343,7 @@ static llama_tokens format_infill(
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model)); embd_inp.push_back(llama_vocab_fim_mid(vocab));
return embd_inp; return embd_inp;
} }
@ -764,14 +765,18 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
return data; return data;
} }
static std::string safe_json_to_str(json data) { static std::string safe_json_to_str(const json & data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace); return data.dump(-1, ' ', false, json::error_handler_t::replace);
} }
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) { static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
const auto * logits = llama_get_logits_ith(ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
cur.resize(n_vocab); cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) { for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
@ -799,8 +804,8 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
} }
static bool are_lora_equal( static bool are_lora_equal(
const std::vector<common_lora_adapter_info> & l1, const std::vector<common_adapter_lora_info> & l1,
const std::vector<common_lora_adapter_info> & l2) { const std::vector<common_adapter_lora_info> & l2) {
if (l1.size() != l2.size()) { if (l1.size() != l2.size()) {
return false; return false;
} }
@ -814,10 +819,10 @@ static bool are_lora_equal(
} }
// parse lora config from JSON request, returned a copy of lora_base with updated scale // parse lora config from JSON request, returned a copy of lora_base with updated scale
static std::vector<common_lora_adapter_info> parse_lora_request( static std::vector<common_adapter_lora_info> parse_lora_request(
const std::vector<common_lora_adapter_info> & lora_base, const std::vector<common_adapter_lora_info> & lora_base,
const json & data) { const json & data) {
std::vector<common_lora_adapter_info> lora(lora_base); std::vector<common_adapter_lora_info> lora(lora_base);
int max_idx = lora.size(); int max_idx = lora.size();
// clear existing value // clear existing value

View File

@ -75,12 +75,14 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_model_get_vocab(model);
// initialize the context // initialize the context
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx; ctx_params.n_ctx = n_ctx;
ctx_params.n_batch = n_ctx; ctx_params.n_batch = n_ctx;
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
if (!ctx) { if (!ctx) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1; return 1;
@ -97,9 +99,9 @@ int main(int argc, char ** argv) {
std::string response; std::string response;
// tokenize the prompt // tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> prompt_tokens(n_prompt_tokens); std::vector<llama_token> prompt_tokens(n_prompt_tokens);
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n"); GGML_ABORT("failed to tokenize the prompt\n");
} }
@ -124,13 +126,13 @@ int main(int argc, char ** argv) {
new_token_id = llama_sampler_sample(smpl, ctx, -1); new_token_id = llama_sampler_sample(smpl, ctx, -1);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id)) { if (llama_vocab_is_eog(vocab, new_token_id)) {
break; break;
} }
// convert the token to a string, print it and add it to the response // convert the token to a string, print it and add it to the response
char buf[256]; char buf[256];
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
if (n < 0) { if (n < 0) {
GGML_ABORT("failed to convert token to piece\n"); GGML_ABORT("failed to convert token to piece\n");
} }
@ -159,12 +161,14 @@ int main(int argc, char ** argv) {
break; break;
} }
const char * tmpl = llama_model_chat_template(model);
// add the user input to the message list and format it // add the user input to the message list and format it
messages.push_back({"user", strdup(user.c_str())}); messages.push_back({"user", strdup(user.c_str())});
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
if (new_len > (int)formatted.size()) { if (new_len > (int)formatted.size()) {
formatted.resize(new_len); formatted.resize(new_len);
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
} }
if (new_len < 0) { if (new_len < 0) {
fprintf(stderr, "failed to apply the chat template\n"); fprintf(stderr, "failed to apply the chat template\n");
@ -181,7 +185,7 @@ int main(int argc, char ** argv) {
// add the response to the messages // add the response to the messages
messages.push_back({"assistant", strdup(response.c_str())}); messages.push_back({"assistant", strdup(response.c_str())});
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0);
if (prev_len < 0) { if (prev_len < 0) {
fprintf(stderr, "failed to apply the chat template\n"); fprintf(stderr, "failed to apply the chat template\n");
return 1; return 1;

View File

@ -84,6 +84,7 @@ int main(int argc, char ** argv) {
model_params.n_gpu_layers = ngl; model_params.n_gpu_layers = ngl;
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
const llama_vocab * vocab = llama_model_get_vocab(model);
if (model == NULL) { if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__); fprintf(stderr , "%s: error: unable to load model\n" , __func__);
@ -93,11 +94,11 @@ int main(int argc, char ** argv) {
// tokenize the prompt // tokenize the prompt
// find the number of tokens in the prompt // find the number of tokens in the prompt
const int n_prompt = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
// allocate space for the tokens and tokenize the prompt // allocate space for the tokens and tokenize the prompt
std::vector<llama_token> prompt_tokens(n_prompt); std::vector<llama_token> prompt_tokens(n_prompt);
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__); fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__);
return 1; return 1;
} }
@ -112,7 +113,7 @@ int main(int argc, char ** argv) {
// enable performance counters // enable performance counters
ctx_params.no_perf = false; ctx_params.no_perf = false;
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
@ -131,7 +132,7 @@ int main(int argc, char ** argv) {
for (auto id : prompt_tokens) { for (auto id : prompt_tokens) {
char buf[128]; char buf[128];
int n = llama_token_to_piece(model, id, buf, sizeof(buf), 0, true); int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true);
if (n < 0) { if (n < 0) {
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
return 1; return 1;
@ -164,12 +165,12 @@ int main(int argc, char ** argv) {
new_token_id = llama_sampler_sample(smpl, ctx, -1); new_token_id = llama_sampler_sample(smpl, ctx, -1);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id)) { if (llama_vocab_is_eog(vocab, new_token_id)) {
break; break;
} }
char buf[128]; char buf[128];
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
if (n < 0) { if (n < 0) {
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
return 1; return 1;

View File

@ -45,6 +45,8 @@ int main(int argc, char ** argv) {
model_tgt = llama_init_tgt.model.get(); model_tgt = llama_init_tgt.model.get();
ctx_tgt = llama_init_tgt.context.get(); ctx_tgt = llama_init_tgt.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model_tgt);
// load the draft model // load the draft model
params.devices = params.speculative.devices; params.devices = params.speculative.devices;
params.model = params.speculative.model; params.model = params.speculative.model;
@ -196,7 +198,7 @@ int main(int argc, char ** argv) {
id_last = ids[i]; id_last = ids[i];
if (llama_token_is_eog(model_tgt, id_last)) { if (llama_vocab_is_eog(vocab, id_last)) {
has_eos = true; has_eos = true;
break; break;
} }

View File

@ -90,10 +90,13 @@ int main(int argc, char ** argv) {
model_dft = llama_init_dft.model.get(); model_dft = llama_init_dft.model.get();
ctx_dft = llama_init_dft.context.get(); ctx_dft = llama_init_dft.context.get();
const bool vocab_type_tgt = llama_vocab_type(model_tgt); const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft); const bool vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("vocab_type dft: %d\n", vocab_type_dft); LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) { if (vocab_type_tgt != vocab_type_dft) {
@ -103,18 +106,18 @@ int main(int argc, char ** argv) {
} }
if ( if (
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) || llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft) llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)
) { ) {
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
return 1; return 1;
} }
{ {
const int n_vocab_tgt = llama_n_vocab(model_tgt); const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft); const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft ? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt; : n_vocab_dft - n_vocab_tgt;
@ -122,13 +125,13 @@ int main(int argc, char ** argv) {
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__); LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1; return 1;
} }
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i); const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i); const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) { if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__); LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i, LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
@ -170,7 +173,7 @@ int main(int argc, char ** argv) {
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab // the 2 models should have the same vocab
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft)); //GGML_ASSERT(n_vocab == llama_vocab_n_tokens(model_dft));
// how many tokens to draft each time // how many tokens to draft each time
int n_draft = params.speculative.n_max; int n_draft = params.speculative.n_max;
@ -386,7 +389,7 @@ int main(int argc, char ** argv) {
} }
} }
if (llama_token_is_eog(model_tgt, token_id)) { if (llama_vocab_is_eog(vocab_tgt, token_id)) {
has_eos = true; has_eos = true;
} }
++n_predict; ++n_predict;

View File

@ -344,8 +344,10 @@ int main(int raw_argc, char ** raw_argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_model_get_vocab(model);
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
if (!ctx) { if (!ctx) {
fprintf(stderr, "Error: could not create context.\n"); fprintf(stderr, "Error: could not create context.\n");
return 1; return 1;
@ -365,7 +367,7 @@ int main(int raw_argc, char ** raw_argv) {
prompt = stdin_buffer.str(); prompt = stdin_buffer.str();
} }
const bool model_wants_add_bos = llama_add_bos_token(model); const bool model_wants_add_bos = llama_vocab_get_add_bos(vocab);
const bool add_bos = model_wants_add_bos && !no_bos; const bool add_bos = model_wants_add_bos && !no_bos;
const bool parse_special = !no_parse_special; const bool parse_special = !no_parse_special;
const bool escape = !no_escape; const bool escape = !no_escape;
@ -375,7 +377,7 @@ int main(int raw_argc, char ** raw_argv) {
} }
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
tokens = common_tokenize(model, prompt, add_bos, parse_special); tokens = common_tokenize(vocab, prompt, add_bos, parse_special);
if (printing_ids) { if (printing_ids) {
printf("["); printf("[");

View File

@ -414,15 +414,15 @@ static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) {
prompt.insert(prompt.end(), tokens.begin(), tokens.end()); prompt.insert(prompt.end(), tokens.begin(), tokens.end());
} }
static void prompt_add(llama_tokens & prompt, const llama_model * model, const std::string & txt, bool add_special, bool parse_special) { static void prompt_add(llama_tokens & prompt, const llama_vocab * vocab, const std::string & txt, bool add_special, bool parse_special) {
auto tmp = common_tokenize(model, txt, add_special, parse_special); auto tmp = common_tokenize(vocab, txt, add_special, parse_special);
prompt_add(prompt, tmp); prompt_add(prompt, tmp);
} }
static void prompt_init(llama_tokens & prompt, const llama_model * model) { static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt.clear(); prompt.clear();
prompt_add(prompt, model, "<|im_start|>\n", true, true); prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
@ -462,6 +462,8 @@ int main(int argc, char ** argv) {
model_ttc = llama_init_ttc.model.get(); model_ttc = llama_init_ttc.model.get();
ctx_ttc = llama_init_ttc.context.get(); ctx_ttc = llama_init_ttc.context.get();
const llama_vocab * vocab = llama_model_get_vocab(model_ttc);
// TODO: refactor in a common struct // TODO: refactor in a common struct
params.model = params.vocoder.model; params.model = params.vocoder.model;
params.model_url = params.vocoder.model_url; params.model_url = params.vocoder.model_url;
@ -499,9 +501,9 @@ int main(int argc, char ** argv) {
std::vector<llama_token> prompt_inp; std::vector<llama_token> prompt_inp;
prompt_init(prompt_inp, model_ttc); prompt_init(prompt_inp, vocab);
prompt_add(prompt_inp, model_ttc, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true); prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
// convert the input text into the necessary format expected by OuteTTS // convert the input text into the necessary format expected by OuteTTS
{ {
@ -509,10 +511,10 @@ int main(int argc, char ** argv) {
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true); prompt_add(prompt_inp, vocab, prompt_clean, false, true);
} }
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n", false, true); prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
// disabled to save time on tokenizing each time // disabled to save time on tokenizing each time
// TODO: load voices from the json files // TODO: load voices from the json files
@ -549,7 +551,7 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)"; lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
auto tmp = common_tokenize(model_ttc, voice_data, false, true); auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n"); printf("\n\n");
for (int i = 0; i < tmp.size(); ++i) { for (int i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]); printf("%d, ", tmp[i]);
@ -735,9 +737,9 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
const auto * cands = common_sampler_get_candidates(smpl[i]); const auto * cands = common_sampler_get_candidates(smpl[i]);
// is it an end of generation? -> mark the stream as finished // is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model_ttc, new_token_id) || n_decode == n_predict) { if (llama_vocab_is_eog(vocab, new_token_id) || n_decode == n_predict) {
std::string reason; std::string reason;
if (llama_token_is_eog(model_ttc, new_token_id)) { if (llama_vocab_is_eog(vocab, new_token_id)) {
reason = "eos"; reason = "eos";
} else { } else {
reason = "n_predict"; reason = "n_predict";
@ -873,7 +875,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
#if 1 #if 1
// spectral operations // spectral operations
const int n_embd = llama_n_embd(model_cts); const int n_embd = llama_model_n_embd(model_cts);
const float * embd = llama_get_embeddings(ctx_cts); const float * embd = llama_get_embeddings(ctx_cts);
auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads); auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads);

View File

@ -20,11 +20,11 @@ struct llama_sampler_deleter {
void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); } void operator()(llama_sampler * sampler) { llama_sampler_free(sampler); }
}; };
struct llama_lora_adapter_deleter { struct llama_adapter_lora_deleter {
void operator()(llama_lora_adapter * lora_adapter) { llama_lora_adapter_free(lora_adapter); } void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); }
}; };
typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr; typedef std::unique_ptr<llama_model, llama_model_deleter> llama_model_ptr;
typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr; typedef std::unique_ptr<llama_context, llama_context_deleter> llama_context_ptr;
typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr; typedef std::unique_ptr<llama_sampler, llama_sampler_deleter> llama_sampler_ptr;
typedef std::unique_ptr<llama_lora_adapter, llama_lora_adapter_deleter> llama_lora_adapter_ptr; typedef std::unique_ptr<llama_adapter_lora, llama_adapter_lora_deleter> llama_adapter_lora_ptr;

View File

@ -56,7 +56,7 @@ extern "C" {
// TODO: show sample usage // TODO: show sample usage
// //
// struct llama_vocab; // TODO: add in the future struct llama_vocab;
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct llama_sampler; struct llama_sampler;
@ -385,8 +385,7 @@ extern "C" {
} llama_chat_message; } llama_chat_message;
// lora adapter // lora adapter
// TODO: rename to llama_adapter_lora struct llama_adapter_lora;
struct llama_lora_adapter;
// Helpers for getting default parameters // Helpers for getting default parameters
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172) // TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
@ -400,18 +399,19 @@ extern "C" {
// Call once at the start of the program // Call once at the start of the program
LLAMA_API void llama_backend_init(void); LLAMA_API void llama_backend_init(void);
// Call once at the end of the program - currently only used for MPI
LLAMA_API void llama_backend_free(void);
//optional: //optional:
LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa);
// Optional: an auto threadpool gets created in ggml if not passed explicitly // Optional: an auto threadpool gets created in ggml if not passed explicitly
LLAMA_API void llama_attach_threadpool( LLAMA_API void llama_attach_threadpool(
struct llama_context * ctx, struct llama_context * ctx,
ggml_threadpool_t threadpool, ggml_threadpool_t threadpool,
ggml_threadpool_t threadpool_batch); ggml_threadpool_t threadpool_batch);
LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
// Call once at the end of the program - currently only used for MPI LLAMA_API void llama_detach_threadpool(struct llama_context * ctx);
LLAMA_API void llama_backend_free(void);
DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file(
const char * path_model, const char * path_model,
@ -427,11 +427,15 @@ extern "C" {
LLAMA_API void llama_model_free(struct llama_model * model); LLAMA_API void llama_model_free(struct llama_model * model);
// TODO: rename to llama_init_from_model LLAMA_API struct llama_context * llama_init_from_model(
LLAMA_API struct llama_context * llama_new_context_with_model(
struct llama_model * model, struct llama_model * model,
struct llama_context_params params); struct llama_context_params params);
DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model(
struct llama_model * model,
struct llama_context_params params),
"use llama_init_from_model instead");
// Frees all allocated memory // Frees all allocated memory
LLAMA_API void llama_free(struct llama_context * ctx); LLAMA_API void llama_free(struct llama_context * ctx);
@ -449,20 +453,30 @@ extern "C" {
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead");
LLAMA_API int32_t llama_n_embd (const struct llama_model * model); DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead");
LLAMA_API int32_t llama_n_layer (const struct llama_model * model); DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead");
LLAMA_API int32_t llama_n_head (const struct llama_model * model);
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_tokens instead");
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor // Get the model's RoPE frequency scaling factor
LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab);
LLAMA_API int32_t llama_vocab_n_tokens(const struct llama_vocab * vocab);
// Functions to access the model's GGUF metadata scalar values // Functions to access the model's GGUF metadata scalar values
// - The functions return the length of the string on success, or -1 on failure // - The functions return the length of the string on success, or -1 on failure
@ -488,6 +502,9 @@ extern "C" {
// Returns the total size of all the tensors in the model in bytes // Returns the total size of all the tensors in the model in bytes
LLAMA_API uint64_t llama_model_size(const struct llama_model * model); LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
// Get the default chat template. Returns nullptr if not available
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
// Returns the total number of parameters in the model // Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
@ -515,34 +532,31 @@ extern "C" {
// //
// Load a LoRA adapter from file // Load a LoRA adapter from file
// TODO: rename to llama_adapter_lora_init LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init(
LLAMA_API struct llama_lora_adapter * llama_lora_adapter_init(
struct llama_model * model, struct llama_model * model,
const char * path_lora); const char * path_lora);
// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
// The following functions operate on a llama_context, hence the naming: llama_verb_...
// Add a loaded LoRA adapter to given context // Add a loaded LoRA adapter to given context
// This will not modify model's weight // This will not modify model's weight
// TODO: rename to llama_set_adapter_lora LLAMA_API int32_t llama_set_adapter_lora(
LLAMA_API int32_t llama_lora_adapter_set(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_lora_adapter * adapter, struct llama_adapter_lora * adapter,
float scale); float scale);
// Remove a specific LoRA adapter from given context // Remove a specific LoRA adapter from given context
// Return -1 if the adapter is not present in the context // Return -1 if the adapter is not present in the context
// TODO: rename to llama_rm_adapter_lora LLAMA_API int32_t llama_rm_adapter_lora(
LLAMA_API int32_t llama_lora_adapter_remove(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_lora_adapter * adapter); struct llama_adapter_lora * adapter);
// Remove all LoRA adapters from given context // Remove all LoRA adapters from given context
// TODO: rename to llama_clear_adapter_lora LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx);
LLAMA_API void llama_lora_adapter_clear(struct llama_context * ctx);
// Manually free a LoRA adapter
// Note: loaded adapters will be free when the associated model is deleted
// TODO: rename to llama_adapter_lora_free
LLAMA_API void llama_lora_adapter_free(struct llama_lora_adapter * adapter);
// Apply a loaded control vector to a llama_context, or if data is NULL, clear // Apply a loaded control vector to a llama_context, or if data is NULL, clear
// the currently loaded vector. // the currently loaded vector.
@ -550,9 +564,8 @@ extern "C" {
// to an n_embd x n_layers buffer starting from layer 1. // to an n_embd x n_layers buffer starting from layer 1.
// il_start and il_end are the layer range the vector should apply to (both inclusive) // il_start and il_end are the layer range the vector should apply to (both inclusive)
// See llama_control_vector_load in common to load a control vector. // See llama_control_vector_load in common to load a control vector.
// TODO: rename to llama_adapter_cvec_apply LLAMA_API int32_t llama_apply_adapter_cvec(
LLAMA_API int32_t llama_control_vector_apply( struct llama_context * ctx,
struct llama_context * lctx,
const float * data, const float * data,
size_t len, size_t len,
int32_t n_embd, int32_t n_embd,
@ -908,41 +921,57 @@ extern "C" {
// Vocab // Vocab
// //
LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token);
LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token);
LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token); LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token);
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token);
// Identify if Token Id is a control token or a render-able token // Identify if Token Id is a control token or a render-able token
LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token); LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence
LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab); // classification
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
LLAMA_API bool llama_add_bos_token(const struct llama_model * model); LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
LLAMA_API bool llama_add_eos_token(const struct llama_model * model); LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
// infill tokens LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab);
DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead"); LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab);
DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead"); LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab);
DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead"); LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model); DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocabable_get_text instead");
LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model); DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead");
LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model); DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead");
LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model); DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead");
LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model); DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead");
LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model); DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead");
DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead");
DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead");
DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead");
DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead");
DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead");
DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead");
DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_bos instead");
DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_get_add_eos instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead");
DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead");
// //
// Tokenization // Tokenization
@ -958,7 +987,7 @@ extern "C" {
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
/// as plaintext. Does not insert a leading space. /// as plaintext. Does not insert a leading space.
LLAMA_API int32_t llama_tokenize( LLAMA_API int32_t llama_tokenize(
const struct llama_model * model, const struct llama_vocab * vocab,
const char * text, const char * text,
int32_t text_len, int32_t text_len,
llama_token * tokens, llama_token * tokens,
@ -972,7 +1001,7 @@ extern "C" {
// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
// @param special If true, special tokens are rendered in the output. // @param special If true, special tokens are rendered in the output.
LLAMA_API int32_t llama_token_to_piece( LLAMA_API int32_t llama_token_to_piece(
const struct llama_model * model, const struct llama_vocab * vocab,
llama_token token, llama_token token,
char * buf, char * buf,
int32_t length, int32_t length,
@ -986,7 +1015,7 @@ extern "C" {
/// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so. /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
/// @param unparse_special If true, special tokens are rendered in the output. /// @param unparse_special If true, special tokens are rendered in the output.
LLAMA_API int32_t llama_detokenize( LLAMA_API int32_t llama_detokenize(
const struct llama_model * model, const struct llama_vocab * vocab,
const llama_token * tokens, const llama_token * tokens,
int32_t n_tokens, int32_t n_tokens,
char * text, char * text,
@ -1009,7 +1038,6 @@ extern "C" {
/// @param length The size of the allocated buffer /// @param length The size of the allocated buffer
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
LLAMA_API int32_t llama_chat_apply_template( LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * tmpl, const char * tmpl,
const struct llama_chat_message * chat, const struct llama_chat_message * chat,
size_t n_msg, size_t n_msg,
@ -1057,7 +1085,6 @@ extern "C" {
// llama_sampler_free(smpl); // llama_sampler_free(smpl);
// //
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
// //
typedef void * llama_sampler_context_t; typedef void * llama_sampler_context_t;
@ -1157,7 +1184,7 @@ extern "C" {
float eta); float eta);
LLAMA_API struct llama_sampler * llama_sampler_init_grammar( LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
const struct llama_model * model, const struct llama_vocab * vocab,
const char * grammar_str, const char * grammar_str,
const char * grammar_root); const char * grammar_root);
@ -1169,8 +1196,9 @@ extern "C" {
float penalty_present); // 0.0 = disabled float penalty_present); // 0.0 = disabled
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
LLAMA_API struct llama_sampler * llama_sampler_init_dry( LLAMA_API struct llama_sampler * llama_sampler_init_dry(
const struct llama_model * model, const struct llama_vocab * vocab,
int32_t n_ctx_train,
float dry_multiplier, float dry_multiplier,
float dry_base, float dry_base,
int32_t dry_allowed_length, int32_t dry_allowed_length,
@ -1204,7 +1232,7 @@ extern "C" {
// 3. discard non-EOG tokens with low prob // 3. discard non-EOG tokens with low prob
// 4. if no tokens are left -> pick EOT // 4. if no tokens are left -> pick EOT
// //
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model); LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab);
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

View File

@ -1,5 +1,7 @@
#include "llama-adapter.h" #include "llama-adapter.h"
#include "llama-impl.h"
#include "llama-mmap.h"
#include "llama-model.h" #include "llama-model.h"
#include <algorithm> #include <algorithm>
@ -9,7 +11,7 @@
// vec // vec
struct ggml_tensor * llama_control_vector::tensor_for(int il) const { struct ggml_tensor * llama_adapter_cvec::tensor_for(int il) const {
if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) { if (il < 0 || il < layer_start || il > layer_end || (size_t) il >= tensors.size()) {
return nullptr; return nullptr;
} }
@ -17,7 +19,7 @@ struct ggml_tensor * llama_control_vector::tensor_for(int il) const {
return tensors[il]; return tensors[il];
} }
struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const { struct ggml_tensor * llama_adapter_cvec::apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const {
ggml_tensor * layer_dir = tensor_for(il); ggml_tensor * layer_dir = tensor_for(il);
if (layer_dir != nullptr) { if (layer_dir != nullptr) {
cur = ggml_add(ctx, cur, layer_dir); cur = ggml_add(ctx, cur, layer_dir);
@ -26,12 +28,12 @@ struct ggml_tensor * llama_control_vector::apply_to(struct ggml_context * ctx, s
return cur; return cur;
} }
static bool llama_control_vector_init(struct llama_control_vector & cvec, const llama_model & model) { bool llama_adapter_cvec::init(const llama_model & model) {
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
GGML_ASSERT(cvec.tensors.empty()); GGML_ASSERT(tensors.empty());
GGML_ASSERT(cvec.ctxs.empty()); GGML_ASSERT(ctxs.empty());
GGML_ASSERT(cvec.bufs.empty()); GGML_ASSERT(bufs.empty());
// create a context for each buffer type // create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map; std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@ -50,7 +52,7 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const
} }
ctx_map[buft] = ctx; ctx_map[buft] = ctx;
cvec.ctxs.emplace_back(ctx); ctxs.emplace_back(ctx);
return ctx; return ctx;
} }
@ -59,21 +61,21 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const
}; };
// make tensors // make tensors
cvec.tensors.reserve(hparams.n_layer); tensors.reserve(hparams.n_layer);
cvec.tensors.push_back(nullptr); // there's never a tensor for layer 0 tensors.push_back(nullptr); // there's never a tensor for layer 0
for (size_t il = 1; il < hparams.n_layer; il++) { for (size_t il = 1; il < hparams.n_layer; il++) {
ggml_backend_buffer_type_t buft = llama_model_select_buft(model, il); ggml_backend_buffer_type_t buft = model.select_buft(il);
ggml_context * ctx = ctx_for_buft(buft); ggml_context * ctx = ctx_for_buft(buft);
if (!ctx) { if (!ctx) {
LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__); LLAMA_LOG_ERROR("%s: failed to allocate context for control vector\n", __func__);
return false; return false;
} }
ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd); ggml_tensor * tensor = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd);
cvec.tensors.push_back(tensor); tensors.push_back(tensor);
} }
// allocate tensors / buffers and zero // allocate tensors / buffers and zero
cvec.bufs.reserve(ctx_map.size()); bufs.reserve(ctx_map.size());
for (auto it : ctx_map) { for (auto it : ctx_map) {
ggml_backend_buffer_type_t buft = it.first; ggml_backend_buffer_type_t buft = it.first;
ggml_context * ctx = it.second; ggml_context * ctx = it.second;
@ -83,14 +85,13 @@ static bool llama_control_vector_init(struct llama_control_vector & cvec, const
return false; return false;
} }
ggml_backend_buffer_clear(buf, 0); ggml_backend_buffer_clear(buf, 0);
cvec.bufs.emplace_back(buf); bufs.emplace_back(buf);
} }
return true; return true;
} }
int32_t llama_control_vector_apply( int32_t llama_adapter_cvec::apply(
struct llama_control_vector & cvec,
const llama_model & model, const llama_model & model,
const float * data, const float * data,
size_t len, size_t len,
@ -101,8 +102,8 @@ int32_t llama_control_vector_apply(
if (data == nullptr) { if (data == nullptr) {
// disable the current control vector (but leave allocated for later) // disable the current control vector (but leave allocated for later)
cvec.layer_start = -1; layer_start = -1;
cvec.layer_end = -1; layer_end = -1;
return 0; return 0;
} }
@ -111,21 +112,21 @@ int32_t llama_control_vector_apply(
return 1; return 1;
} }
if (cvec.tensors.empty()) { if (tensors.empty()) {
if (!llama_control_vector_init(cvec, model)) { if (!init(model)) {
return 1; return 1;
} }
} }
cvec.layer_start = il_start; layer_start = il_start;
cvec.layer_end = il_end; layer_end = il_end;
for (size_t il = 1; il < hparams.n_layer; il++) { for (size_t il = 1; il < hparams.n_layer; il++) {
assert(cvec.tensors[il] != nullptr); assert(tensors[il] != nullptr);
const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present const size_t off = n_embd * (il - 1); // buffer doesn't have data for layer 0, since it's never present
if (off + n_embd <= len) { if (off + n_embd <= len) {
ggml_backend_tensor_set(cvec.tensors[il], data + off, 0, n_embd * ggml_element_size(cvec.tensors[il])); ggml_backend_tensor_set(tensors[il], data + off, 0, n_embd * ggml_element_size(tensors[il]));
} }
} }
@ -134,7 +135,7 @@ int32_t llama_control_vector_apply(
// lora // lora
llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) { llama_adapter_lora_weight * llama_adapter_lora::get_weight(struct ggml_tensor * w) {
const std::string name(w->name); const std::string name(w->name);
const auto pos = ab_map.find(name); const auto pos = ab_map.find(name);
@ -145,11 +146,7 @@ llama_lora_weight * llama_lora_adapter::get_weight(struct ggml_tensor * w) {
return nullptr; return nullptr;
} }
void llama_lora_adapter_free(struct llama_lora_adapter * adapter) { static void llama_adapter_lora_init_impl(struct llama_model & model, const char * path_lora, struct llama_adapter_lora & adapter) {
delete adapter;
}
static void llama_lora_adapter_init_impl(struct llama_model & model, const char * path_lora, struct llama_lora_adapter & adapter) {
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora); LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
ggml_context * ctx_init; ggml_context * ctx_init;
@ -221,7 +218,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
}; };
// bundle lora_a and lora_b into pairs // bundle lora_a and lora_b into pairs
std::map<std::string, llama_lora_weight> ab_map; std::map<std::string, llama_adapter_lora_weight> ab_map;
auto str_endswith = [](const std::string & str, const std::string & suffix) { auto str_endswith = [](const std::string & str, const std::string & suffix) {
return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0;
}; };
@ -231,14 +228,14 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
if (str_endswith(name, ".lora_a")) { if (str_endswith(name, ".lora_a")) {
replace_all(name, ".lora_a", ""); replace_all(name, ".lora_a", "");
if (ab_map.find(name) == ab_map.end()) { if (ab_map.find(name) == ab_map.end()) {
ab_map[name] = llama_lora_weight(cur, nullptr); ab_map[name] = llama_adapter_lora_weight(cur, nullptr);
} else { } else {
ab_map[name].a = cur; ab_map[name].a = cur;
} }
} else if (str_endswith(name, ".lora_b")) { } else if (str_endswith(name, ".lora_b")) {
replace_all(name, ".lora_b", ""); replace_all(name, ".lora_b", "");
if (ab_map.find(name) == ab_map.end()) { if (ab_map.find(name) == ab_map.end()) {
ab_map[name] = llama_lora_weight(nullptr, cur); ab_map[name] = llama_adapter_lora_weight(nullptr, cur);
} else { } else {
ab_map[name].b = cur; ab_map[name].b = cur;
} }
@ -254,7 +251,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
// add tensors // add tensors
for (auto & it : ab_map) { for (auto & it : ab_map) {
const std::string & name = it.first; const std::string & name = it.first;
llama_lora_weight & w = it.second; llama_adapter_lora_weight & w = it.second;
bool is_token_embd = str_endswith(name, "token_embd.weight"); bool is_token_embd = str_endswith(name, "token_embd.weight");
if (!w.a || !w.b) { if (!w.a || !w.b) {
@ -262,7 +259,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
} }
// device buft and device ctx // device buft and device ctx
auto * model_tensor = llama_model_get_tensor(model, name.c_str()); const auto * model_tensor = model.get_tensor(name.c_str());
if (!model_tensor) { if (!model_tensor) {
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)"); throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
} }
@ -288,7 +285,7 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b); struct ggml_tensor * tensor_b = ggml_dup_tensor(dev_ctx, w.b);
ggml_set_name(tensor_a, w.a->name); ggml_set_name(tensor_a, w.a->name);
ggml_set_name(tensor_b, w.b->name); ggml_set_name(tensor_b, w.b->name);
adapter.ab_map[name] = llama_lora_weight(tensor_a, tensor_b); adapter.ab_map[name] = llama_adapter_lora_weight(tensor_a, tensor_b);
} }
// allocate tensors / buffers and zero // allocate tensors / buffers and zero
@ -330,11 +327,11 @@ static void llama_lora_adapter_init_impl(struct llama_model & model, const char
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2); LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
} }
struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model, const char * path_lora) { struct llama_adapter_lora * llama_adapter_lora_init(struct llama_model * model, const char * path_lora) {
struct llama_lora_adapter * adapter = new llama_lora_adapter(); struct llama_adapter_lora * adapter = new llama_adapter_lora();
try { try {
llama_lora_adapter_init_impl(*model, path_lora, *adapter); llama_adapter_lora_init_impl(*model, path_lora, *adapter);
return adapter; return adapter;
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what()); LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
@ -344,3 +341,7 @@ struct llama_lora_adapter * llama_lora_adapter_init(struct llama_model * model,
return nullptr; return nullptr;
} }
void llama_adapter_lora_free(struct llama_adapter_lora * adapter) {
delete adapter;
}

View File

@ -1,73 +1,74 @@
#pragma once #pragma once
#include "llama-impl.h" #include "llama.h"
#include "llama-hparams.h"
#include "ggml-cpp.h" #include "ggml-cpp.h"
#include <string>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
// TODO: pimpl
// //
// llama_adapter_cvec // llama_adapter_cvec
// //
// TODO: rename to llama_adapter_cvec struct llama_adapter_cvec {
struct llama_control_vector { struct ggml_tensor * tensor_for(int il) const;
std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
std::vector<struct ggml_tensor *> tensors; // per layer struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const;
int32_t apply(
const llama_model & model,
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end);
private:
bool init(const llama_model & model);
int32_t layer_start = -1; int32_t layer_start = -1;
int32_t layer_end = -1; int32_t layer_end = -1;
struct ggml_tensor * tensor_for(int il) const; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs;
struct ggml_tensor * apply_to(struct ggml_context * ctx, struct ggml_tensor * cur, int il) const; std::vector<struct ggml_tensor *> tensors; // per layer
}; };
int32_t llama_control_vector_apply(
struct llama_control_vector & cvec,
const llama_model & model,
const float * data,
size_t len,
int32_t n_embd,
int32_t il_start,
int32_t il_end);
// //
// llama_adapter_lora // llama_adapter_lora
// //
// TODO: rename to llama_adapter_lora_weight struct llama_adapter_lora_weight {
struct llama_lora_weight {
struct ggml_tensor * a = nullptr; struct ggml_tensor * a = nullptr;
struct ggml_tensor * b = nullptr; struct ggml_tensor * b = nullptr;
// get actual scale based on rank and alpha // get actual scale based on rank and alpha
float get_scale(float alpha, float adapter_scale) { float get_scale(float alpha, float adapter_scale) const {
const float rank = (float) b->ne[0]; const float rank = (float) b->ne[0];
const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale; const float scale = alpha ? adapter_scale * alpha / rank : adapter_scale;
return scale; return scale;
} }
llama_lora_weight() = default; llama_adapter_lora_weight() = default;
llama_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {} llama_adapter_lora_weight(struct ggml_tensor * a, struct ggml_tensor * b) : a(a), b(b) {}
}; };
// TODO: rename to llama_adapter_lora struct llama_adapter_lora {
struct llama_lora_adapter {
// map tensor name to lora_a_b // map tensor name to lora_a_b
std::unordered_map<std::string, struct llama_lora_weight> ab_map; std::unordered_map<std::string, struct llama_adapter_lora_weight> ab_map;
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
std::vector<ggml_backend_buffer_ptr> bufs; std::vector<ggml_backend_buffer_ptr> bufs;
float alpha; float alpha;
llama_lora_adapter() = default; llama_adapter_lora() = default;
~llama_lora_adapter() = default; ~llama_adapter_lora() = default;
llama_lora_weight * get_weight(struct ggml_tensor * w); llama_adapter_lora_weight * get_weight(struct ggml_tensor * w);
}; };

View File

@ -178,6 +178,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat.template" },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },

View File

@ -176,6 +176,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID, LLM_KV_TOKENIZER_FIM_MID_ID,

View File

@ -1,5 +1,8 @@
#include "llama-context.h" #include "llama-context.h"
#include "llama-impl.h"
#include "llama-mmap.h"
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
@ -467,11 +470,12 @@ void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) { size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
const auto & cparams = lctx.cparams; const auto & cparams = lctx.cparams;
const auto & hparams = lctx.model.hparams; const auto & hparams = lctx.model.hparams;
const auto & vocab = lctx.model.vocab;
const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max); const size_t n_outputs_max = std::max(n_outputs, (size_t) cparams.n_seq_max);
const auto n_batch = cparams.n_batch; const auto n_batch = cparams.n_batch;
const auto n_vocab = hparams.n_vocab; const auto n_vocab = vocab.n_tokens();
const auto n_embd = hparams.n_embd; const auto n_embd = hparams.n_embd;
// TODO: use a per-batch flag for logits presence instead // TODO: use a per-batch flag for logits presence instead
@ -504,7 +508,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
auto * buft = ggml_backend_cpu_buffer_type(); auto * buft = ggml_backend_cpu_buffer_type();
// try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory // try to use the host buffer of the device where the output tensor is allocated for faster transfer to system memory
auto * output_dev = lctx.model.dev_output.dev; auto * output_dev = lctx.model.dev_output();
auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr; auto * output_dev_host_buft = output_dev ? ggml_backend_dev_host_buffer_type(output_dev) : nullptr;
if (output_dev_host_buft) { if (output_dev_host_buft) {
buft = output_dev_host_buft; buft = output_dev_host_buft;
@ -538,7 +542,7 @@ size_t llama_output_reserve(struct llama_context & lctx, size_t n_outputs) {
void llama_output_reorder(struct llama_context & ctx) { void llama_output_reorder(struct llama_context & ctx) {
std::vector<size_t> & out_ids = ctx.sbatch.out_ids; std::vector<size_t> & out_ids = ctx.sbatch.out_ids;
if (!out_ids.empty()) { if (!out_ids.empty()) {
const uint32_t n_vocab = ctx.model.hparams.n_vocab; const uint32_t n_vocab = ctx.model.vocab.n_tokens();
const uint32_t n_embd = ctx.model.hparams.n_embd; const uint32_t n_embd = ctx.model.hparams.n_embd;
const int32_t n_outputs = ctx.n_outputs; const int32_t n_outputs = ctx.n_outputs;
@ -722,7 +726,7 @@ float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
} }
return ctx->logits + j*ctx->model.hparams.n_vocab; return ctx->logits + j*ctx->model.vocab.n_tokens();
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what());
#ifndef NDEBUG #ifndef NDEBUG
@ -882,7 +886,7 @@ struct llama_data_write {
} }
void write_logits(const struct llama_context * ctx) { void write_logits(const struct llama_context * ctx) {
const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.hparams.n_vocab); const uint64_t logits_size = std::min((uint64_t) ctx->logits_size, (uint64_t) ctx->n_outputs * ctx->model.vocab.n_tokens());
write(&logits_size, sizeof(logits_size)); write(&logits_size, sizeof(logits_size));

View File

@ -22,12 +22,12 @@ struct llama_context {
const struct llama_model & model; const struct llama_model & model;
struct llama_cparams cparams; struct llama_cparams cparams;
struct llama_sbatch sbatch; // TODO: revisit if needed struct llama_sbatch sbatch; // TODO: revisit if needed
struct llama_kv_cache kv_self; struct llama_kv_cache kv_self;
struct llama_control_vector cvec; struct llama_adapter_cvec cvec;
std::unordered_map<struct llama_lora_adapter *, float> lora_adapters; std::unordered_map<struct llama_adapter_lora *, float> lora;
std::vector<ggml_backend_ptr> backends; std::vector<ggml_backend_ptr> backends;
std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns; std::vector<std::pair<ggml_backend_t, ggml_backend_set_n_threads_t>> set_n_threads_fns;

View File

@ -1092,9 +1092,9 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
const llama_token id = cur_p->data[i].id; const llama_token id = cur_p->data[i].id;
const std::string & piece = grammar.vocab->cache_token_to_piece.at(id); const std::string & piece = grammar.vocab->token_to_piece(id);
if (llama_token_is_eog_impl(*grammar.vocab, id)) { if (grammar.vocab->is_eog(id)) {
if (!allow_eog) { if (!allow_eog) {
cur_p->data[i].logit = -INFINITY; cur_p->data[i].logit = -INFINITY;
} }
@ -1115,7 +1115,7 @@ void llama_grammar_apply_impl(const struct llama_grammar & grammar, llama_token_
void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) { void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token) {
GGML_ASSERT(grammar.vocab != nullptr); GGML_ASSERT(grammar.vocab != nullptr);
if (llama_token_is_eog_impl(*grammar.vocab, token)) { if (grammar.vocab->is_eog(token)) {
for (const auto & stack : grammar.stacks) { for (const auto & stack : grammar.stacks) {
if (stack.empty()) { if (stack.empty()) {
return; return;
@ -1124,7 +1124,7 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
const std::string & piece = grammar.vocab->cache_token_to_piece.at(token); const std::string & piece = grammar.vocab->token_to_piece(token);
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece, grammar.partial_utf8); const auto decoded = decode_utf8(piece, grammar.partial_utf8);

View File

@ -30,7 +30,6 @@ struct llama_hparams {
bool use_par_res; bool use_par_res;
bool swin_norm; bool swin_norm;
uint32_t n_vocab = 0;
uint32_t n_ctx_train; // context size the model was trained on uint32_t n_ctx_train; // context size the model was trained on
uint32_t n_embd; uint32_t n_embd;
uint32_t n_embd_features = 0; uint32_t n_embd_features = 0;
@ -41,7 +40,6 @@ struct llama_hparams {
uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head
uint32_t n_expert = 0; uint32_t n_expert = 0;
uint32_t n_expert_used = 0; uint32_t n_expert_used = 0;
uint32_t n_vocab_type = 0; // for BERT-style token types
uint32_t n_rel_attn_bkts = 0; uint32_t n_rel_attn_bkts = 0;
// for WavTokenizer // for WavTokenizer

View File

@ -79,7 +79,7 @@ bool llama_kv_cache_init(
ggml_backend_buffer_type_t buft; ggml_backend_buffer_type_t buft;
if (offload) { if (offload) {
auto * dev = model.dev_layer.at(i).dev; auto * dev = model.dev_layer(i);
buft = ggml_backend_dev_buffer_type(dev); buft = ggml_backend_dev_buffer_type(dev);
} else { } else {
buft = ggml_backend_cpu_buffer_type(); buft = ggml_backend_cpu_buffer_type();

View File

@ -35,7 +35,7 @@
// TODO: consider moving to llama-impl.h if needed in more places // TODO: consider moving to llama-impl.h if needed in more places
#if defined(_WIN32) #if defined(_WIN32)
std::string llama_format_win_err(DWORD err) { static std::string llama_format_win_err(DWORD err) {
LPSTR buf; LPSTR buf;
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL); NULL, err, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&buf, 0, NULL);

View File

@ -7,6 +7,10 @@
#include <cstring> #include <cstring>
#include <future> #include <future>
static const size_t kiB = 1024;
static const size_t MiB = 1024*kiB;
static const size_t GiB = 1024*MiB;
const char * llama_file_version_name(llama_fver version) { const char * llama_file_version_name(llama_fver version) {
switch (version) { switch (version) {
case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)"; case GGUF_FILE_VERSION_V1: return "GGUF V1 (support until nov 2023)";
@ -17,6 +21,49 @@ const char * llama_file_version_name(llama_fver version) {
return "unknown"; return "unknown";
} }
static std::string llama_model_ftype_name(llama_ftype ftype) {
if (ftype & LLAMA_FTYPE_GUESSED) {
return llama_model_ftype_name((enum llama_ftype) (ftype & ~LLAMA_FTYPE_GUESSED)) + " (guessed)";
}
switch (ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0";
case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1";
case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0";
case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small";
case LLAMA_FTYPE_MOSTLY_Q3_K_M: return "Q3_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q3_K_L: return "Q3_K - Large";
case LLAMA_FTYPE_MOSTLY_Q4_K_S: return "Q4_K - Small";
case LLAMA_FTYPE_MOSTLY_Q4_K_M: return "Q4_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q5_K_S: return "Q5_K - Small";
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "Q5_K - Medium";
case LLAMA_FTYPE_MOSTLY_Q6_K: return "Q6_K";
case LLAMA_FTYPE_MOSTLY_TQ1_0: return "TQ1_0 - 1.69 bpw ternary";
case LLAMA_FTYPE_MOSTLY_TQ2_0: return "TQ2_0 - 2.06 bpw ternary";
case LLAMA_FTYPE_MOSTLY_IQ2_XXS: return "IQ2_XXS - 2.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_XS: return "IQ2_XS - 2.3125 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_S: return "IQ2_S - 2.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ2_M: return "IQ2_M - 2.7 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XS: return "IQ3_XS - 3.3 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_XXS: return "IQ3_XXS - 3.0625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_S: return "IQ1_S - 1.5625 bpw";
case LLAMA_FTYPE_MOSTLY_IQ1_M: return "IQ1_M - 1.75 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_NL: return "IQ4_NL - 4.5 bpw";
case LLAMA_FTYPE_MOSTLY_IQ4_XS: return "IQ4_XS - 4.25 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_S: return "IQ3_S - 3.4375 bpw";
case LLAMA_FTYPE_MOSTLY_IQ3_M: return "IQ3_S mix - 3.66 bpw";
default: return "unknown, may not work";
}
}
namespace GGUFMeta { namespace GGUFMeta {
template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t)> template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int64_t)>
struct GKV_Base_Type { struct GKV_Base_Type {
@ -1009,3 +1056,17 @@ bool llama_model_loader::load_all_data(
return true; return true;
} }
std::string llama_model_loader::ftype_name() const {
return llama_model_ftype_name(ftype);
}
void llama_model_loader::print_info() const {
LLAMA_LOG_INFO("%s: file format = %s\n", __func__, llama_file_version_name(fver));
LLAMA_LOG_INFO("%s: file type = %s\n", __func__, llama_model_ftype_name(ftype).c_str());
if (n_bytes < GiB) {
LLAMA_LOG_INFO("%s: file size = %.2f MiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0, n_bytes*8.0/n_elements);
} else {
LLAMA_LOG_INFO("%s: file size = %.2f GiB (%.2f BPW) \n", __func__, n_bytes/1024.0/1024.0/1024.0, n_bytes*8.0/n_elements);
}
}

View File

@ -155,4 +155,8 @@ struct llama_model_loader {
llama_mlocks * lmlocks, llama_mlocks * lmlocks,
llama_progress_callback progress_callback, llama_progress_callback progress_callback,
void * progress_callback_user_data); void * progress_callback_user_data);
std::string ftype_name() const;
void print_info() const;
}; };

File diff suppressed because it is too large Load Diff

View File

@ -4,79 +4,80 @@
#include "llama-arch.h" #include "llama-arch.h"
#include "llama-hparams.h" #include "llama-hparams.h"
#include "llama-vocab.h" #include "llama-vocab.h"
#include "llama-mmap.h"
#include "ggml-cpp.h"
#include <memory>
#include <string>
#include <unordered_map>
#include <vector> #include <vector>
struct llama_model_loader;
// available models // available models
// TODO: this enum does not follow the enum naming convention
enum llm_type { enum llm_type {
MODEL_UNKNOWN, LLM_TYPE_UNKNOWN,
MODEL_14M, LLM_TYPE_14M,
MODEL_17M, LLM_TYPE_17M,
MODEL_22M, LLM_TYPE_22M,
MODEL_33M, LLM_TYPE_33M,
MODEL_60M, LLM_TYPE_60M,
MODEL_70M, LLM_TYPE_70M,
MODEL_80M, LLM_TYPE_80M,
MODEL_109M, LLM_TYPE_109M,
MODEL_137M, LLM_TYPE_137M,
MODEL_160M, LLM_TYPE_160M,
MODEL_220M, LLM_TYPE_220M,
MODEL_250M, LLM_TYPE_250M,
MODEL_270M, LLM_TYPE_270M,
MODEL_335M, LLM_TYPE_335M,
MODEL_410M, LLM_TYPE_410M,
MODEL_450M, LLM_TYPE_450M,
MODEL_770M, LLM_TYPE_770M,
MODEL_780M, LLM_TYPE_780M,
MODEL_0_5B, LLM_TYPE_0_5B,
MODEL_1B, LLM_TYPE_1B,
MODEL_1_3B, LLM_TYPE_1_3B,
MODEL_1_4B, LLM_TYPE_1_4B,
MODEL_1_5B, LLM_TYPE_1_5B,
MODEL_1_6B, LLM_TYPE_1_6B,
MODEL_2B, LLM_TYPE_2B,
MODEL_2_8B, LLM_TYPE_2_8B,
MODEL_3B, LLM_TYPE_3B,
MODEL_4B, LLM_TYPE_4B,
MODEL_6B, LLM_TYPE_6B,
MODEL_6_9B, LLM_TYPE_6_9B,
MODEL_7B, LLM_TYPE_7B,
MODEL_8B, LLM_TYPE_8B,
MODEL_9B, LLM_TYPE_9B,
MODEL_11B, LLM_TYPE_11B,
MODEL_12B, LLM_TYPE_12B,
MODEL_13B, LLM_TYPE_13B,
MODEL_14B, LLM_TYPE_14B,
MODEL_15B, LLM_TYPE_15B,
MODEL_16B, LLM_TYPE_16B,
MODEL_20B, LLM_TYPE_20B,
MODEL_30B, LLM_TYPE_30B,
MODEL_32B, LLM_TYPE_32B,
MODEL_34B, LLM_TYPE_34B,
MODEL_35B, LLM_TYPE_35B,
MODEL_40B, LLM_TYPE_40B,
MODEL_65B, LLM_TYPE_65B,
MODEL_70B, LLM_TYPE_70B,
MODEL_236B, LLM_TYPE_236B,
MODEL_314B, LLM_TYPE_314B,
MODEL_671B, LLM_TYPE_671B,
MODEL_SMALL, LLM_TYPE_SMALL,
MODEL_MEDIUM, LLM_TYPE_MEDIUM,
MODEL_LARGE, LLM_TYPE_LARGE,
MODEL_XL, LLM_TYPE_XL,
MODEL_A1_7B, LLM_TYPE_A1_7B,
MODEL_A2_7B, LLM_TYPE_A2_7B,
MODEL_8x7B, LLM_TYPE_8x7B,
MODEL_8x22B, LLM_TYPE_8x22B,
MODEL_16x12B, LLM_TYPE_16x12B,
MODEL_16x3_8B, LLM_TYPE_16x3_8B,
MODEL_10B_128x3_66B, LLM_TYPE_10B_128x3_66B,
MODEL_57B_A14B, LLM_TYPE_57B_A14B,
MODEL_27B, LLM_TYPE_27B,
}; };
struct llama_layer_posnet { struct llama_layer_posnet {
@ -286,11 +287,9 @@ struct llama_layer {
}; };
struct llama_model { struct llama_model {
llm_type type = MODEL_UNKNOWN; llm_type type = LLM_TYPE_UNKNOWN;
llm_arch arch = LLM_ARCH_UNKNOWN; llm_arch arch = LLM_ARCH_UNKNOWN;
llama_ftype ftype = LLAMA_FTYPE_ALL_F32;
std::string name = "n/a"; std::string name = "n/a";
llama_hparams hparams = {}; llama_hparams hparams = {};
@ -319,78 +318,55 @@ struct llama_model {
std::vector<llama_layer> layers; std::vector<llama_layer> layers;
llama_model_params params;
// gguf metadata // gguf metadata
std::unordered_map<std::string, std::string> gguf_kv; std::unordered_map<std::string, std::string> gguf_kv;
llama_split_mode split_mode;
int main_gpu;
int n_gpu_layers;
std::vector<std::string> rpc_servers; std::vector<std::string> rpc_servers;
// list of devices used in this model // list of devices used in this model
std::vector<ggml_backend_dev_t> devices; std::vector<ggml_backend_dev_t> devices;
// lists of buffer types used for each layer
using buft_list_t = std::vector<std::pair<ggml_backend_dev_t, ggml_backend_buffer_type_t>>;
buft_list_t cpu_buft_list;
std::map<ggml_backend_dev_t, buft_list_t> gpu_buft_list;
struct layer_dev {
ggml_backend_dev_t dev;
buft_list_t * buft_list;
};
layer_dev dev_input = {};
layer_dev dev_output = {};
std::vector<layer_dev> dev_layer;
// contexts where the model tensors metadata is stored
std::vector<ggml_context_ptr> ctxs;
// the model memory buffers for the tensor data
std::vector<ggml_backend_buffer_ptr> bufs;
// model memory mapped files
llama_mmaps mappings;
// objects representing data potentially being locked in memory
llama_mlocks mlock_bufs;
llama_mlocks mlock_mmaps;
// for quantize-stats only // for quantize-stats only
std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name; std::vector<std::pair<std::string, struct ggml_tensor *>> tensors_by_name;
int64_t t_load_us = 0; int64_t t_load_us = 0;
int64_t t_start_us = 0; int64_t t_start_us = 0;
// total number of parameters in the model explicit llama_model(const struct llama_model_params & params);
uint64_t n_elements = 0; ~llama_model();
// total size of all the tensors in the model in bytes void load_stats (llama_model_loader & ml);
size_t n_bytes = 0; void load_arch (llama_model_loader & ml);
void load_hparams(llama_model_loader & ml);
void load_vocab (llama_model_loader & ml);
bool load_tensors(llama_model_loader & ml); // returns false if cancelled by progress_callback
std::string arch_name() const;
std::string type_name() const;
std::string desc() const;
size_t size() const;
size_t max_nodes() const;
size_t n_devices() const;
// total number of parameters in the model
uint64_t n_elements() const;
void print_info() const;
ggml_backend_dev_t dev_layer(int il) const;
ggml_backend_dev_t dev_output() const;
ggml_backend_buffer_type_t select_buft(int il) const;
const struct ggml_tensor * get_tensor(const char * name) const;
private:
struct impl;
std::unique_ptr<impl> pimpl;
}; };
const char * llm_type_name(llm_type type); const char * llm_type_name(llm_type type);
std::string llama_model_arch_name (const llama_model & model);
std::string llama_model_type_name (const llama_model & model);
std::string llama_model_ftype_name(const llama_model & model);
// used by llama_adapter_cvec
ggml_backend_buffer_type_t llama_model_select_buft(const llama_model & model, int il);
// used by llama_adapter_lora
struct ggml_tensor * llama_model_get_tensor(const struct llama_model & model, const char * name);
size_t llama_model_max_nodes(const llama_model & model);
struct llama_model_loader;
// TODO: become llama_model methods
void llm_load_stats (llama_model_loader & ml, llama_model & model);
void llm_load_arch (llama_model_loader & ml, llama_model & model);
void llm_load_hparams (llama_model_loader & ml, llama_model & model);
void llm_load_vocab (llama_model_loader & ml, llama_model & model);
void llm_load_print_meta(llama_model_loader & ml, llama_model & model);

View File

@ -235,7 +235,7 @@ static ggml_type llama_tensor_get_type(quantize_state_impl & qs, ggml_type new_t
else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) &&
use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K; use_more_bits(qs.i_attention_wv, qs.n_attention_wv)) new_type = GGML_TYPE_Q6_K;
else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K; else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && qs.i_attention_wv < 4) new_type = GGML_TYPE_Q5_K;
if (qs.model.type == MODEL_70B) { if (qs.model.type == LLM_TYPE_70B) {
// In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is // In the 70B model we have 8 heads sharing the same attn_v weights. As a result, the attn_v.weight tensor is
// 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with // 8x smaller compared to attn_q.weight. Hence, we can get a nice boost in quantization accuracy with
// nearly negligible increase in model size by quantizing this tensor with more bits: // nearly negligible increase in model size by quantizing this tensor with more bits:
@ -525,18 +525,20 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides; auto v = (std::vector<llama_model_kv_override>*)params->kv_overrides;
kv_overrides = v->data(); kv_overrides = v->data();
} }
llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides); llama_model_loader ml(fname_inp, use_mmap, /*check_tensors*/ true, kv_overrides);
ml.init_mappings(false); // no prefetching ml.init_mappings(false); // no prefetching
llama_model model; llama_model model(llama_model_default_params());
llm_load_arch (ml, model);
llm_load_hparams(ml, model); model.load_arch (ml);
llm_load_stats (ml, model); model.load_hparams(ml);
model.load_stats (ml);
struct quantize_state_impl qs(model, params); struct quantize_state_impl qs(model, params);
if (params->only_copy) { if (params->only_copy) {
ftype = model.ftype; ftype = ml.ftype;
} }
const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr; const std::unordered_map<std::string, std::vector<float>> * imatrix_data = nullptr;
if (params->imatrix) { if (params->imatrix) {

View File

@ -371,7 +371,10 @@ void llama_sampler_free(struct llama_sampler * smpl) {
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
const auto * logits = llama_get_logits_ith(ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab);
// TODO: do not allocate each time // TODO: do not allocate each time
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
@ -1445,7 +1448,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr); auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr);
// copy the state // copy the state
{ {
@ -1481,19 +1484,19 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
/* .free = */ llama_sampler_grammar_free, /* .free = */ llama_sampler_grammar_free,
}; };
struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
auto * ctx = new llama_sampler_grammar; auto * ctx = new llama_sampler_grammar;
if (grammar_str != nullptr && grammar_str[0] != '\0') { if (grammar_str != nullptr && grammar_str[0] != '\0') {
*ctx = { *ctx = {
/* .vocab = */ &vocab, /* .vocab = */ vocab,
/* .grammar_str = */ grammar_str, /* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root, /* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root),
}; };
} else { } else {
*ctx = { *ctx = {
/* .vocab = */ &vocab, /* .vocab = */ vocab,
/* .grammar_str = */ {}, /* .grammar_str = */ {},
/* .grammar_root = */ {}, /* .grammar_root = */ {},
/* .grammar = */ nullptr, /* .grammar = */ nullptr,
@ -1663,8 +1666,8 @@ struct llama_sampler_dry {
// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) // Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am)
static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) { static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap<llama_token, std::vector<llama_token>>& token_sequences, int max_tail_len = -1) {
for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) { for (llama_token token_id = 0; token_id < (llama_token) vocab.n_tokens(); token_id++) {
std::string word = llama_detokenize(vocab, {token_id}, true); std::string word = vocab.detokenize({token_id}, true);
if (word.find(str) != std::string::npos) { if (word.find(str) != std::string::npos) {
token_sequences.emplace(token_id, std::vector<llama_token>()); token_sequences.emplace(token_id, std::vector<llama_token>());
} else { } else {
@ -1681,7 +1684,7 @@ static void get_overlapping_token_sequences(const llama_vocab & vocab, const std
} }
} }
if (match) { if (match) {
std::vector<llama_token> tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); std::vector<llama_token> tokenization = vocab.tokenize(str.substr(i), false, false);
if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) {
tokenization.resize(max_tail_len); tokenization.resize(max_tail_len);
} }
@ -1937,7 +1940,7 @@ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler
llama_vocab dummy_vocab; llama_vocab dummy_vocab;
// dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
// Copy the state, including the processed breakers // Copy the state, including the processed breakers
{ {
@ -1964,7 +1967,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
/* .free = */ llama_sampler_dry_free, /* .free = */ llama_sampler_dry_free,
}; };
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers; std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
const int MAX_CHAR_LEN = 40; const int MAX_CHAR_LEN = 40;
@ -1991,7 +1994,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
sequence_break.resize(MAX_CHAR_LEN); sequence_break.resize(MAX_CHAR_LEN);
} }
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
} }
} }
@ -2014,7 +2017,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
// wrapper for test-sampling.cpp // wrapper for test-sampling.cpp
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) { struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
llama_vocab dummy_vocab; llama_vocab dummy_vocab;
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0); auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
auto * ctx = (llama_sampler_dry *) result->ctx; auto * ctx = (llama_sampler_dry *) result->ctx;
// Process the token-based sequence breakers // Process the token-based sequence breakers
@ -2153,7 +2156,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
float p_eog_sum = 0.0f; float p_eog_sum = 0.0f;
for (size_t i = 0; i < cur_p->size; ++i) { for (size_t i = 0; i < cur_p->size; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { if (ctx->vocab->is_eog(cur_p->data[i].id)) {
p_eog_sum += cur_p->data[i].p; p_eog_sum += cur_p->data[i].p;
} else { } else {
p_txt_sum += cur_p->data[i].p; p_txt_sum += cur_p->data[i].p;
@ -2175,7 +2178,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
float p_sum = 0.0f; float p_sum = 0.0f;
for (size_t i = 0; i < size_org; ++i) { for (size_t i = 0; i < size_org; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { if (ctx->vocab->is_eog(cur_p->data[i].id)) {
p_sum += cur_p->data[i].p; p_sum += cur_p->data[i].p;
cur_p->data[cur_p->size++] = cur_p->data[i]; cur_p->data[cur_p->size++] = cur_p->data[i];
@ -2203,17 +2206,17 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
continue; continue;
} }
int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); int len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
if (len0 < 0) { if (len0 < 0) {
ctx->buf0.resize(len0); ctx->buf0.resize(len0);
len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); len0 = ctx->vocab->token_to_piece(cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false);
assert(len0 > 0); assert(len0 > 0);
} }
int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); int len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
if (len1 < 0) { if (len1 < 0) {
ctx->buf1.resize(len1); ctx->buf1.resize(len1);
len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); len1 = ctx->vocab->token_to_piece(cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false);
assert(len1 > 0); assert(len1 > 0);
} }
@ -2248,7 +2251,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold); LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold);
for (size_t i = 0; i < size_org; ++i) { for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id); const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
if (cur_p->data[i].p < thold && !is_eog) { if (cur_p->data[i].p < thold && !is_eog) {
continue; continue;
@ -2269,7 +2272,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
// if no non-EOG tokens are left -> reduce cur_p to single EOT token // if no non-EOG tokens are left -> reduce cur_p to single EOT token
if (n_non_eog == 0) { if (n_non_eog == 0) {
cur_p->size = 1; cur_p->size = 1;
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab); cur_p->data[0].id = ctx->vocab->token_eot();
cur_p->data[0].logit = 1.0f; cur_p->data[0].logit = 1.0f;
return; return;
@ -2291,7 +2294,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold); LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold);
for (size_t i = 0; i < size_org; ++i) { for (size_t i = 0; i < size_org; ++i) {
const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id); const bool is_eog = ctx->vocab->is_eog(cur_p->data[i].id);
if (cur_p->data[i].p < thold && !is_eog) { if (cur_p->data[i].p < thold && !is_eog) {
continue; continue;
@ -2314,7 +2317,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_infill *) smpl->ctx; const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
return llama_sampler_init_infill_impl(*ctx->vocab); return llama_sampler_init_infill(ctx->vocab);
} }
static void llama_sampler_infill_free(struct llama_sampler * smpl) { static void llama_sampler_infill_free(struct llama_sampler * smpl) {
@ -2330,14 +2333,13 @@ static struct llama_sampler_i llama_sampler_infill_i = {
/* .free = */ llama_sampler_infill_free, /* .free = */ llama_sampler_infill_free,
}; };
struct llama_sampler * llama_sampler_init_infill_impl( struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
const struct llama_vocab & vocab) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_infill_i, /* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill { /* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab, /* .vocab = */ vocab,
/* .buf0 = */ std::vector<char>(512), /* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512), /* .buf1 = */ std::vector<char>(512),
}, },
}; };
} }

View File

@ -2,7 +2,9 @@
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? // TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
#include "llama-grammar.h" #include "llama.h"
#include <vector>
struct llama_vocab; struct llama_vocab;
struct llama_grammar; struct llama_grammar;
@ -21,24 +23,6 @@ struct llama_sampler_chain {
mutable int32_t n_sample; mutable int32_t n_sample;
}; };
struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab,
const char * grammar_str,
const char * grammar_root);
struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab);
struct llama_sampler * llama_sampler_init_dry_impl(
const struct llama_vocab & vocab,
int32_t context_size,
float dry_multiplier,
float dry_base,
int32_t dry_allowed_length,
int32_t dry_penalty_last_n,
const char ** seq_breakers,
size_t num_breakers);
struct llama_sampler * llama_sampler_init_dry_testing( struct llama_sampler * llama_sampler_init_dry_testing(
int32_t context_size, int32_t context_size,
float dry_multiplier, float dry_multiplier,

File diff suppressed because it is too large Load Diff

View File

@ -4,179 +4,123 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <memory>
#include <map>
#include <set>
static const char * llama_model_vocab_type_name(enum llama_vocab_type type){ struct LLM_KV;
switch (type) { struct llama_model_loader;
case LLAMA_VOCAB_TYPE_NONE: return "no vocab";
case LLAMA_VOCAB_TYPE_SPM: return "SPM";
case LLAMA_VOCAB_TYPE_BPE: return "BPE";
case LLAMA_VOCAB_TYPE_WPM: return "WPM";
case LLAMA_VOCAB_TYPE_UGM: return "UGM";
case LLAMA_VOCAB_TYPE_RWKV: return "RWKV";
default: return "unknown";
}
}
struct llm_tokenizer;
struct llama_vocab { struct llama_vocab {
using id = llama_token;
using token = std::string;
using tattr = llama_token_attr;
struct token_data { struct token_data {
token text; std::string text;
float score; float score;
tattr attr; llama_token_attr attr;
}; };
uint32_t n_vocab = 0; // TODO: not great because has to keep in sync with hparams.n_vocab llama_vocab();
enum llama_vocab_type type = LLAMA_VOCAB_TYPE_SPM;
enum llama_vocab_pre_type type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
int max_token_len = 0; // used for optimizing longest token search
std::unordered_map<token, id> token_to_id;
std::vector<token_data> id_to_token;
std::vector<id> cache_special_tokens;
std::vector<token> cache_token_to_piece; // llama_token_to_piece(special = true);
std::map<std::pair<std::string, std::string>, int> bpe_ranks;
// default LLaMA special tokens
// TODO: should we set all of these to LLAMA_TOKEN_NULL?
id special_bos_id = 1;
id special_eos_id = 2;
id special_eot_id = LLAMA_TOKEN_NULL;
id special_eom_id = LLAMA_TOKEN_NULL;
id special_unk_id = 0;
id special_sep_id = LLAMA_TOKEN_NULL;
id special_pad_id = LLAMA_TOKEN_NULL;
id special_cls_id = LLAMA_TOKEN_NULL; // TODO: revisit if this is really needed https://github.com/ggerganov/llama.cpp/pull/10930
id special_mask_id = LLAMA_TOKEN_NULL;
id linefeed_id = 13;
// fim tokens
id special_fim_pre_id = LLAMA_TOKEN_NULL;
id special_fim_suf_id = LLAMA_TOKEN_NULL;
id special_fim_mid_id = LLAMA_TOKEN_NULL;
id special_fim_pad_id = LLAMA_TOKEN_NULL;
id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo
id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator
// set of all tokens that cause "end of generation"
std::set<id> special_eog_ids;
// tokenizer flags
bool tokenizer_add_space_prefix = false;
bool tokenizer_add_bos = false;
bool tokenizer_add_eos = false;
bool tokenizer_ignore_merges = false;
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces
bool tokenizer_remove_extra_whitespaces = false;
bool tokenizer_escape_whitespaces = true;
bool tokenizer_treat_whitespace_as_suffix = false;
std::vector<char> precompiled_charsmap;
llm_tokenizer * tokenizer = nullptr;
llama_vocab() = default;
~llama_vocab(); ~llama_vocab();
void load(llama_model_loader & ml, const LLM_KV & kv);
enum llama_vocab_type get_type() const;
enum llama_vocab_pre_type get_pre_type() const;
uint32_t n_tokens() const;
uint32_t n_token_types() const;
std::string type_name() const;
bool is_normal (llama_token id) const;
bool is_unknown (llama_token id) const;
bool is_control (llama_token id) const;
bool is_byte (llama_token id) const;
bool is_user_defined(llama_token id) const;
bool is_unused (llama_token id) const;
bool is_eog (llama_token id) const;
uint8_t token_to_byte(llama_token id) const;
llama_token byte_to_token(uint8_t ch) const;
llama_token text_to_token(const std::string & text) const;
const token_data & get_token_data(llama_token id) const;
const char * token_get_text (llama_token id) const;
float token_get_score(llama_token id) const;
llama_token_attr token_get_attr (llama_token id) const;
llama_token token_bos() const;
llama_token token_eos() const;
llama_token token_eot() const;
llama_token token_eom() const;
llama_token token_unk() const;
llama_token token_cls() const;
llama_token token_sep() const;
llama_token token_nl () const;
llama_token token_pad() const;
llama_token token_prefix() const;
llama_token token_middle() const;
llama_token token_suffix() const;
llama_token token_fim_pre() const;
llama_token token_fim_suf() const;
llama_token token_fim_mid() const;
llama_token token_fim_pad() const;
llama_token token_fim_rep() const;
llama_token token_fim_sep() const;
bool get_add_space_prefix () const;
bool get_add_bos () const;
bool get_add_eos () const;
bool get_ignore_merges () const;
bool get_clean_spaces () const;
bool get_remove_extra_whitespaces () const;
bool get_escape_whitespaces () const;
bool get_treat_whitespace_as_suffix() const;
int max_token_len() const;
int find_bpe_rank(const std::string & token_left, const std::string & token_right) const; int find_bpe_rank(const std::string & token_left, const std::string & token_right) const;
void init_tokenizer(); int32_t tokenize(
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special) const;
std::vector<llama_token> tokenize(
const std::string & raw_text,
bool add_special,
bool parse_special = false) const;
// does not write null-terminator to buf
int32_t token_to_piece(
llama_token token,
char * buf,
int32_t length,
int32_t lstrip,
bool special) const;
// use cached data
const std::string & token_to_piece(llama_token token) const;
int32_t detokenize(
const llama_token * tokens,
int32_t n_tokens,
char * text,
int32_t text_len_max,
bool remove_special,
bool unparse_special) const;
std::string detokenize(
const std::vector<llama_token> & tokens,
bool special) const;
void print_info() const;
private:
struct impl;
std::unique_ptr<impl> pimpl;
}; };
//
// internal API
//
// TODO: rename to llama_tokenize_impl
// TODO: This should probably be in llama.h
std::vector<llama_vocab::id> llama_tokenize_internal(
const llama_vocab & vocab,
std::string raw_text,
bool add_special,
bool parse_special = false);
// TODO: move the API below as member functions of llama_vocab
llama_token llama_byte_to_token_impl(const llama_vocab & vocab, uint8_t ch);
const char * llama_token_get_text_impl(const struct llama_vocab & vocab, llama_token token);
float llama_token_get_score_impl(const struct llama_vocab & vocab, llama_token token);
llama_token_attr llama_token_get_attr_impl(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_eog_impl(const struct llama_vocab & vocab, llama_token token);
bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token token);
llama_token llama_token_bos_impl(const struct llama_vocab & vocab);
llama_token llama_token_eos_impl(const struct llama_vocab & vocab);
llama_token llama_token_eot_impl(const struct llama_vocab & vocab);
llama_token llama_token_eom_impl(const struct llama_vocab & vocab);
llama_token llama_token_cls_impl(const struct llama_vocab & vocab);
llama_token llama_token_sep_impl(const struct llama_vocab & vocab);
llama_token llama_token_nl_impl (const struct llama_vocab & vocab);
llama_token llama_token_pad_impl(const struct llama_vocab & vocab);
llama_token llama_token_prefix_impl(const struct llama_vocab & vocab);
llama_token llama_token_middle_impl(const struct llama_vocab & vocab);
llama_token llama_token_suffix_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab);
llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab);
bool llama_add_bos_token_impl(const struct llama_vocab & vocab);
bool llama_add_eos_token_impl(const struct llama_vocab & vocab);
int32_t llama_tokenize_impl(
const struct llama_vocab & vocab,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special);
// does not write null-terminator to buf
int32_t llama_token_to_piece_impl(
const struct llama_vocab & vocab,
llama_token token,
char * buf,
int32_t length,
int32_t lstrip,
bool special);
// check if token0 is contained as a prefix in token1
bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1);
int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
int32_t n_tokens,
char * text,
int32_t text_len_max,
bool remove_special,
bool unparse_special);
std::string llama_detokenize(
const struct llama_vocab & vocab,
const std::vector<llama_token> & tokens,
bool special);

File diff suppressed because it is too large Load Diff

View File

@ -14,7 +14,7 @@ int main(int argc, char ** argv) {
std::thread([&model_path]() { std::thread([&model_path]() {
llama_backend_init(); llama_backend_init();
auto * model = llama_model_load_from_file(model_path, llama_model_default_params()); auto * model = llama_model_load_from_file(model_path, llama_model_default_params());
auto * ctx = llama_new_context_with_model(model, llama_context_default_params()); auto * ctx = llama_init_from_model(model, llama_context_default_params());
llama_free(ctx); llama_free(ctx);
llama_model_free(model); llama_model_free(model);
llama_backend_free(); llama_backend_free();

View File

@ -157,7 +157,7 @@ int main(void) {
} }
// test invalid chat template // test invalid chat template
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); res = llama_chat_apply_template("INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
assert(res < 0); assert(res < 0);
for (size_t i = 0; i < templates.size(); i++) { for (size_t i = 0; i < templates.size(); i++) {
@ -165,7 +165,6 @@ int main(void) {
std::string expected = expected_output[i]; std::string expected = expected_output[i];
formatted_chat.resize(1024); formatted_chat.resize(1024);
res = llama_chat_apply_template( res = llama_chat_apply_template(
nullptr,
custom_template.c_str(), custom_template.c_str(),
conversation, conversation,
message_count, message_count,

View File

@ -161,7 +161,7 @@ int main(int argc, char **argv) {
auto cparams = llama_context_default_params(); auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams); ctx = llama_init_from_model(model, cparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());

View File

@ -55,7 +55,7 @@ int main(int argc, char **argv) {
auto cparams = llama_context_default_params(); auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams); ctx = llama_init_from_model(model, cparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@ -64,8 +64,10 @@ int main(int argc, char **argv) {
} }
} }
//GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_BPE); const llama_vocab * vocab = llama_model_get_vocab(model);
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
//GGML_ASSERT(llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_BPE);
if (llama_vocab_type(vocab) != LLAMA_VOCAB_TYPE_BPE) {
return 99; return 99;
} }
@ -75,7 +77,7 @@ int main(int argc, char **argv) {
atexit([]() { console::cleanup(); }); atexit([]() { console::cleanup(); });
#endif #endif
const int n_vocab = llama_n_vocab(model); const int n_vocab = llama_vocab_n_tokens(vocab);
for (int i = 0; i < n_vocab; ++i) { for (int i = 0; i < n_vocab; ++i) {
std::string str = common_detokenize(ctx, std::vector<int>(1, i)); std::string str = common_detokenize(ctx, std::vector<int>(1, i));

View File

@ -43,7 +43,7 @@ int main(int argc, char ** argv) {
auto cparams = llama_context_default_params(); auto cparams = llama_context_default_params();
ctx = llama_new_context_with_model(model, cparams); ctx = llama_init_from_model(model, cparams);
if (ctx == NULL) { if (ctx == NULL) {
fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str());
@ -52,8 +52,10 @@ int main(int argc, char ** argv) {
} }
} }
const llama_vocab * vocab = llama_model_get_vocab(model);
//GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_SPM) { if (llama_vocab_type(vocab) != LLAMA_VOCAB_TYPE_SPM) {
return 99; return 99;
} }
@ -63,7 +65,7 @@ int main(int argc, char ** argv) {
atexit([]() { console::cleanup(); }); atexit([]() { console::cleanup(); });
#endif #endif
const int n_vocab = llama_n_vocab(model); const int n_vocab = llama_vocab_n_tokens(vocab);
for (int i = 0; i < n_vocab; ++i) { for (int i = 0; i < n_vocab; ++i) {
std::string str = common_detokenize(ctx, std::vector<int>(1, i), true); std::string str = common_detokenize(ctx, std::vector<int>(1, i), true);

View File

@ -76,7 +76,7 @@ class LibLlamaModel:
self.ffi = libllama.ffi self.ffi = libllama.ffi
if isinstance(mparams, dict): if isinstance(mparams, dict):
mparams = libllama.model_default_params(**mparams) mparams = libllama.model_default_params(**mparams)
self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams) self.model = self.lib.llama_model_load_from_file(path_model.encode(), mparams)
if not self.model: if not self.model:
raise RuntimeError("error: failed to load model '%s'" % path_model) raise RuntimeError("error: failed to load model '%s'" % path_model)
if isinstance(cparams, dict): if isinstance(cparams, dict):
@ -92,7 +92,7 @@ class LibLlamaModel:
if self.ctx: if self.ctx:
self.lib.llama_free(self.ctx) self.lib.llama_free(self.ctx)
if self.model: if self.model:
self.lib.llama_free_model(self.model) self.lib.llama_model_free(self.model)
self.ctx = None self.ctx = None
self.model = None self.model = None
self.lib = None self.lib = None