llama : add struct llama_vocab to the API (#11156)

ggml-ci
This commit is contained in:
Georgi Gerganov 2025-01-09 15:28:52 +02:00
parent 609ec7e0a0
commit c725f691ea
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
46 changed files with 592 additions and 495 deletions

View File

@ -857,21 +857,23 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams; return iparams;
} }
const llama_vocab * vocab = llama_get_vocab(model);
if (params.reranking) { if (params.reranking) {
bool ok = true; bool ok = true;
if (llama_token_bos(model) == LLAMA_TOKEN_NULL) { if (llama_token_bos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have a BOS token, reranking will not work\n", __func__); LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__);
ok = false; ok = false;
} }
if (llama_token_eos(model) == LLAMA_TOKEN_NULL) { if (llama_token_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have an EOS token, reranking will not work\n", __func__); LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__);
ok = false; ok = false;
} }
if (llama_token_sep(model) == LLAMA_TOKEN_NULL) { if (llama_token_sep(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have a SEP token, reranking will not work\n", __func__); LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__);
ok = false; ok = false;
} }
@ -941,14 +943,14 @@ struct common_init_result common_init_from_params(common_params & params) {
common_lora_adapters_apply(lctx, params.lora_adapters); common_lora_adapters_apply(lctx, params.lora_adapters);
} }
if (params.sampling.ignore_eos && llama_token_eos(model) == LLAMA_TOKEN_NULL) { if (params.sampling.ignore_eos && llama_token_eos(vocab) == LLAMA_TOKEN_NULL) {
LOG_WRN("%s: warning: model does not have an EOS token, ignoring --ignore-eos\n", __func__); LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__);
params.sampling.ignore_eos = false; params.sampling.ignore_eos = false;
} }
if (params.sampling.ignore_eos) { if (params.sampling.ignore_eos) {
for (llama_token i = 0; i < llama_n_vocab(model); i++) { for (llama_token i = 0; i < llama_n_vocab(vocab); i++) {
if (llama_token_is_eog(model, i)) { if (llama_token_is_eog(vocab, i)) {
LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY);
params.sampling.logit_bias.push_back({i, -INFINITY}); params.sampling.logit_bias.push_back({i, -INFINITY});
} }
@ -969,8 +971,9 @@ struct common_init_result common_init_from_params(common_params & params) {
LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__);
std::vector<llama_token> tmp; std::vector<llama_token> tmp;
llama_token bos = llama_token_bos(model); llama_token bos = llama_token_bos(vocab);
llama_token eos = llama_token_eos(model); llama_token eos = llama_token_eos(vocab);
// some models (e.g. T5) don't have a BOS token // some models (e.g. T5) don't have a BOS token
if (bos != LLAMA_TOKEN_NULL) { if (bos != LLAMA_TOKEN_NULL) {
tmp.push_back(bos); tmp.push_back(bos);
@ -1559,21 +1562,23 @@ std::vector<llama_token> common_tokenize(
const std::string & text, const std::string & text,
bool add_special, bool add_special,
bool parse_special) { bool parse_special) {
return common_tokenize(llama_get_model(ctx), text, add_special, parse_special); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
return common_tokenize(vocab, text, add_special, parse_special);
} }
std::vector<llama_token> common_tokenize( std::vector<llama_token> common_tokenize(
const struct llama_model * model, const struct llama_vocab * vocab,
const std::string & text, const std::string & text,
bool add_special, bool add_special,
bool parse_special) { bool parse_special) {
// upper limit for the number of tokens // upper limit for the number of tokens
int n_tokens = text.length() + 2 * add_special; int n_tokens = text.length() + 2 * add_special;
std::vector<llama_token> result(n_tokens); std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); n_tokens = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
if (n_tokens < 0) { if (n_tokens < 0) {
result.resize(-n_tokens); result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_special, parse_special); int check = llama_tokenize(vocab, text.data(), text.length(), result.data(), result.size(), add_special, parse_special);
GGML_ASSERT(check == -n_tokens); GGML_ASSERT(check == -n_tokens);
} else { } else {
result.resize(n_tokens); result.resize(n_tokens);
@ -1582,12 +1587,18 @@ std::vector<llama_token> common_tokenize(
} }
std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
return common_token_to_piece(vocab, token, special);
}
std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token token, bool special) {
std::string piece; std::string piece;
piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n'
const int n_chars = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); const int n_chars = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
if (n_chars < 0) { if (n_chars < 0) {
piece.resize(-n_chars); piece.resize(-n_chars);
int check = llama_token_to_piece(llama_get_model(ctx), token, &piece[0], piece.size(), 0, special); int check = llama_token_to_piece(vocab, token, &piece[0], piece.size(), 0, special);
GGML_ASSERT(check == -n_chars); GGML_ASSERT(check == -n_chars);
} }
else { else {
@ -1597,13 +1608,19 @@ std::string common_token_to_piece(const struct llama_context * ctx, llama_token
return piece; return piece;
} }
std::string common_detokenize(llama_context * ctx, const std::vector<llama_token> & tokens, bool special) { std::string common_detokenize(const struct llama_context * ctx, const std::vector<llama_token> & tokens, bool special) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
return common_detokenize(vocab, tokens, special);
}
std::string common_detokenize(const struct llama_vocab * vocab, const std::vector<llama_token> & tokens, bool special) {
std::string text; std::string text;
text.resize(std::max(text.capacity(), tokens.size())); text.resize(std::max(text.capacity(), tokens.size()));
int32_t n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); int32_t n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
if (n_chars < 0) { if (n_chars < 0) {
text.resize(-n_chars); text.resize(-n_chars);
n_chars = llama_detokenize(llama_get_model(ctx), tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); n_chars = llama_detokenize(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special);
GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization
} }
@ -1631,7 +1648,7 @@ std::string common_get_builtin_chat_template(const struct llama_model * model) {
bool common_chat_verify_template(const std::string & tmpl) { bool common_chat_verify_template(const std::string & tmpl) {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0); const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0);
return res >= 0; return res >= 0;
} }
@ -1642,16 +1659,16 @@ std::string common_chat_apply_template(const struct llama_model * model,
int alloc_size = 0; int alloc_size = 0;
bool fallback = false; // indicate if we must fallback to default chatml bool fallback = false; // indicate if we must fallback to default chatml
std::vector<llama_chat_message> chat; std::vector<llama_chat_message> chat;
for (auto & msg : msgs) { for (const auto & msg : msgs) {
chat.push_back({msg.role.c_str(), msg.content.c_str()}); chat.push_back({msg.role.c_str(), msg.content.c_str()});
alloc_size += (msg.role.size() + msg.content.size()) * 1.25; alloc_size += (msg.role.size() + msg.content.size()) * 1.25;
} }
const char * ptr_tmpl = tmpl.empty() ? nullptr : tmpl.c_str(); const char * ptr_tmpl = tmpl.empty() ? llama_model_chat_template(model) : tmpl.c_str();
std::vector<char> buf(alloc_size); std::vector<char> buf(alloc_size);
// run the first time to get the total output length // run the first time to get the total output length
int32_t res = llama_chat_apply_template(model, ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size()); int32_t res = llama_chat_apply_template(ptr_tmpl, chat.data(), chat.size(), add_ass, buf.data(), buf.size());
// error: chat template is not supported // error: chat template is not supported
if (res < 0) { if (res < 0) {
@ -1659,18 +1676,17 @@ std::string common_chat_apply_template(const struct llama_model * model,
// if the custom "tmpl" is not supported, we throw an error // if the custom "tmpl" is not supported, we throw an error
// this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template() // this is a bit redundant (for good), since we're not sure if user validated the custom template with llama_chat_verify_template()
throw std::runtime_error("this custom template is not supported"); throw std::runtime_error("this custom template is not supported");
} else {
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template(nullptr, "chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
} }
// If the built-in template is not supported, we default to chatml
res = llama_chat_apply_template("chatml", chat.data(), chat.size(), add_ass, buf.data(), buf.size());
fallback = true;
} }
// if it turns out that our buffer is too small, we resize it // if it turns out that our buffer is too small, we resize it
if ((size_t) res > buf.size()) { if ((size_t) res > buf.size()) {
buf.resize(res); buf.resize(res);
res = llama_chat_apply_template( res = llama_chat_apply_template(
fallback ? nullptr : model,
fallback ? "chatml" : ptr_tmpl, fallback ? "chatml" : ptr_tmpl,
chat.data(), chat.size(), add_ass, buf.data(), buf.size()); chat.data(), chat.size(), add_ass, buf.data(), buf.size());
} }

View File

@ -541,7 +541,7 @@ std::vector<llama_token> common_tokenize(
bool parse_special = false); bool parse_special = false);
std::vector<llama_token> common_tokenize( std::vector<llama_token> common_tokenize(
const struct llama_model * model, const struct llama_vocab * vocab,
const std::string & text, const std::string & text,
bool add_special, bool add_special,
bool parse_special = false); bool parse_special = false);
@ -553,11 +553,21 @@ std::string common_token_to_piece(
llama_token token, llama_token token,
bool special = true); bool special = true);
std::string common_token_to_piece(
const struct llama_vocab * vocab,
llama_token token,
bool special = true);
// detokenizes a vector of tokens into a string // detokenizes a vector of tokens into a string
// should work similar to Python's `tokenizer.decode` // should work similar to Python's `tokenizer.decode`
// optionally renders special/control tokens // optionally renders special/control tokens
std::string common_detokenize( std::string common_detokenize(
llama_context * ctx, const struct llama_context * ctx,
const std::vector<llama_token> & tokens,
bool special = true);
std::string common_detokenize(
const struct llama_vocab * vocab,
const std::vector<llama_token> & tokens, const std::vector<llama_token> & tokens,
bool special = true); bool special = true);

View File

@ -113,7 +113,10 @@ struct common_sampler {
void set_logits(struct llama_context * ctx, int idx) { void set_logits(struct llama_context * ctx, int idx) {
const auto * logits = llama_get_logits_ith(ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
cur.resize(n_vocab); cur.resize(n_vocab);
@ -142,13 +145,15 @@ std::string common_params_sampling::print() const {
} }
struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_get_vocab(model);
llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
lparams.no_perf = params.no_perf; lparams.no_perf = params.no_perf;
auto * result = new common_sampler { auto * result = new common_sampler {
/* .params = */ params, /* .params = */ params,
/* .grmr = */ llama_sampler_init_grammar(model, params.grammar.c_str(), "root"), /* .grmr = */ llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"),
/* .chain = */ llama_sampler_chain_init(lparams), /* .chain = */ llama_sampler_chain_init(lparams),
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)), /* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
/* .cur = */ {}, /* .cur = */ {},
@ -157,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_chain_add(result->chain,
llama_sampler_init_logit_bias( llama_sampler_init_logit_bias(
llama_n_vocab(model), llama_n_vocab(vocab),
params.logit_bias.size(), params.logit_bias.size(),
params.logit_bias.data())); params.logit_bias.data()));
@ -172,7 +177,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
c_breakers.push_back(str.c_str()); c_breakers.push_back(str.c_str());
} }
llama_sampler_chain_add(result->chain, llama_sampler_init_dry (model, params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size()));
} }
break; break;
case COMMON_SAMPLER_TYPE_TOP_K: case COMMON_SAMPLER_TYPE_TOP_K:
@ -194,7 +199,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break; break;
case COMMON_SAMPLER_TYPE_INFILL: case COMMON_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model)); llama_sampler_chain_add(result->chain, llama_sampler_init_infill (vocab));
break; break;
case COMMON_SAMPLER_TYPE_PENALTIES: case COMMON_SAMPLER_TYPE_PENALTIES:
llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present)); llama_sampler_chain_add(result->chain, llama_sampler_init_penalties(params.penalty_last_n, params.penalty_repeat, params.penalty_freq, params.penalty_present));
@ -206,7 +211,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed));
} else if (params.mirostat == 1) { } else if (params.mirostat == 1) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(model), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100));
} else if (params.mirostat == 2) { } else if (params.mirostat == 2) {
llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp));
llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta));

View File

@ -79,10 +79,13 @@ bool common_speculative_are_compatible(
const struct llama_model * model_tgt = llama_get_model(ctx_tgt); const struct llama_model * model_tgt = llama_get_model(ctx_tgt);
const struct llama_model * model_dft = llama_get_model(ctx_dft); const struct llama_model * model_dft = llama_get_model(ctx_dft);
const bool vocab_type_tgt = llama_vocab_type(model_tgt); const struct llama_vocab * vocab_tgt = llama_get_vocab(model_tgt);
const struct llama_vocab * vocab_dft = llama_get_vocab(model_dft);
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft); const bool vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) { if (vocab_type_tgt != vocab_type_dft) {
@ -91,34 +94,34 @@ bool common_speculative_are_compatible(
return false; return false;
} }
if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || if (llama_add_bos_token(vocab_tgt) != llama_add_bos_token(vocab_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || llama_add_eos_token(vocab_tgt) != llama_add_eos_token(vocab_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) || llama_token_bos(vocab_tgt) != llama_token_bos(vocab_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft)) { llama_token_eos(vocab_tgt) != llama_token_eos(vocab_dft)) {
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__);
LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_tgt), llama_add_bos_token(model_tgt), llama_token_eos(model_tgt), llama_add_eos_token(model_tgt)); LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(vocab_tgt), llama_add_bos_token(vocab_tgt), llama_token_eos(vocab_tgt), llama_add_eos_token(vocab_tgt));
LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(model_dft), llama_add_bos_token(model_dft), llama_token_eos(model_dft), llama_add_eos_token(model_dft)); LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(vocab_dft), llama_add_bos_token(vocab_dft), llama_token_eos(vocab_dft), llama_add_eos_token(vocab_dft));
return false; return false;
} }
{ {
const int n_vocab_tgt = llama_n_vocab(model_tgt); const int n_vocab_tgt = llama_n_vocab(vocab_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft); const int n_vocab_dft = llama_n_vocab(vocab_dft);
const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft);
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " LOG_ERR("%s: draft model vocab must closely match target model to use speculation but "
"target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
__func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); __func__, n_vocab_tgt, llama_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false; return false;
} }
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i); const char * token_text_tgt = llama_token_get_text(vocab_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i); const char * token_text_dft = llama_token_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) { if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft model vocab must match target model to use speculation but " LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but "
"token %d content differs - target '%s', draft '%s'\n", __func__, i, "token %d content differs - target '%s', draft '%s'\n", __func__, i,
common_token_to_piece(ctx_tgt, i).c_str(), common_token_to_piece(ctx_tgt, i).c_str(),
common_token_to_piece(ctx_dft, i).c_str()); common_token_to_piece(ctx_dft, i).c_str());

View File

@ -23,12 +23,12 @@ defer {
} }
let model_params = llama_model_default_params() let model_params = llama_model_default_params()
guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), model_params) else { guard let model = llama_model_load_from_file(modelPath.cString(using: .utf8), model_params) else {
print("Failed to load model") print("Failed to load model")
exit(1) exit(1)
} }
defer { defer {
llama_free_model(model) llama_model_free(model)
} }
var tokens = tokenize(text: prompt, add_bos: true) var tokens = tokenize(text: prompt, add_bos: true)

View File

@ -48,10 +48,12 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_get_vocab(model);
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> tokens_list; std::vector<llama_token> tokens_list;
tokens_list = common_tokenize(model, params.prompt, true); tokens_list = common_tokenize(vocab, params.prompt, true);
const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel; const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size())*n_parallel;
@ -121,7 +123,7 @@ int main(int argc, char ** argv) {
llama_token decoder_start_token_id = llama_model_decoder_start_token(model); llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) { if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = llama_token_bos(model); decoder_start_token_id = llama_token_bos(vocab);
} }
common_batch_clear(batch); common_batch_clear(batch);
@ -174,7 +176,7 @@ int main(int argc, char ** argv) {
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
// is it an end of generation? -> mark the stream as finished // is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) { if (llama_token_is_eog(vocab, new_token_id) || n_cur == n_predict) {
i_batch[i] = -1; i_batch[i] = -1;
LOG("\n"); LOG("\n");
if (n_parallel > 1) { if (n_parallel > 1) {

View File

@ -273,7 +273,9 @@ struct tokenized_prompt {
size_t max_seq_len; size_t max_seq_len;
tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const bool add_bos = llama_add_bos_token(vocab);
tokens_pos = common_tokenize(ctx, pos, add_bos, true); tokens_pos = common_tokenize(ctx, pos, add_bos, true);
tokens_neg = common_tokenize(ctx, neg, add_bos, true); tokens_neg = common_tokenize(ctx, neg, add_bos, true);
max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size());

View File

@ -105,6 +105,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_get_vocab(model);
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
@ -148,7 +150,7 @@ int main(int argc, char ** argv) {
// check if the last token is SEP // check if the last token is SEP
// it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true'
for (auto & inp : inputs) { for (auto & inp : inputs) {
if (inp.empty() || inp.back() != llama_token_sep(model)) { if (inp.empty() || inp.back() != llama_token_sep(vocab)) {
LOG_WRN("%s: last token in the prompt is not SEP\n", __func__); LOG_WRN("%s: last token in the prompt is not SEP\n", __func__);
LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__);
} }

View File

@ -127,7 +127,10 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
} }
static bool run(llama_context * ctx, const common_params & params) { static bool run(llama_context * ctx, const common_params & params) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const bool add_bos = llama_add_bos_token(vocab);
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos); std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);

View File

@ -11,6 +11,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
std::vector<std::vector<float>> result; std::vector<std::vector<float>> result;
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
@ -19,16 +20,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
const std::string input_string = instruction + sentences[i]; const std::string input_string = instruction + sentences[i];
std::vector<llama_token> inputs = common_tokenize(model, input_string, true, false); std::vector<llama_token> inputs = common_tokenize(vocab, input_string, true, false);
const int32_t n_toks = inputs.size(); const int32_t n_toks = inputs.size();
// GritLM seems to have EOS = "" // GritLM seems to have EOS = ""
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
// inputs.push_back(llama_token_eos(model)); // inputs.push_back(llama_token_eos(vocab));
// we want to ignore instruction tokens for mean pooling // we want to ignore instruction tokens for mean pooling
const int32_t n_inst = common_tokenize(model, instruction, true, false).size(); const int32_t n_inst = common_tokenize(vocab, instruction, true, false).size();
#ifdef GRIT_DEBUG #ifdef GRIT_DEBUG
// debug tokens - should be matching as referenced in the GritLM sample // debug tokens - should be matching as referenced in the GritLM sample
@ -97,7 +98,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
std::string result; std::string result;
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
llama_token eos_token = llama_token_eos(model); const llama_vocab * vocab = llama_get_vocab(model);
llama_token eos_token = llama_token_eos(vocab);
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, false); llama_set_embeddings(ctx, false);
@ -105,7 +108,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
std::vector<llama_token> inputs = common_tokenize(model, prompt, false, true); std::vector<llama_token> inputs = common_tokenize(vocab, prompt, false, true);
int32_t i_current_token = 0; int32_t i_current_token = 0;
while (true) { while (true) {

View File

@ -429,10 +429,13 @@ static void process_logits(
} }
static bool compute_imatrix(llama_context * ctx, const common_params & params) { static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const bool add_bos = llama_add_bos_token(vocab);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); GGML_ASSERT(!llama_add_eos_token(vocab));
auto tim1 = std::chrono::high_resolution_clock::now(); auto tim1 = std::chrono::high_resolution_clock::now();
LOG_INF("%s: tokenizing the input ..\n", __func__); LOG_INF("%s: tokenizing the input ..\n", __func__);
@ -468,7 +471,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
const int n_chunk_max = tokens.size() / n_ctx; const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_n_vocab(vocab);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
int count = 0; int count = 0;
@ -508,7 +511,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) {
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (add_bos && j == 0) { if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); tokens[batch_start] = llama_token_bos(vocab);
} }
common_batch_clear(batch); common_batch_clear(batch);

View File

@ -139,6 +139,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_get_vocab(model);
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
LOG_DBG("n_ctx: %d\n", n_ctx); LOG_DBG("n_ctx: %d\n", n_ctx);
@ -152,28 +154,28 @@ int main(int argc, char ** argv) {
LOG_INF("\n"); LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str()); LOG_INF("%s\n", common_params_get_system_info(params).c_str());
} }
const bool add_bos = llama_add_bos_token(model); const bool add_bos = llama_add_bos_token(vocab);
GGML_ASSERT(!llama_add_eos_token(model)); GGML_ASSERT(!llama_add_eos_token(vocab));
std::vector<llama_token> embd_inp; std::vector<llama_token> embd_inp;
std::vector<llama_token> embd_end; std::vector<llama_token> embd_end;
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false); std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false); std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);
GGML_ASSERT(llama_token_fim_pre(model) >= 0); GGML_ASSERT(llama_token_fim_pre(vocab) >= 0);
GGML_ASSERT(llama_token_fim_suf(model) >= 0); GGML_ASSERT(llama_token_fim_suf(vocab) >= 0);
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model)); inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(vocab));
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model)); inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(vocab));
embd_inp = params.spm_infill ? inp_sfx : inp_pfx; embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx; embd_end = params.spm_infill ? inp_pfx : inp_sfx;
if (add_bos) { if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); embd_inp.insert(embd_inp.begin(), llama_token_bos(vocab));
} }
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
const llama_token middle_token = llama_token_fim_mid(model); const llama_token middle_token = llama_token_fim_mid(vocab);
if (middle_token >= 0) { if (middle_token >= 0) {
embd_inp.push_back(middle_token); embd_inp.push_back(middle_token);
} }
@ -185,7 +187,7 @@ int main(int argc, char ** argv) {
// Should not run without any tokens // Should not run without any tokens
if (embd_inp.empty()) { if (embd_inp.empty()) {
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_token_bos(vocab));
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
} }
@ -420,10 +422,10 @@ int main(int argc, char ** argv) {
// if not currently processing queued inputs; // if not currently processing queued inputs;
if ((int) embd_inp.size() <= n_consumed) { if ((int) embd_inp.size() <= n_consumed) {
// deal with eot token in infill mode // deal with eot token in infill mode
if ((common_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){ if ((common_sampler_last(smpl) == llama_token_eot(vocab) || is_interacting) && params.interactive){
if (is_interacting && !params.interactive_first) { if (is_interacting && !params.interactive_first) {
// print an eot token // print an eot token
LOG("%s", common_token_to_piece(ctx, llama_token_eot(model)).c_str()); LOG("%s", common_token_to_piece(ctx, llama_token_eot(vocab)).c_str());
} }
LOG("\n"); LOG("\n");
console::set_display(console::user_input); console::set_display(console::user_input);
@ -463,13 +465,13 @@ int main(int argc, char ** argv) {
std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false); std::vector<llama_token> inp_pfx = common_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false); std::vector<llama_token> inp_sfx = common_tokenize(ctx, params.input_suffix, false);
inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(model)); inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(vocab));
inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(model)); inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(vocab));
embd_inp = params.spm_infill ? inp_sfx : inp_pfx; embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx; embd_end = params.spm_infill ? inp_pfx : inp_sfx;
if (add_bos) { if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); embd_inp.insert(embd_inp.begin(), llama_token_bos(vocab));
} }
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
@ -484,7 +486,7 @@ int main(int argc, char ** argv) {
is_interacting = false; is_interacting = false;
} }
// deal with end of generation tokens in interactive mode // deal with end of generation tokens in interactive mode
else if (llama_token_is_eog(model, common_sampler_last(smpl))) { else if (llama_token_is_eog(vocab, common_sampler_last(smpl))) {
LOG_DBG("found EOS token\n"); LOG_DBG("found EOS token\n");
if (params.interactive) { if (params.interactive) {
@ -500,7 +502,7 @@ int main(int argc, char ** argv) {
if (params.input_prefix_bos) { if (params.input_prefix_bos) {
LOG_DBG("adding input prefix BOS token\n"); LOG_DBG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_token_bos(vocab));
} }
std::string buffer; std::string buffer;
@ -563,7 +565,7 @@ int main(int argc, char ** argv) {
} }
// end of generation // end of generation
if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !params.interactive) { if (!embd.empty() && llama_token_is_eog(vocab, embd.back()) && !params.interactive) {
break; break;
} }
@ -575,7 +577,7 @@ int main(int argc, char ** argv) {
} }
} }
if (!params.interactive && n_remain <= 0) { if (!params.interactive && n_remain <= 0) {
LOG("%s", common_token_to_piece(ctx, llama_token_eot(model)).c_str()); LOG("%s", common_token_to_piece(ctx, llama_token_eot(vocab)).c_str());
} }
LOG("\n"); LOG("\n");

View File

@ -1401,7 +1401,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const int32_t n_vocab = llama_n_vocab(model); const llama_vocab * vocab = llama_get_vocab(model);
const int32_t n_vocab = llama_n_vocab(vocab);
std::vector<llama_token> tokens(n_batch); std::vector<llama_token> tokens(n_batch);
@ -1409,7 +1410,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th
while (n_processed < n_prompt) { while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch); int n_tokens = std::min(n_prompt - n_processed, n_batch);
tokens[0] = n_processed == 0 && llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; tokens[0] = n_processed == 0 && llama_add_bos_token(vocab) ? llama_token_bos(vocab) : std::rand() % n_vocab;
for (int i = 1; i < n_tokens; i++) { for (int i = 1; i < n_tokens; i++) {
tokens[i] = std::rand() % n_vocab; tokens[i] = std::rand() % n_vocab;
} }
@ -1424,9 +1425,10 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) {
llama_set_n_threads(ctx, n_threads, n_threads); llama_set_n_threads(ctx, n_threads, n_threads);
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const int32_t n_vocab = llama_n_vocab(model); const llama_vocab * vocab = llama_get_vocab(model);
const int32_t n_vocab = llama_n_vocab(vocab);
llama_token token = llama_add_bos_token(model) ? llama_token_bos(model) : std::rand() % n_vocab; llama_token token = llama_add_bos_token(vocab) ? llama_token_bos(vocab) : std::rand() % n_vocab;
for (int i = 0; i < n_gen; i++) { for (int i = 0; i < n_gen; i++) {
llama_decode(ctx, llama_batch_get_one(&token, 1)); llama_decode(ctx, llama_batch_get_one(&token, 1));

View File

@ -87,7 +87,7 @@ Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring fi
auto path_to_model = env->GetStringUTFChars(filename, 0); auto path_to_model = env->GetStringUTFChars(filename, 0);
LOGi("Loading model from %s", path_to_model); LOGi("Loading model from %s", path_to_model);
auto model = llama_load_model_from_file(path_to_model, model_params); auto model = llama_model_load_from_file(path_to_model, model_params);
env->ReleaseStringUTFChars(filename, path_to_model); env->ReleaseStringUTFChars(filename, path_to_model);
if (!model) { if (!model) {
@ -102,7 +102,7 @@ Java_android_llama_cpp_LLamaAndroid_load_1model(JNIEnv *env, jobject, jstring fi
extern "C" extern "C"
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) { Java_android_llama_cpp_LLamaAndroid_free_1model(JNIEnv *, jobject, jlong model) {
llama_free_model(reinterpret_cast<llama_model *>(model)); llama_model_free(reinterpret_cast<llama_model *>(model));
} }
extern "C" extern "C"
@ -405,6 +405,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer); const auto sampler = reinterpret_cast<llama_sampler *>(sampler_pointer);
const auto model = llama_get_model(context); const auto model = llama_get_model(context);
const auto vocab = llama_get_vocab(model);
if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur);
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
@ -414,7 +415,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
const auto new_token_id = llama_sampler_sample(sampler, context, -1); const auto new_token_id = llama_sampler_sample(sampler, context, -1);
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { if (llama_token_is_eog(vocab, new_token_id) || n_cur == n_len) {
return nullptr; return nullptr;
} }

View File

@ -52,8 +52,8 @@ actor LlamaContext {
deinit { deinit {
llama_sampler_free(sampling) llama_sampler_free(sampling)
llama_batch_free(batch) llama_batch_free(batch)
llama_model_free(model)
llama_free(context) llama_free(context)
llama_free_model(model)
llama_backend_free() llama_backend_free()
} }
@ -65,7 +65,7 @@ actor LlamaContext {
model_params.n_gpu_layers = 0 model_params.n_gpu_layers = 0
print("Running on simulator, force use n_gpu_layers = 0") print("Running on simulator, force use n_gpu_layers = 0")
#endif #endif
let model = llama_load_model_from_file(path, model_params) let model = llama_model_load_from_file(path, model_params)
guard let model else { guard let model else {
print("Could not load model at \(path)") print("Could not load model at \(path)")
throw LlamaError.couldNotInitializeContext throw LlamaError.couldNotInitializeContext

View File

@ -47,8 +47,12 @@ static const char * sample(struct common_sampler * smpl,
int * n_past) { int * n_past) {
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
common_sampler_accept(smpl, id, true); common_sampler_accept(smpl, id, true);
const llama_model * model = llama_get_model(ctx_llama);
const llama_vocab * vocab = llama_get_vocab(model);
static std::string ret; static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { if (llama_token_is_eog(vocab, id)) {
ret = "</s>"; ret = "</s>";
} else { } else {
ret = common_token_to_piece(ctx_llama, id); ret = common_token_to_piece(ctx_llama, id);

View File

@ -167,8 +167,12 @@ static const char * sample(struct common_sampler * smpl,
int * n_past) { int * n_past) {
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
common_sampler_accept(smpl, id, true); common_sampler_accept(smpl, id, true);
const llama_model * model = llama_get_model(ctx_llama);
const llama_vocab * vocab = llama_get_vocab(model);
static std::string ret; static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { if (llama_token_is_eog(vocab, id)) {
ret = "</s>"; ret = "</s>";
} else { } else {
ret = common_token_to_piece(ctx_llama, id); ret = common_token_to_piece(ctx_llama, id);

View File

@ -132,8 +132,12 @@ static const char * sample(struct common_sampler * smpl,
int * n_past, int * st_pos_id) { int * n_past, int * st_pos_id) {
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1); const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
common_sampler_accept(smpl, id, true); common_sampler_accept(smpl, id, true);
const llama_model * model = llama_get_model(ctx_llama);
const llama_vocab * vocab = llama_get_vocab(model);
static std::string ret; static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) { if (llama_token_is_eog(vocab, id)) {
ret = "</s>"; ret = "</s>";
} else { } else {
ret = common_token_to_piece(ctx_llama, id); ret = common_token_to_piece(ctx_llama, id);

View File

@ -61,6 +61,8 @@ int main(int argc, char ** argv) {
llama_model * model = llama_init.model.get(); llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
const llama_vocab * vocab = llama_get_vocab(model);
// Tokenize the prompt // Tokenize the prompt
std::vector<llama_token> inp; std::vector<llama_token> inp;
std::vector<llama_token> all; std::vector<llama_token> all;
@ -147,7 +149,7 @@ int main(int argc, char ** argv) {
} }
// here we keep adding new n-grams as we go // here we keep adding new n-grams as we go
ngram_container ngrams_observed(llama_n_vocab(model), N, G); ngram_container ngrams_observed(llama_n_vocab(vocab), N, G);
// debug // debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1); struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
@ -297,7 +299,7 @@ int main(int argc, char ** argv) {
} }
fflush(stdout); fflush(stdout);
if (llama_token_is_eog(model, id)) { if (llama_token_is_eog(vocab, id)) {
has_eos = true; has_eos = true;
} }

View File

@ -36,6 +36,8 @@ int main(int argc, char ** argv){
llama_model * model = llama_init.model.get(); llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
const llama_vocab * vocab = llama_get_vocab(model);
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> inp; std::vector<llama_token> inp;
inp = common_tokenize(ctx, params.prompt, true, true); inp = common_tokenize(ctx, params.prompt, true, true);
@ -136,7 +138,7 @@ int main(int argc, char ** argv){
LOG("%s", token_str.c_str()); LOG("%s", token_str.c_str());
} }
if (llama_token_is_eog(model, id)) { if (llama_token_is_eog(vocab, id)) {
has_eos = true; has_eos = true;
} }

View File

@ -163,6 +163,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_get_vocab(model);
LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads);
auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU)); auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU));
@ -241,9 +243,9 @@ int main(int argc, char ** argv) {
} }
} }
const bool add_bos = llama_add_bos_token(model); const bool add_bos = llama_add_bos_token(vocab);
if (!llama_model_has_encoder(model)) { if (!llama_model_has_encoder(model)) {
GGML_ASSERT(!llama_add_eos_token(model)); GGML_ASSERT(!llama_add_eos_token(vocab));
} }
LOG_DBG("n_ctx: %d, add_bos: %d\n", n_ctx, add_bos); LOG_DBG("n_ctx: %d, add_bos: %d\n", n_ctx, add_bos);
@ -269,7 +271,7 @@ int main(int argc, char ** argv) {
// Should not run without any tokens // Should not run without any tokens
if (embd_inp.empty()) { if (embd_inp.empty()) {
if (add_bos) { if (add_bos) {
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_token_bos(vocab));
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
} else { } else {
LOG_ERR("input is empty\n"); LOG_ERR("input is empty\n");
@ -495,7 +497,7 @@ int main(int argc, char ** argv) {
llama_token decoder_start_token_id = llama_model_decoder_start_token(model); llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) { if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = llama_token_bos(model); decoder_start_token_id = llama_token_bos(vocab);
} }
embd_inp.clear(); embd_inp.clear();
@ -742,7 +744,7 @@ int main(int argc, char ** argv) {
} }
// deal with end of generation tokens in interactive mode // deal with end of generation tokens in interactive mode
if (llama_token_is_eog(model, common_sampler_last(smpl))) { if (llama_token_is_eog(vocab, common_sampler_last(smpl))) {
LOG_DBG("found an EOG token\n"); LOG_DBG("found an EOG token\n");
if (params.interactive) { if (params.interactive) {
@ -776,7 +778,7 @@ int main(int argc, char ** argv) {
if (params.input_prefix_bos) { if (params.input_prefix_bos) {
LOG_DBG("adding input prefix BOS token\n"); LOG_DBG("adding input prefix BOS token\n");
embd_inp.push_back(llama_token_bos(model)); embd_inp.push_back(llama_token_bos(vocab));
} }
std::string buffer; std::string buffer;
@ -830,8 +832,8 @@ int main(int argc, char ** argv) {
// if user stop generation mid-way, we must add EOT to finish model's last response // if user stop generation mid-way, we must add EOT to finish model's last response
if (need_insert_eot && format_chat) { if (need_insert_eot && format_chat) {
llama_token eot = llama_token_eot(model); llama_token eot = llama_token_eot(vocab);
embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_token_eos(model) : eot); embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_token_eos(vocab) : eot);
need_insert_eot = false; need_insert_eot = false;
} }
@ -866,7 +868,7 @@ int main(int argc, char ** argv) {
} }
// end of generation // end of generation
if (!embd.empty() && llama_token_is_eog(model, embd.back()) && !(params.interactive)) { if (!embd.empty() && llama_token_is_eog(vocab, embd.back()) && !(params.interactive)) {
LOG(" [end of text]\n"); LOG(" [end of text]\n");
break; break;
} }

View File

@ -135,6 +135,8 @@ int main(int argc, char ** argv) {
llama_model * model = llama_init.model.get(); llama_model * model = llama_init.model.get();
llama_context * ctx = llama_init.context.get(); llama_context * ctx = llama_init.context.get();
const llama_vocab * vocab = llama_get_vocab(model);
// load the prompts from an external file if there are any // load the prompts from an external file if there are any
if (params.prompt.empty()) { if (params.prompt.empty()) {
LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n"); LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
@ -358,7 +360,7 @@ int main(int argc, char ** argv) {
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
if (client.n_decoded > 2 && if (client.n_decoded > 2 &&
(llama_token_is_eog(model, id) || (llama_token_is_eog(vocab, id) ||
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
client.response.find("User:") != std::string::npos || client.response.find("User:") != std::string::npos ||
client.response.find('\n') != std::string::npos)) { client.response.find('\n') != std::string::npos)) {

View File

@ -70,6 +70,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_get_vocab(model);
// initialize the context // initialize the context
llama_context_params ctx_params = common_context_params_to_llama(params); llama_context_params ctx_params = common_context_params_to_llama(params);
@ -223,7 +225,7 @@ int main(int argc, char ** argv) {
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) { if (llama_token_is_eog(vocab, new_token_id) || n_cur == n_len) {
LOG("\n"); LOG("\n");
break; break;

View File

@ -296,8 +296,11 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); const llama_vocab * vocab = llama_get_vocab(model);
const bool add_bos = llama_add_bos_token(vocab);
GGML_ASSERT(!llama_add_eos_token(vocab));
LOG_INF("%s: tokenizing the input ..\n", __func__); LOG_INF("%s: tokenizing the input ..\n", __func__);
@ -338,7 +341,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_n_vocab(vocab);
int count = 0; int count = 0;
double nll = 0.0; double nll = 0.0;
@ -382,7 +385,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (add_bos && j == 0) { if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); tokens[batch_start] = llama_token_bos(vocab);
} }
const auto * batch_logits = llama_get_logits(ctx); const auto * batch_logits = llama_get_logits(ctx);
@ -444,8 +447,11 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); const llama_vocab * vocab = llama_get_vocab(model);
const bool add_bos = llama_add_bos_token(vocab);
GGML_ASSERT(!llama_add_eos_token(vocab));
std::ofstream logits_stream; std::ofstream logits_stream;
if (!params.logits_file.empty()) { if (!params.logits_file.empty()) {
@ -485,7 +491,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_n_vocab(vocab);
int count = 0; int count = 0;
double nll = 0.0; double nll = 0.0;
@ -557,7 +563,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params &
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (add_bos && j == 0) { if (add_bos && j == 0) {
tokens[seq_start] = llama_token_bos(llama_get_model(ctx)); tokens[seq_start] = llama_token_bos(vocab);
} }
for (int k = 0; k < batch_size; ++k) { for (int k = 0; k < batch_size; ++k) {
@ -732,6 +738,9 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto
} }
static void hellaswag_score(llama_context * ctx, const common_params & params) { static void hellaswag_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
// Calculates hellaswag score (acc_norm) from prompt // Calculates hellaswag score (acc_norm) from prompt
// //
// Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl // Data extracted from the HellaSwag validation dataset (MIT license) https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl
@ -765,7 +774,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
size_t hs_task_count = prompt_lines.size()/6; size_t hs_task_count = prompt_lines.size()/6;
LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count); LOG_INF("%s : loaded %zu tasks from prompt.\n", __func__, hs_task_count);
const bool is_spm = llama_vocab_type(llama_get_model(ctx)) == LLAMA_VOCAB_TYPE_SPM; const bool is_spm = llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_SPM;
LOG_INF("================================= is_spm = %d\n", is_spm); LOG_INF("================================= is_spm = %d\n", is_spm);
// The tasks should be randomized so the score stabilizes quickly. // The tasks should be randomized so the score stabilizes quickly.
@ -848,7 +857,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) {
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_n_vocab(vocab);
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1072,6 +1081,8 @@ static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string
* *
*/ */
static void winogrande_score(llama_context * ctx, const common_params & params) { static void winogrande_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
constexpr int k_min_trailing_ctx = 3; constexpr int k_min_trailing_ctx = 3;
@ -1130,7 +1141,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params)
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_n_vocab(vocab);
const int max_tasks_per_batch = 128; const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1374,6 +1385,8 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic
// https://huggingface.co/datasets/truthful_qa // https://huggingface.co/datasets/truthful_qa
// //
static void multiple_choice_score(llama_context * ctx, const common_params & params) { static void multiple_choice_score(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
std::istringstream strstream(params.prompt); std::istringstream strstream(params.prompt);
uint32_t n_task; uint32_t n_task;
@ -1482,7 +1495,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_n_vocab(vocab);
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1655,6 +1668,9 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par
} }
static void kl_divergence(llama_context * ctx, const common_params & params) { static void kl_divergence(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
if (params.logits_file.empty()) { if (params.logits_file.empty()) {
LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__); LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__);
return; return;
@ -1688,8 +1704,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str()); LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
return; return;
} }
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) { if (n_vocab != llama_n_vocab(vocab)) {
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx))); LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(vocab));
} }
std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk); std::vector<llama_token> tokens(size_t(n_ctx) * n_chunk);
@ -1701,8 +1717,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int num_batches = (n_ctx + n_batch - 1)/n_batch; const int num_batches = (n_ctx + n_batch - 1)/n_batch;
const int nv = 2*((n_vocab + 1)/2) + 4; const int nv = 2*((n_vocab + 1)/2) + 4;
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_add_bos_token(vocab);
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); GGML_ASSERT(!llama_add_eos_token(vocab));
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
@ -1761,7 +1777,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) {
// add BOS token for the first batch of each chunk // add BOS token for the first batch of each chunk
if (add_bos && j == 0) { if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); tokens[batch_start] = llama_token_bos(vocab);
} }
common_batch_clear(batch); common_batch_clear(batch);

View File

@ -159,6 +159,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_get_vocab(model);
const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
@ -192,8 +194,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
// add eos if not present // add eos if not present
if (llama_token_eos(model) >= 0 && (inp.empty() || inp.back() != llama_token_eos(model))) { if (llama_token_eos(vocab) >= 0 && (inp.empty() || inp.back() != llama_token_eos(vocab))) {
inp.push_back(llama_token_eos(model)); inp.push_back(llama_token_eos(vocab));
} }
chunk.tokens = inp; chunk.tokens = inp;
} }

View File

@ -713,11 +713,11 @@ static void add_message(const char * role, const std::string & text, LlamaData &
// Function to apply the chat template and resize `formatted` if needed // Function to apply the chat template and resize `formatted` if needed
static int apply_chat_template(LlamaData & llama_data, const bool append) { static int apply_chat_template(LlamaData & llama_data, const bool append) {
int result = llama_chat_apply_template( int result = llama_chat_apply_template(
llama_data.model.get(), nullptr, llama_data.messages.data(), llama_data.messages.size(), append, llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(), llama_data.messages.size(), append,
append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0); append ? llama_data.fmtted.data() : nullptr, append ? llama_data.fmtted.size() : 0);
if (append && result > static_cast<int>(llama_data.fmtted.size())) { if (append && result > static_cast<int>(llama_data.fmtted.size())) {
llama_data.fmtted.resize(result); llama_data.fmtted.resize(result);
result = llama_chat_apply_template(llama_data.model.get(), nullptr, llama_data.messages.data(), result = llama_chat_apply_template(llama_model_chat_template(llama_data.model.get()), llama_data.messages.data(),
llama_data.messages.size(), append, llama_data.fmtted.data(), llama_data.messages.size(), append, llama_data.fmtted.data(),
llama_data.fmtted.size()); llama_data.fmtted.size());
} }
@ -726,11 +726,11 @@ static int apply_chat_template(LlamaData & llama_data, const bool append) {
} }
// Function to tokenize the prompt // Function to tokenize the prompt
static int tokenize_prompt(const llama_model_ptr & model, const std::string & prompt, static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt,
std::vector<llama_token> & prompt_tokens) { std::vector<llama_token> & prompt_tokens) {
const int n_prompt_tokens = -llama_tokenize(model.get(), prompt.c_str(), prompt.size(), NULL, 0, true, true); const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
prompt_tokens.resize(n_prompt_tokens); prompt_tokens.resize(n_prompt_tokens);
if (llama_tokenize(model.get(), prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true,
true) < 0) { true) < 0) {
printe("failed to tokenize the prompt\n"); printe("failed to tokenize the prompt\n");
return -1; return -1;
@ -753,9 +753,9 @@ static int check_context_size(const llama_context_ptr & ctx, const llama_batch &
} }
// convert the token to a string // convert the token to a string
static int convert_token_to_string(const llama_model_ptr & model, const llama_token token_id, std::string & piece) { static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) {
char buf[256]; char buf[256];
int n = llama_token_to_piece(model.get(), token_id, buf, sizeof(buf), 0, true); int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true);
if (n < 0) { if (n < 0) {
printe("failed to convert token to piece\n"); printe("failed to convert token to piece\n");
return 1; return 1;
@ -773,8 +773,10 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st
// helper function to evaluate a prompt and generate a response // helper function to evaluate a prompt and generate a response
static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) {
const llama_vocab * vocab = llama_get_vocab(llama_data.model.get());
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
if (tokenize_prompt(llama_data.model, prompt, tokens) < 0) { if (tokenize_prompt(vocab, prompt, tokens) < 0) {
return 1; return 1;
} }
@ -790,12 +792,12 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str
// sample the next token, check is it an end of generation? // sample the next token, check is it an end of generation?
new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1);
if (llama_token_is_eog(llama_data.model.get(), new_token_id)) { if (llama_token_is_eog(vocab, new_token_id)) {
break; break;
} }
std::string piece; std::string piece;
if (convert_token_to_string(llama_data.model, new_token_id, piece)) { if (convert_token_to_string(vocab, new_token_id, piece)) {
return 1; return 1;
} }

View File

@ -203,10 +203,12 @@ struct server_task {
server_task(server_task_type type) : type(type) {} server_task(server_task_type type) : type(type) {}
static slot_params params_from_json_cmpl( static slot_params params_from_json_cmpl(
const llama_model * model,
const llama_context * ctx, const llama_context * ctx,
const common_params & params_base, const common_params & params_base,
const json & data) { const json & data) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
slot_params params; slot_params params;
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
@ -329,7 +331,7 @@ struct server_task {
const auto & logit_bias = data.find("logit_bias"); const auto & logit_bias = data.find("logit_bias");
if (logit_bias != data.end() && logit_bias->is_array()) { if (logit_bias != data.end() && logit_bias->is_array()) {
const int n_vocab = llama_n_vocab(model); const int n_vocab = llama_n_vocab(vocab);
for (const auto & el : *logit_bias) { for (const auto & el : *logit_bias) {
// TODO: we may want to throw errors here, in case "el" is incorrect // TODO: we may want to throw errors here, in case "el" is incorrect
if (el.is_array() && el.size() == 2) { if (el.is_array() && el.size() == 2) {
@ -348,7 +350,7 @@ struct server_task {
params.sampling.logit_bias.push_back({tok, bias}); params.sampling.logit_bias.push_back({tok, bias});
} }
} else if (el[0].is_string()) { } else if (el[0].is_string()) {
auto toks = common_tokenize(model, el[0].get<std::string>(), false); auto toks = common_tokenize(vocab, el[0].get<std::string>(), false);
for (auto tok : toks) { for (auto tok : toks) {
params.sampling.logit_bias.push_back({tok, bias}); params.sampling.logit_bias.push_back({tok, bias});
} }
@ -1633,6 +1635,8 @@ struct server_context {
llama_model * model = nullptr; llama_model * model = nullptr;
llama_context * ctx = nullptr; llama_context * ctx = nullptr;
const llama_vocab * vocab = nullptr;
llama_model * model_dft = nullptr; llama_model * model_dft = nullptr;
llama_context_params cparams_dft; llama_context_params cparams_dft;
@ -1690,10 +1694,12 @@ struct server_context {
return false; return false;
} }
vocab = llama_get_vocab(model);
n_ctx = llama_n_ctx(ctx); n_ctx = llama_n_ctx(ctx);
add_bos_token = llama_add_bos_token(model); add_bos_token = llama_add_bos_token(vocab);
has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL; has_eos_token = llama_token_eos(vocab) != LLAMA_TOKEN_NULL;
if (!params_base.speculative.model.empty()) { if (!params_base.speculative.model.empty()) {
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
@ -1736,7 +1742,8 @@ struct server_context {
bool validate_builtin_chat_template() const { bool validate_builtin_chat_template() const {
llama_chat_message chat[] = {{"user", "test"}}; llama_chat_message chat[] = {{"user", "test"}};
int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0); const char * tmpl = llama_model_chat_template(model);
const int32_t chat_res = llama_chat_apply_template(tmpl, chat, 1, true, nullptr, 0);
return chat_res > 0; return chat_res > 0;
} }
@ -1891,7 +1898,7 @@ struct server_context {
} }
if (slot.params.ignore_eos && has_eos_token) { if (slot.params.ignore_eos && has_eos_token) {
slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY}); slot.params.sampling.logit_bias.push_back({llama_token_eos(vocab), -INFINITY});
} }
{ {
@ -2047,7 +2054,7 @@ struct server_context {
slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
} }
if (llama_token_is_eog(model, result.tok)) { if (llama_token_is_eog(vocab, result.tok)) {
slot.stop = STOP_TYPE_EOS; slot.stop = STOP_TYPE_EOS;
slot.has_next_token = false; slot.has_next_token = false;
@ -2074,7 +2081,7 @@ struct server_context {
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
size_t n_probs = slot.params.sampling.n_probs; size_t n_probs = slot.params.sampling.n_probs;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); size_t n_vocab = llama_n_vocab(vocab);
if (post_sampling) { if (post_sampling) {
const auto * cur_p = common_sampler_get_candidates(slot.smpl); const auto * cur_p = common_sampler_get_candidates(slot.smpl);
const size_t max_probs = cur_p->size; const size_t max_probs = cur_p->size;
@ -3129,8 +3136,8 @@ struct server_context {
json model_meta() const { json model_meta() const {
return json { return json {
{"vocab_type", llama_vocab_type (model)}, {"vocab_type", llama_vocab_type (vocab)},
{"n_vocab", llama_n_vocab (model)}, {"n_vocab", llama_n_vocab (vocab)},
{"n_ctx_train", llama_n_ctx_train (model)}, {"n_ctx_train", llama_n_ctx_train (model)},
{"n_embd", llama_n_embd (model)}, {"n_embd", llama_n_embd (model)},
{"n_params", llama_model_n_params(model)}, {"n_params", llama_model_n_params(model)},
@ -3639,7 +3646,7 @@ int main(int argc, char ** argv) {
std::vector<server_task> tasks; std::vector<server_task> tasks;
try { try {
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true); std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, data.at("prompt"), true, true);
tasks.reserve(tokenized_prompts.size()); tasks.reserve(tokenized_prompts.size());
for (size_t i = 0; i < tokenized_prompts.size(); i++) { for (size_t i = 0; i < tokenized_prompts.size(); i++) {
server_task task = server_task(type); server_task task = server_task(type);
@ -3649,7 +3656,6 @@ int main(int argc, char ** argv) {
task.prompt_tokens = std::move(tokenized_prompts[i]); task.prompt_tokens = std::move(tokenized_prompts[i]);
task.params = server_task::params_from_json_cmpl( task.params = server_task::params_from_json_cmpl(
ctx_server.model,
ctx_server.ctx, ctx_server.ctx,
ctx_server.params_base, ctx_server.params_base,
data); data);
@ -3745,13 +3751,13 @@ int main(int argc, char ** argv) {
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
// check model compatibility // check model compatibility
std::string err; std::string err;
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_token_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
err += "prefix token is missing. "; err += "prefix token is missing. ";
} }
if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_token_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
err += "suffix token is missing. "; err += "suffix token is missing. ";
} }
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) { if (llama_token_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) {
err += "middle token is missing. "; err += "middle token is missing. ";
} }
if (!err.empty()) { if (!err.empty()) {
@ -3797,10 +3803,10 @@ int main(int argc, char ** argv) {
data["input_extra"] = input_extra; // default to empty array if it's not exist data["input_extra"] = input_extra; // default to empty array if it's not exist
std::string prompt = json_value(data, "prompt", std::string()); std::string prompt = json_value(data, "prompt", std::string());
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, false, true); std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, false, true);
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
data["prompt"] = format_infill( data["prompt"] = format_infill(
ctx_server.ctx, ctx_server.vocab,
data.at("input_prefix"), data.at("input_prefix"),
data.at("input_suffix"), data.at("input_suffix"),
data.at("input_extra"), data.at("input_extra"),
@ -3857,7 +3863,7 @@ int main(int argc, char ** argv) {
const bool add_special = json_value(body, "add_special", false); const bool add_special = json_value(body, "add_special", false);
const bool with_pieces = json_value(body, "with_pieces", false); const bool with_pieces = json_value(body, "with_pieces", false);
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true); llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, true);
if (with_pieces) { if (with_pieces) {
for (const auto& token : tokens) { for (const auto& token : tokens) {
@ -3933,7 +3939,7 @@ int main(int argc, char ** argv) {
} }
} }
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true); std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, prompt, true, true);
for (const auto & tokens : tokenized_prompts) { for (const auto & tokens : tokenized_prompts) {
// this check is necessary for models that do not add BOS token to the input // this check is necessary for models that do not add BOS token to the input
if (tokens.empty()) { if (tokens.empty()) {
@ -4033,20 +4039,20 @@ int main(int argc, char ** argv) {
return; return;
} }
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0]; llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.vocab, query, /* add_special */ false, true)[0];
// create and queue the task // create and queue the task
json responses = json::array(); json responses = json::array();
bool error = false; bool error = false;
{ {
std::vector<server_task> tasks; std::vector<server_task> tasks;
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true); std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.vocab, documents, /* add_special */ false, true);
tasks.reserve(tokenized_docs.size()); tasks.reserve(tokenized_docs.size());
for (size_t i = 0; i < tokenized_docs.size(); i++) { for (size_t i = 0; i < tokenized_docs.size(); i++) {
server_task task = server_task(SERVER_TASK_TYPE_RERANK); server_task task = server_task(SERVER_TASK_TYPE_RERANK);
task.id = ctx_server.queue_tasks.get_new_id(); task.id = ctx_server.queue_tasks.get_new_id();
task.index = i; task.index = i;
task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]); task.prompt_tokens = format_rerank(ctx_server.vocab, tokenized_query, tokenized_docs[i]);
tasks.push_back(task); tasks.push_back(task);
} }

View File

@ -118,7 +118,7 @@ static json json_get_nested_values(const std::vector<std::string> & paths, const
* - only string, example: "string" * - only string, example: "string"
* - mixed string and tokens, example: [12, 34, "string", 56, 78] * - mixed string and tokens, example: [12, 34, "string", 56, 78]
*/ */
static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { static llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
// If `add_bos` is true, we only add BOS, when json_prompt is a string, // If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string. // or the first element of the json_prompt array is a string.
llama_tokens prompt_tokens; llama_tokens prompt_tokens;
@ -131,10 +131,10 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
llama_tokens p; llama_tokens p;
if (first) { if (first) {
p = common_tokenize(ctx, s, add_special, parse_special); p = common_tokenize(vocab, s, add_special, parse_special);
first = false; first = false;
} else { } else {
p = common_tokenize(ctx, s, false, parse_special); p = common_tokenize(vocab, s, false, parse_special);
} }
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
@ -148,7 +148,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
} }
} else { } else {
auto s = json_prompt.template get<std::string>(); auto s = json_prompt.template get<std::string>();
prompt_tokens = common_tokenize(ctx, s, add_special, parse_special); prompt_tokens = common_tokenize(vocab, s, add_special, parse_special);
} }
return prompt_tokens; return prompt_tokens;
@ -166,11 +166,11 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
* - "prompt": [[12, 34, 56], [78, 90, 12]] * - "prompt": [[12, 34, 56], [78, 90, 12]]
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]] * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
*/ */
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) { static std::vector<llama_tokens> tokenize_input_prompts(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) {
std::vector<llama_tokens> result; std::vector<llama_tokens> result;
if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) {
// string or mixed // string or mixed
result.push_back(tokenize_mixed(ctx, json_prompt, add_special, parse_special)); result.push_back(tokenize_mixed(vocab, json_prompt, add_special, parse_special));
} else if (json_is_array_of_numbers(json_prompt)) { } else if (json_is_array_of_numbers(json_prompt)) {
// array of tokens // array of tokens
result.push_back(json_prompt.get<llama_tokens>()); result.push_back(json_prompt.get<llama_tokens>());
@ -179,7 +179,7 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
result.reserve(json_prompt.size()); result.reserve(json_prompt.size());
for (const auto & p : json_prompt) { for (const auto & p : json_prompt) {
if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) { if (p.is_string() || json_is_array_of_mixed_numbers_strings(p)) {
result.push_back(tokenize_mixed(ctx, p, add_special, parse_special)); result.push_back(tokenize_mixed(vocab, p, add_special, parse_special));
} else if (json_is_array_of_numbers(p)) { } else if (json_is_array_of_numbers(p)) {
// array of tokens // array of tokens
result.push_back(p.get<llama_tokens>()); result.push_back(p.get<llama_tokens>());
@ -231,21 +231,23 @@ static size_t validate_utf8(const std::string& text) {
// //
// format rerank task: [BOS]query[EOS][SEP]doc[EOS] // format rerank task: [BOS]query[EOS][SEP]doc[EOS]
static llama_tokens format_rerank(const struct llama_model * model, const llama_tokens & query, const llama_tokens & doc) { static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_tokens & query, const llama_tokens & doc) {
llama_tokens result; llama_tokens result;
result.reserve(doc.size() + query.size() + 4); result.reserve(doc.size() + query.size() + 4);
result.push_back(llama_token_bos(model)); result.push_back(llama_token_bos(vocab));
result.insert(result.end(), query.begin(), query.end()); result.insert(result.end(), query.begin(), query.end());
result.push_back(llama_token_eos(model)); result.push_back(llama_token_eos(vocab));
result.push_back(llama_token_sep(model)); result.push_back(llama_token_sep(vocab));
result.insert(result.end(), doc.begin(), doc.end()); result.insert(result.end(), doc.begin(), doc.end());
result.push_back(llama_token_eos(model)); result.push_back(llama_token_eos(vocab));
return result; return result;
} }
// format infill task // format infill task
static llama_tokens format_infill( static llama_tokens format_infill(
const llama_context * ctx, const llama_vocab * vocab,
const json & input_prefix, const json & input_prefix,
const json & input_suffix, const json & input_suffix,
const json & input_extra, const json & input_extra,
@ -272,15 +274,14 @@ static llama_tokens format_infill(
llama_tokens extra_tokens; llama_tokens extra_tokens;
extra_tokens.reserve(n_ctx); extra_tokens.reserve(n_ctx);
auto model = llama_get_model(ctx); auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false);
auto tokens_prefix = tokenize_mixed(ctx, input_prefix, false, false); auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false);
auto tokens_suffix = tokenize_mixed(ctx, input_suffix, false, false);
if (llama_token_fim_rep(model) != LLAMA_TOKEN_NULL) { if (llama_token_fim_rep(vocab) != LLAMA_TOKEN_NULL) {
// TODO: make project name an input // TODO: make project name an input
static const auto k_fim_repo = common_tokenize(ctx, "myproject\n", false, false); static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false);
extra_tokens.push_back(llama_token_fim_rep(model)); extra_tokens.push_back(llama_token_fim_rep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end());
} }
for (const auto & chunk : input_extra) { for (const auto & chunk : input_extra) {
@ -288,28 +289,28 @@ static llama_tokens format_infill(
const std::string text = json_value(chunk, "text", std::string()); const std::string text = json_value(chunk, "text", std::string());
const std::string filename = json_value(chunk, "filename", std::string("tmp")); const std::string filename = json_value(chunk, "filename", std::string("tmp"));
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) { if (llama_token_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
const auto k_fim_file = common_tokenize(ctx, filename + "\n", false, false); const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model)); extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} else { } else {
// chunk separator in binary form to avoid confusing the AI // chunk separator in binary form to avoid confusing the AI
static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00};
static const auto k_chunk_prefix_tokens = common_tokenize(ctx, k_chunk_prefix_str, false, false); static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false);
extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end());
} }
const auto chunk_tokens = common_tokenize(ctx, text, false, false); const auto chunk_tokens = common_tokenize(vocab, text, false, false);
extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end());
} }
if (llama_token_fim_sep(model) != LLAMA_TOKEN_NULL) { if (llama_token_fim_sep(vocab) != LLAMA_TOKEN_NULL) {
// TODO: current filename // TODO: current filename
static const auto k_fim_file = common_tokenize(ctx, "filename\n", false, false); static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false);
extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(model)); extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(vocab));
extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end());
} }
@ -325,15 +326,15 @@ static llama_tokens format_infill(
tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take);
tokens_suffix.resize(n_suffix_take); tokens_suffix.resize(n_suffix_take);
tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(model)); tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(vocab));
tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end());
tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(model)); tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(vocab));
auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix;
auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; auto embd_end = spm_infill ? tokens_prefix : tokens_suffix;
if (llama_add_bos_token(model)) { if (llama_add_bos_token(vocab)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); embd_inp.insert(embd_inp.begin(), llama_token_bos(vocab));
} }
SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size());
@ -342,7 +343,7 @@ static llama_tokens format_infill(
embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end());
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model)); embd_inp.push_back(llama_token_fim_mid(vocab));
return embd_inp; return embd_inp;
} }
@ -764,14 +765,18 @@ static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias)
return data; return data;
} }
static std::string safe_json_to_str(json data) { static std::string safe_json_to_str(const json & data) {
return data.dump(-1, ' ', false, json::error_handler_t::replace); return data.dump(-1, ' ', false, json::error_handler_t::replace);
} }
static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) { static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
const auto * logits = llama_get_logits_ith(ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
cur.resize(n_vocab); cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) { for (llama_token token_id = 0; token_id < n_vocab; token_id++) {

View File

@ -75,6 +75,8 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_get_vocab(model);
// initialize the context // initialize the context
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_ctx; ctx_params.n_ctx = n_ctx;
@ -97,9 +99,9 @@ int main(int argc, char ** argv) {
std::string response; std::string response;
// tokenize the prompt // tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> prompt_tokens(n_prompt_tokens); std::vector<llama_token> prompt_tokens(n_prompt_tokens);
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) { if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), llama_get_kv_cache_used_cells(ctx) == 0, true) < 0) {
GGML_ABORT("failed to tokenize the prompt\n"); GGML_ABORT("failed to tokenize the prompt\n");
} }
@ -124,13 +126,13 @@ int main(int argc, char ** argv) {
new_token_id = llama_sampler_sample(smpl, ctx, -1); new_token_id = llama_sampler_sample(smpl, ctx, -1);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id)) { if (llama_token_is_eog(vocab, new_token_id)) {
break; break;
} }
// convert the token to a string, print it and add it to the response // convert the token to a string, print it and add it to the response
char buf[256]; char buf[256];
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
if (n < 0) { if (n < 0) {
GGML_ABORT("failed to convert token to piece\n"); GGML_ABORT("failed to convert token to piece\n");
} }
@ -159,12 +161,14 @@ int main(int argc, char ** argv) {
break; break;
} }
const char * tmpl = llama_model_chat_template(model);
// add the user input to the message list and format it // add the user input to the message list and format it
messages.push_back({"user", strdup(user.c_str())}); messages.push_back({"user", strdup(user.c_str())});
int new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); int new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
if (new_len > (int)formatted.size()) { if (new_len > (int)formatted.size()) {
formatted.resize(new_len); formatted.resize(new_len);
new_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), true, formatted.data(), formatted.size()); new_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), true, formatted.data(), formatted.size());
} }
if (new_len < 0) { if (new_len < 0) {
fprintf(stderr, "failed to apply the chat template\n"); fprintf(stderr, "failed to apply the chat template\n");
@ -181,7 +185,7 @@ int main(int argc, char ** argv) {
// add the response to the messages // add the response to the messages
messages.push_back({"assistant", strdup(response.c_str())}); messages.push_back({"assistant", strdup(response.c_str())});
prev_len = llama_chat_apply_template(model, nullptr, messages.data(), messages.size(), false, nullptr, 0); prev_len = llama_chat_apply_template(tmpl, messages.data(), messages.size(), false, nullptr, 0);
if (prev_len < 0) { if (prev_len < 0) {
fprintf(stderr, "failed to apply the chat template\n"); fprintf(stderr, "failed to apply the chat template\n");
return 1; return 1;

View File

@ -84,6 +84,7 @@ int main(int argc, char ** argv) {
model_params.n_gpu_layers = ngl; model_params.n_gpu_layers = ngl;
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
const llama_vocab * vocab = llama_get_vocab(model);
if (model == NULL) { if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__); fprintf(stderr , "%s: error: unable to load model\n" , __func__);
@ -93,11 +94,11 @@ int main(int argc, char ** argv) {
// tokenize the prompt // tokenize the prompt
// find the number of tokens in the prompt // find the number of tokens in the prompt
const int n_prompt = -llama_tokenize(model, prompt.c_str(), prompt.size(), NULL, 0, true, true); const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
// allocate space for the tokens and tokenize the prompt // allocate space for the tokens and tokenize the prompt
std::vector<llama_token> prompt_tokens(n_prompt); std::vector<llama_token> prompt_tokens(n_prompt);
if (llama_tokenize(model, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__); fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__);
return 1; return 1;
} }
@ -131,7 +132,7 @@ int main(int argc, char ** argv) {
for (auto id : prompt_tokens) { for (auto id : prompt_tokens) {
char buf[128]; char buf[128];
int n = llama_token_to_piece(model, id, buf, sizeof(buf), 0, true); int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true);
if (n < 0) { if (n < 0) {
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
return 1; return 1;
@ -164,12 +165,12 @@ int main(int argc, char ** argv) {
new_token_id = llama_sampler_sample(smpl, ctx, -1); new_token_id = llama_sampler_sample(smpl, ctx, -1);
// is it an end of generation? // is it an end of generation?
if (llama_token_is_eog(model, new_token_id)) { if (llama_token_is_eog(vocab, new_token_id)) {
break; break;
} }
char buf[128]; char buf[128];
int n = llama_token_to_piece(model, new_token_id, buf, sizeof(buf), 0, true); int n = llama_token_to_piece(vocab, new_token_id, buf, sizeof(buf), 0, true);
if (n < 0) { if (n < 0) {
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
return 1; return 1;

View File

@ -45,6 +45,8 @@ int main(int argc, char ** argv) {
model_tgt = llama_init_tgt.model.get(); model_tgt = llama_init_tgt.model.get();
ctx_tgt = llama_init_tgt.context.get(); ctx_tgt = llama_init_tgt.context.get();
const llama_vocab * vocab = llama_get_vocab(model_tgt);
// load the draft model // load the draft model
params.devices = params.speculative.devices; params.devices = params.speculative.devices;
params.model = params.speculative.model; params.model = params.speculative.model;
@ -196,7 +198,7 @@ int main(int argc, char ** argv) {
id_last = ids[i]; id_last = ids[i];
if (llama_token_is_eog(model_tgt, id_last)) { if (llama_token_is_eog(vocab, id_last)) {
has_eos = true; has_eos = true;
break; break;
} }

View File

@ -90,10 +90,13 @@ int main(int argc, char ** argv) {
model_dft = llama_init_dft.model.get(); model_dft = llama_init_dft.model.get();
ctx_dft = llama_init_dft.context.get(); ctx_dft = llama_init_dft.context.get();
const bool vocab_type_tgt = llama_vocab_type(model_tgt); const llama_vocab * vocab_tgt = llama_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_get_vocab(model_dft);
const bool vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
const bool vocab_type_dft = llama_vocab_type(model_dft); const bool vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("vocab_type dft: %d\n", vocab_type_dft); LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) { if (vocab_type_tgt != vocab_type_dft) {
@ -103,18 +106,18 @@ int main(int argc, char ** argv) {
} }
if ( if (
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || llama_add_bos_token(vocab_tgt) != llama_add_bos_token(vocab_dft) ||
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || llama_add_eos_token(vocab_tgt) != llama_add_eos_token(vocab_dft) ||
llama_token_bos(model_tgt) != llama_token_bos(model_dft) || llama_token_bos(vocab_tgt) != llama_token_bos(vocab_dft) ||
llama_token_eos(model_tgt) != llama_token_eos(model_dft) llama_token_eos(vocab_tgt) != llama_token_eos(vocab_dft)
) { ) {
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
return 1; return 1;
} }
{ {
const int n_vocab_tgt = llama_n_vocab(model_tgt); const int n_vocab_tgt = llama_n_vocab(vocab_tgt);
const int n_vocab_dft = llama_n_vocab(model_dft); const int n_vocab_dft = llama_n_vocab(vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft ? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt; : n_vocab_dft - n_vocab_tgt;
@ -122,13 +125,13 @@ int main(int argc, char ** argv) {
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__); LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); n_vocab_tgt, llama_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return 1; return 1;
} }
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_token_get_text(model_tgt, i); const char * token_text_tgt = llama_token_get_text(vocab_tgt, i);
const char * token_text_dft = llama_token_get_text(model_dft, i); const char * token_text_dft = llama_token_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) { if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__); LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i, LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
@ -386,7 +389,7 @@ int main(int argc, char ** argv) {
} }
} }
if (llama_token_is_eog(model_tgt, token_id)) { if (llama_token_is_eog(vocab_tgt, token_id)) {
has_eos = true; has_eos = true;
} }
++n_predict; ++n_predict;

View File

@ -344,6 +344,8 @@ int main(int raw_argc, char ** raw_argv) {
return 1; return 1;
} }
const llama_vocab * vocab = llama_get_vocab(model);
llama_context_params ctx_params = llama_context_default_params(); llama_context_params ctx_params = llama_context_default_params();
llama_context * ctx = llama_new_context_with_model(model, ctx_params); llama_context * ctx = llama_new_context_with_model(model, ctx_params);
if (!ctx) { if (!ctx) {
@ -365,7 +367,7 @@ int main(int raw_argc, char ** raw_argv) {
prompt = stdin_buffer.str(); prompt = stdin_buffer.str();
} }
const bool model_wants_add_bos = llama_add_bos_token(model); const bool model_wants_add_bos = llama_add_bos_token(vocab);
const bool add_bos = model_wants_add_bos && !no_bos; const bool add_bos = model_wants_add_bos && !no_bos;
const bool parse_special = !no_parse_special; const bool parse_special = !no_parse_special;
const bool escape = !no_escape; const bool escape = !no_escape;
@ -375,7 +377,7 @@ int main(int raw_argc, char ** raw_argv) {
} }
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
tokens = common_tokenize(model, prompt, add_bos, parse_special); tokens = common_tokenize(vocab, prompt, add_bos, parse_special);
if (printing_ids) { if (printing_ids) {
printf("["); printf("[");

View File

@ -414,15 +414,15 @@ static void prompt_add(llama_tokens & prompt, const llama_tokens & tokens) {
prompt.insert(prompt.end(), tokens.begin(), tokens.end()); prompt.insert(prompt.end(), tokens.begin(), tokens.end());
} }
static void prompt_add(llama_tokens & prompt, const llama_model * model, const std::string & txt, bool add_special, bool parse_special) { static void prompt_add(llama_tokens & prompt, const llama_vocab * vocab, const std::string & txt, bool add_special, bool parse_special) {
auto tmp = common_tokenize(model, txt, add_special, parse_special); auto tmp = common_tokenize(vocab, txt, add_special, parse_special);
prompt_add(prompt, tmp); prompt_add(prompt, tmp);
} }
static void prompt_init(llama_tokens & prompt, const llama_model * model) { static void prompt_init(llama_tokens & prompt, const llama_vocab * vocab) {
prompt.clear(); prompt.clear();
prompt_add(prompt, model, "<|im_start|>\n", true, true); prompt_add(prompt, vocab, "<|im_start|>\n", true, true);
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
@ -462,6 +462,8 @@ int main(int argc, char ** argv) {
model_ttc = llama_init_ttc.model.get(); model_ttc = llama_init_ttc.model.get();
ctx_ttc = llama_init_ttc.context.get(); ctx_ttc = llama_init_ttc.context.get();
const llama_vocab * vocab = llama_get_vocab(model_ttc);
// TODO: refactor in a common struct // TODO: refactor in a common struct
params.model = params.vocoder.model; params.model = params.vocoder.model;
params.model_url = params.vocoder.model_url; params.model_url = params.vocoder.model_url;
@ -499,9 +501,9 @@ int main(int argc, char ** argv) {
std::vector<llama_token> prompt_inp; std::vector<llama_token> prompt_inp;
prompt_init(prompt_inp, model_ttc); prompt_init(prompt_inp, vocab);
prompt_add(prompt_inp, model_ttc, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true); prompt_add(prompt_inp, vocab, "<|text_start|>the<|text_sep|>overall<|text_sep|>package<|text_sep|>from<|text_sep|>just<|text_sep|>two<|text_sep|>people<|text_sep|>is<|text_sep|>pretty<|text_sep|>remarkable<|text_sep|>sure<|text_sep|>i<|text_sep|>have<|text_sep|>some<|text_sep|>critiques<|text_sep|>about<|text_sep|>some<|text_sep|>of<|text_sep|>the<|text_sep|>gameplay<|text_sep|>aspects<|text_sep|>but<|text_sep|>its<|text_sep|>still<|text_sep|>really<|text_sep|>enjoyable<|text_sep|>and<|text_sep|>it<|text_sep|>looks<|text_sep|>lovely<|text_sep|>", false, true);
// convert the input text into the necessary format expected by OuteTTS // convert the input text into the necessary format expected by OuteTTS
{ {
@ -509,10 +511,10 @@ int main(int argc, char ** argv) {
LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str()); LOG_INF("%s: prompt: '%s'\n", __func__, prompt_clean.c_str());
prompt_add(prompt_inp, model_ttc, prompt_clean, false, true); prompt_add(prompt_inp, vocab, prompt_clean, false, true);
} }
prompt_add(prompt_inp, model_ttc, "<|text_end|>\n", false, true); prompt_add(prompt_inp, vocab, "<|text_end|>\n", false, true);
// disabled to save time on tokenizing each time // disabled to save time on tokenizing each time
// TODO: load voices from the json files // TODO: load voices from the json files
@ -549,7 +551,7 @@ it<|t_0.09|><|code_start|><|848|><|1366|><|395|><|1601|><|1513|><|593|><|1302|><
looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|> looks<|t_0.27|><|code_start|><|1281|><|1266|><|1755|><|572|><|248|><|1751|><|1257|><|695|><|1380|><|457|><|659|><|585|><|1315|><|1105|><|1776|><|736|><|24|><|736|><|654|><|1027|><|code_end|>
lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)"; lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|1481|><|1721|><|1123|><|438|><|1246|><|1251|><|795|><|659|><|1381|><|1658|><|217|><|1772|><|562|><|952|><|107|><|1129|><|1112|><|467|><|550|><|1079|><|840|><|1615|><|1469|><|1380|><|168|><|917|><|836|><|1827|><|437|><|583|><|67|><|595|><|1087|><|1646|><|1493|><|1677|><|code_end|>)";
auto tmp = common_tokenize(model_ttc, voice_data, false, true); auto tmp = common_tokenize(vocab, voice_data, false, true);
printf("\n\n"); printf("\n\n");
for (int i = 0; i < tmp.size(); ++i) { for (int i = 0; i < tmp.size(); ++i) {
printf("%d, ", tmp[i]); printf("%d, ", tmp[i]);
@ -735,9 +737,9 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14
const auto * cands = common_sampler_get_candidates(smpl[i]); const auto * cands = common_sampler_get_candidates(smpl[i]);
// is it an end of generation? -> mark the stream as finished // is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model_ttc, new_token_id) || n_decode == n_predict) { if (llama_token_is_eog(vocab, new_token_id) || n_decode == n_predict) {
std::string reason; std::string reason;
if (llama_token_is_eog(model_ttc, new_token_id)) { if (llama_token_is_eog(vocab, new_token_id)) {
reason = "eos"; reason = "eos";
} else { } else {
reason = "n_predict"; reason = "n_predict";

View File

@ -56,7 +56,7 @@ extern "C" {
// TODO: show sample usage // TODO: show sample usage
// //
// struct llama_vocab; // TODO: add in the future struct llama_vocab;
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
struct llama_sampler; struct llama_sampler;
@ -449,16 +449,18 @@ extern "C" {
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_n_layer (const struct llama_model * model); LLAMA_API int32_t llama_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_n_head (const struct llama_model * model); LLAMA_API int32_t llama_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab);
LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx);
LLAMA_API const struct llama_vocab * llama_get_vocab(const struct llama_model * model);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_vocab * vocab);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
// Get the model's RoPE frequency scaling factor // Get the model's RoPE frequency scaling factor
@ -488,6 +490,9 @@ extern "C" {
// Returns the total size of all the tensors in the model in bytes // Returns the total size of all the tensors in the model in bytes
LLAMA_API uint64_t llama_model_size(const struct llama_model * model); LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
// Get the default chat template. Returns nullptr if not available
LLAMA_API const char * llama_model_chat_template(const struct llama_model * model);
// Returns the total number of parameters in the model // Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model); LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
@ -908,41 +913,41 @@ extern "C" {
// Vocab // Vocab
// //
LLAMA_API const char * llama_token_get_text(const struct llama_model * model, llama_token token); LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token);
LLAMA_API float llama_token_get_score(const struct llama_model * model, llama_token token); LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token);
LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token); LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token);
// Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.)
LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token);
// Identify if Token Id is a control token or a render-able token // Identify if Token Id is a control token or a render-able token
LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token); LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab); // end-of-sentence
LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab); // end-of-turn
LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab); // classification
LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab); // sentence separator
LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab); // next-line
LLAMA_API llama_token llama_token_pad(const struct llama_model * model); // padding LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab); // padding
LLAMA_API bool llama_add_bos_token(const struct llama_model * model); LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab);
LLAMA_API bool llama_add_eos_token(const struct llama_model * model); LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab);
// infill tokens // infill tokens
DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead"); DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_vocab * vocab), "use llama_token_fim_pre instead");
DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead"); DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_vocab * vocab), "use llama_token_fim_mid instead");
DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead"); DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_vocab * vocab), "use llama_token_fim_suf instead");
LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model); LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model); LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model); LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model); LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model); LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab);
LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model); LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab);
// //
// Tokenization // Tokenization
@ -958,7 +963,7 @@ extern "C" {
/// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated /// @param parse_special Allow tokenizing special and/or control tokens which otherwise are not exposed and treated
/// as plaintext. Does not insert a leading space. /// as plaintext. Does not insert a leading space.
LLAMA_API int32_t llama_tokenize( LLAMA_API int32_t llama_tokenize(
const struct llama_model * model, const struct llama_vocab * vocab,
const char * text, const char * text,
int32_t text_len, int32_t text_len,
llama_token * tokens, llama_token * tokens,
@ -972,7 +977,7 @@ extern "C" {
// User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix') // User can skip up to 'lstrip' leading spaces before copying (useful when encoding/decoding multiple tokens with 'add_space_prefix')
// @param special If true, special tokens are rendered in the output. // @param special If true, special tokens are rendered in the output.
LLAMA_API int32_t llama_token_to_piece( LLAMA_API int32_t llama_token_to_piece(
const struct llama_model * model, const struct llama_vocab * vocab,
llama_token token, llama_token token,
char * buf, char * buf,
int32_t length, int32_t length,
@ -986,7 +991,7 @@ extern "C" {
/// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so. /// @param remove_special Allow to remove BOS and EOS tokens if model is configured to do so.
/// @param unparse_special If true, special tokens are rendered in the output. /// @param unparse_special If true, special tokens are rendered in the output.
LLAMA_API int32_t llama_detokenize( LLAMA_API int32_t llama_detokenize(
const struct llama_model * model, const struct llama_vocab * vocab,
const llama_token * tokens, const llama_token * tokens,
int32_t n_tokens, int32_t n_tokens,
char * text, char * text,
@ -1009,7 +1014,6 @@ extern "C" {
/// @param length The size of the allocated buffer /// @param length The size of the allocated buffer
/// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template. /// @return The total number of bytes of the formatted prompt. If is it larger than the size of buffer, you may need to re-alloc it and then re-apply the template.
LLAMA_API int32_t llama_chat_apply_template( LLAMA_API int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * tmpl, const char * tmpl,
const struct llama_chat_message * chat, const struct llama_chat_message * chat,
size_t n_msg, size_t n_msg,
@ -1057,7 +1061,6 @@ extern "C" {
// llama_sampler_free(smpl); // llama_sampler_free(smpl);
// //
// TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU).
// TODO: in the future, the entire sampling API that uses llama_model should start using llama_vocab
// //
typedef void * llama_sampler_context_t; typedef void * llama_sampler_context_t;
@ -1157,7 +1160,7 @@ extern "C" {
float eta); float eta);
LLAMA_API struct llama_sampler * llama_sampler_init_grammar( LLAMA_API struct llama_sampler * llama_sampler_init_grammar(
const struct llama_model * model, const struct llama_vocab * vocab,
const char * grammar_str, const char * grammar_str,
const char * grammar_root); const char * grammar_root);
@ -1169,8 +1172,9 @@ extern "C" {
float penalty_present); // 0.0 = disabled float penalty_present); // 0.0 = disabled
/// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982 /// @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
LLAMA_API struct llama_sampler * llama_sampler_init_dry( LLAMA_API struct llama_sampler * llama_sampler_init_dry(
const struct llama_model * model, const struct llama_vocab * vocab,
int32_t n_ctx_train,
float dry_multiplier, float dry_multiplier,
float dry_base, float dry_base,
int32_t dry_allowed_length, int32_t dry_allowed_length,
@ -1204,7 +1208,7 @@ extern "C" {
// 3. discard non-EOG tokens with low prob // 3. discard non-EOG tokens with low prob
// 4. if no tokens are left -> pick EOT // 4. if no tokens are left -> pick EOT
// //
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model); LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab);
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

View File

@ -178,6 +178,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" },
{ LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" },
{ LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" },
{ LLM_KV_TOKENIZER_CHAT_TEMPLATE, "tokenizer.chat.template" },
{ LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" },
{ LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" },
{ LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" },

View File

@ -176,6 +176,7 @@ enum llm_kv {
LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP,
LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_HF_JSON,
LLM_KV_TOKENIZER_RWKV, LLM_KV_TOKENIZER_RWKV,
LLM_KV_TOKENIZER_CHAT_TEMPLATE,
LLM_KV_TOKENIZER_FIM_PRE_ID, LLM_KV_TOKENIZER_FIM_PRE_ID,
LLM_KV_TOKENIZER_FIM_SUF_ID, LLM_KV_TOKENIZER_FIM_SUF_ID,
LLM_KV_TOKENIZER_FIM_MID_ID, LLM_KV_TOKENIZER_FIM_MID_ID,

View File

@ -3738,6 +3738,10 @@ struct llama_model_params llama_model_default_params() {
return result; return result;
} }
const struct llama_vocab * llama_get_vocab(const struct llama_model * model) {
return &model->vocab;
}
void llama_free_model(struct llama_model * model) { void llama_free_model(struct llama_model * model) {
llama_model_free(model); llama_model_free(model);
} }
@ -3746,14 +3750,6 @@ void llama_model_free(struct llama_model * model) {
delete model; delete model;
} }
enum llama_vocab_type llama_vocab_type(const struct llama_model * model) {
return model->vocab.get_type();
}
int32_t llama_n_vocab(const struct llama_model * model) {
return model->hparams.n_vocab;
}
int32_t llama_n_ctx_train(const struct llama_model * model) { int32_t llama_n_ctx_train(const struct llama_model * model) {
return model->hparams.n_ctx_train; return model->hparams.n_ctx_train;
} }
@ -3899,6 +3895,15 @@ uint64_t llama_model_size(const struct llama_model * model) {
return model->size(); return model->size();
} }
const char * llama_model_chat_template(const struct llama_model * model) {
const auto & it = model->gguf_kv.find(LLM_KV(model->arch)(LLM_KV_TOKENIZER_CHAT_TEMPLATE));
if (it == model->gguf_kv.end()) {
return nullptr;
}
return it->second.c_str();
}
uint64_t llama_model_n_params(const struct llama_model * model) { uint64_t llama_model_n_params(const struct llama_model * model) {
return model->n_elements(); return model->n_elements();
} }

View File

@ -371,7 +371,10 @@ void llama_sampler_free(struct llama_sampler * smpl) {
llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
const auto * logits = llama_get_logits_ith(ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
// TODO: do not allocate each time // TODO: do not allocate each time
std::vector<llama_token_data> cur; std::vector<llama_token_data> cur;
@ -1445,7 +1448,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_grammar *) smpl->ctx; const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr); auto * result = llama_sampler_init_grammar(ctx->vocab, nullptr, nullptr);
// copy the state // copy the state
{ {
@ -1481,19 +1484,19 @@ static struct llama_sampler_i llama_sampler_grammar_i = {
/* .free = */ llama_sampler_grammar_free, /* .free = */ llama_sampler_grammar_free,
}; };
struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) { struct llama_sampler * llama_sampler_init_grammar(const struct llama_vocab * vocab, const char * grammar_str, const char * grammar_root) {
auto * ctx = new llama_sampler_grammar; auto * ctx = new llama_sampler_grammar;
if (grammar_str != nullptr && grammar_str[0] != '\0') { if (grammar_str != nullptr && grammar_str[0] != '\0') {
*ctx = { *ctx = {
/* .vocab = */ &vocab, /* .vocab = */ vocab,
/* .grammar_str = */ grammar_str, /* .grammar_str = */ grammar_str,
/* .grammar_root = */ grammar_root, /* .grammar_root = */ grammar_root,
/* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root), /* .grammar = */ llama_grammar_init_impl(vocab, grammar_str, grammar_root),
}; };
} else { } else {
*ctx = { *ctx = {
/* .vocab = */ &vocab, /* .vocab = */ vocab,
/* .grammar_str = */ {}, /* .grammar_str = */ {},
/* .grammar_root = */ {}, /* .grammar_root = */ {},
/* .grammar = */ nullptr, /* .grammar = */ nullptr,
@ -1937,7 +1940,7 @@ static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler
llama_vocab dummy_vocab; llama_vocab dummy_vocab;
// dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying // dummy vocab is passed because it is only needed for raw sequence breaker processing, which we have already done and will simply be copying
auto * result = llama_sampler_init_dry_impl(dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); auto * result = llama_sampler_init_dry(&dummy_vocab, ctx->total_context_size, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0);
// Copy the state, including the processed breakers // Copy the state, including the processed breakers
{ {
@ -1964,7 +1967,7 @@ static struct llama_sampler_i llama_sampler_dry_i = {
/* .free = */ llama_sampler_dry_free, /* .free = */ llama_sampler_dry_free,
}; };
struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0);
std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers; std::unordered_multimap<llama_token, std::vector<llama_token>> processed_breakers;
const int MAX_CHAR_LEN = 40; const int MAX_CHAR_LEN = 40;
@ -1991,7 +1994,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
sequence_break.resize(MAX_CHAR_LEN); sequence_break.resize(MAX_CHAR_LEN);
} }
get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); get_overlapping_token_sequences(*vocab, sequence_break, processed_breakers, MAX_SEQ_LEN);
} }
} }
@ -2014,7 +2017,7 @@ struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vo
// wrapper for test-sampling.cpp // wrapper for test-sampling.cpp
struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) { struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector<std::vector<llama_token>>& seq_breakers) {
llama_vocab dummy_vocab; llama_vocab dummy_vocab;
auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0); auto * result = llama_sampler_init_dry(&dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0);
auto * ctx = (llama_sampler_dry *) result->ctx; auto * ctx = (llama_sampler_dry *) result->ctx;
// Process the token-based sequence breakers // Process the token-based sequence breakers
@ -2314,7 +2317,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_infill *) smpl->ctx; const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
return llama_sampler_init_infill_impl(*ctx->vocab); return llama_sampler_init_infill(ctx->vocab);
} }
static void llama_sampler_infill_free(struct llama_sampler * smpl) { static void llama_sampler_infill_free(struct llama_sampler * smpl) {
@ -2330,14 +2333,13 @@ static struct llama_sampler_i llama_sampler_infill_i = {
/* .free = */ llama_sampler_infill_free, /* .free = */ llama_sampler_infill_free,
}; };
struct llama_sampler * llama_sampler_init_infill_impl( struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) {
const struct llama_vocab & vocab) {
return new llama_sampler { return new llama_sampler {
/* .iface = */ &llama_sampler_infill_i, /* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill { /* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab, /* .vocab = */ vocab,
/* .buf0 = */ std::vector<char>(512), /* .buf0 = */ std::vector<char>(512),
/* .buf1 = */ std::vector<char>(512), /* .buf1 = */ std::vector<char>(512),
}, },
}; };
} }

View File

@ -2,7 +2,9 @@
// TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ? // TODO: rename llama-sampling.h/.cpp to llama-sampler.h/.cpp ?
#include "llama-grammar.h" #include "llama.h"
#include <vector>
struct llama_vocab; struct llama_vocab;
struct llama_grammar; struct llama_grammar;
@ -21,24 +23,6 @@ struct llama_sampler_chain {
mutable int32_t n_sample; mutable int32_t n_sample;
}; };
struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab,
const char * grammar_str,
const char * grammar_root);
struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab);
struct llama_sampler * llama_sampler_init_dry_impl(
const struct llama_vocab & vocab,
int32_t context_size,
float dry_multiplier,
float dry_base,
int32_t dry_allowed_length,
int32_t dry_penalty_last_n,
const char ** seq_breakers,
size_t num_breakers);
struct llama_sampler * llama_sampler_init_dry_testing( struct llama_sampler * llama_sampler_init_dry_testing(
int32_t context_size, int32_t context_size,
float dry_multiplier, float dry_multiplier,

View File

@ -1987,16 +1987,20 @@ void llama_vocab::load(llama_model_loader & ml, const LLM_KV & kv) {
for (auto id : cache_special_tokens) { for (auto id : cache_special_tokens) {
_set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true); _set_tokenid_attr(id, LLAMA_TOKEN_ATTR_RSTRIP, true);
} }
for (auto token : {"</s>"}) { for (const auto * token : {"</s>"}) {
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true); _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, true);
} }
for (auto token : {"<unk>", "<s>", "<|endoftext|>"}) { for (const auto * token : {"<unk>", "<s>", "<|endoftext|>"}) {
_set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false); _set_token_attr(token, LLAMA_TOKEN_ATTR_RSTRIP, false);
} }
} }
} }
} }
int32_t llama_n_vocab(const struct llama_vocab * vocab) {
return vocab->n_vocab();
}
enum llama_vocab_type llama_vocab::get_type() const { enum llama_vocab_type llama_vocab::get_type() const {
return type; return type;
} }
@ -2851,3 +2855,140 @@ std::string llama_vocab::detokenize(const std::vector<llama_token> & tokens, boo
// NOTE: the original tokenizer decodes bytes after collecting the pieces. // NOTE: the original tokenizer decodes bytes after collecting the pieces.
return text; return text;
} }
//
// interface implementation
//
enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) {
return vocab->get_type();
}
const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
return vocab->token_get_text(token);
}
float llama_token_get_score(const struct llama_vocab * vocab, llama_token token) {
return vocab->token_get_score(token);
}
enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token) {
return vocab->token_get_attr(token);
}
bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token) {
return vocab->is_eog(token);
}
bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token) {
return vocab->is_control(token);
}
llama_token llama_token_bos(const struct llama_vocab * vocab) {
return vocab->token_bos();
}
llama_token llama_token_eos(const struct llama_vocab * vocab) {
return vocab->token_eos();
}
llama_token llama_token_eot(const struct llama_vocab * vocab) {
return vocab->token_eot();
}
llama_token llama_token_cls(const struct llama_vocab * vocab) {
return vocab->token_cls();
}
llama_token llama_token_sep(const struct llama_vocab * vocab) {
return vocab->token_sep();
}
llama_token llama_token_nl (const struct llama_vocab * vocab) {
return vocab->token_nl();
}
llama_token llama_token_pad(const struct llama_vocab * vocab) {
return vocab->token_pad();
}
bool llama_add_bos_token(const struct llama_vocab * vocab) {
return vocab->add_bos_token();
}
bool llama_add_eos_token(const struct llama_vocab * vocab) {
return vocab->add_eos_token();
}
llama_token llama_token_prefix(const struct llama_vocab * vocab) {
return vocab->token_prefix();
}
llama_token llama_token_middle(const struct llama_vocab * vocab) {
return vocab->token_middle();
}
llama_token llama_token_suffix(const struct llama_vocab * vocab) {
return vocab->token_suffix();
}
llama_token llama_token_fim_pre(const struct llama_vocab * vocab) {
return vocab->token_fim_pre();
}
llama_token llama_token_fim_suf(const struct llama_vocab * vocab) {
return vocab->token_fim_suf();
}
llama_token llama_token_fim_mid(const struct llama_vocab * vocab) {
return vocab->token_fim_mid();
}
llama_token llama_token_fim_pad(const struct llama_vocab * vocab) {
return vocab->token_fim_pad();
}
llama_token llama_token_fim_rep(const struct llama_vocab * vocab) {
return vocab->token_fim_rep();
}
llama_token llama_token_fim_sep(const struct llama_vocab * vocab) {
return vocab->token_fim_sep();
}
//
// tokenization
//
int32_t llama_tokenize(
const struct llama_vocab * vocab,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special) {
return vocab->tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
}
int32_t llama_token_to_piece(
const struct llama_vocab * vocab,
llama_token token,
char * buf,
int32_t length,
int32_t lstrip,
bool special) {
return vocab->token_to_piece(token, buf, length, lstrip, special);
}
int32_t llama_detokenize(
const struct llama_vocab * vocab,
const llama_token * tokens,
int32_t n_tokens,
char * text,
int32_t text_len_max,
bool remove_special,
bool unparse_special) {
return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
}

View File

@ -9235,7 +9235,7 @@ static void llama_kv_cache_update_impl(struct llama_context & lctx) {
// build worst-case graph // build worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch); uint32_t n_tokens = std::min(lctx.cparams.n_ctx, lctx.cparams.n_ubatch);
llama_token token = llama_token_bos(&lctx.model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_token token = lctx.model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true); ggml_cgraph * gf = llama_build_graph(lctx, ubatch, true);
@ -9805,7 +9805,7 @@ struct llama_context * llama_new_context_with_model(
// initialize scheduler with the worst-case graph // initialize scheduler with the worst-case graph
uint32_t n_seqs = 1; // TODO: worst-case number of sequences uint32_t n_seqs = 1; // TODO: worst-case number of sequences
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
llama_token token = llama_token_bos(&ctx->model); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_token token = ctx->model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true); ggml_cgraph * gf_pp = llama_build_graph(*ctx, ubatch_pp, true);
@ -9954,166 +9954,18 @@ int32_t llama_decode(
return ret; return ret;
} }
//
// vocab
//
// TODO: tmp bridges below until `struct llama_vocab` is exposed through the public API
const char * llama_token_get_text(const struct llama_model * model, llama_token token) {
return model->vocab.token_get_text(token);
}
float llama_token_get_score(const struct llama_model * model, llama_token token) {
return model->vocab.token_get_score(token);
}
enum llama_token_attr llama_token_get_attr(const struct llama_model * model, llama_token token) {
return model->vocab.token_get_attr(token);
}
bool llama_token_is_eog(const struct llama_model * model, llama_token token) {
return model->vocab.is_eog(token);
}
bool llama_token_is_control(const struct llama_model * model, llama_token token) {
return model->vocab.is_control(token);
}
llama_token llama_token_bos(const struct llama_model * model) {
return model->vocab.token_bos();
}
llama_token llama_token_eos(const struct llama_model * model) {
return model->vocab.token_eos();
}
llama_token llama_token_eot(const struct llama_model * model) {
return model->vocab.token_eot();
}
llama_token llama_token_cls(const struct llama_model * model) {
return model->vocab.token_cls();
}
llama_token llama_token_sep(const struct llama_model * model) {
return model->vocab.token_sep();
}
llama_token llama_token_nl (const struct llama_model * model) {
return model->vocab.token_nl();
}
llama_token llama_token_pad(const struct llama_model * model) {
return model->vocab.token_pad();
}
bool llama_add_bos_token(const struct llama_model * model) {
return model->vocab.add_bos_token();
}
bool llama_add_eos_token(const struct llama_model * model) {
return model->vocab.add_eos_token();
}
llama_token llama_token_prefix(const struct llama_model * model) {
return model->vocab.token_prefix();
}
llama_token llama_token_middle(const struct llama_model * model) {
return model->vocab.token_middle();
}
llama_token llama_token_suffix(const struct llama_model * model) {
return model->vocab.token_suffix();
}
llama_token llama_token_fim_pre(const struct llama_model * model) {
return model->vocab.token_fim_pre();
}
llama_token llama_token_fim_suf(const struct llama_model * model) {
return model->vocab.token_fim_suf();
}
llama_token llama_token_fim_mid(const struct llama_model * model) {
return model->vocab.token_fim_mid();
}
llama_token llama_token_fim_pad(const struct llama_model * model) {
return model->vocab.token_fim_pad();
}
llama_token llama_token_fim_rep(const struct llama_model * model) {
return model->vocab.token_fim_rep();
}
llama_token llama_token_fim_sep(const struct llama_model * model) {
return model->vocab.token_fim_sep();
}
//
// tokenization
//
int32_t llama_tokenize(
const struct llama_model * model,
const char * text,
int32_t text_len,
llama_token * tokens,
int32_t n_tokens_max,
bool add_special,
bool parse_special) {
return model->vocab.tokenize(text, text_len, tokens, n_tokens_max, add_special, parse_special);
}
int32_t llama_token_to_piece(
const struct llama_model * model,
llama_token token,
char * buf,
int32_t length,
int32_t lstrip,
bool special) {
return model->vocab.token_to_piece(token, buf, length, lstrip, special);
}
int32_t llama_detokenize(
const struct llama_model * model,
const llama_token * tokens,
int32_t n_tokens,
char * text,
int32_t text_len_max,
bool remove_special,
bool unparse_special) {
return model->vocab.detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
}
// //
// chat templates // chat templates
// //
int32_t llama_chat_apply_template( int32_t llama_chat_apply_template(
const struct llama_model * model,
const char * tmpl, const char * tmpl,
const struct llama_chat_message * chat, const struct llama_chat_message * chat,
size_t n_msg, size_t n_msg,
bool add_ass, bool add_ass,
char * buf, char * buf,
int32_t length) { int32_t length) {
std::string curr_tmpl(tmpl == nullptr ? "" : tmpl); const std::string curr_tmpl(tmpl == nullptr ? "chatml" : tmpl);
if (tmpl == nullptr) {
GGML_ASSERT(model != nullptr);
// load template from model, if available
const auto & it = model->gguf_kv.find("tokenizer.chat_template");
if (it != model->gguf_kv.end() && it->second.size() > 0) {
curr_tmpl = it->second;
}
else {
// worst case: there is no information about template, we will use chatml by default
curr_tmpl = "chatml"; // see llm_chat_apply_template
}
}
// format the chat to string // format the chat to string
std::vector<const llama_chat_message *> chat_vec; std::vector<const llama_chat_message *> chat_vec;
@ -10137,23 +9989,6 @@ int32_t llama_chat_apply_template(
return res; return res;
} }
//
// sampling
//
// TODO: remove indirection when vocab becomes accesible in llama-sampling.cpp
struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * model, const char * grammar_str, const char * grammar_root) {
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
}
struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
return llama_sampler_init_infill_impl(model->vocab);
}
struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
return llama_sampler_init_dry_impl(model->vocab, llama_n_ctx_train(model), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
}
// //
// model split // model split
// //

View File

@ -157,7 +157,7 @@ int main(void) {
} }
// test invalid chat template // test invalid chat template
res = llama_chat_apply_template(nullptr, "INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size()); res = llama_chat_apply_template("INVALID TEMPLATE", conversation, message_count, true, formatted_chat.data(), formatted_chat.size());
assert(res < 0); assert(res < 0);
for (size_t i = 0; i < templates.size(); i++) { for (size_t i = 0; i < templates.size(); i++) {
@ -165,7 +165,6 @@ int main(void) {
std::string expected = expected_output[i]; std::string expected = expected_output[i];
formatted_chat.resize(1024); formatted_chat.resize(1024);
res = llama_chat_apply_template( res = llama_chat_apply_template(
nullptr,
custom_template.c_str(), custom_template.c_str(),
conversation, conversation,
message_count, message_count,

View File

@ -64,8 +64,10 @@ int main(int argc, char **argv) {
} }
} }
//GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_BPE); const llama_vocab * vocab = llama_get_vocab(model);
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_BPE) {
//GGML_ASSERT(llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_BPE);
if (llama_vocab_type(vocab) != LLAMA_VOCAB_TYPE_BPE) {
return 99; return 99;
} }
@ -75,7 +77,7 @@ int main(int argc, char **argv) {
atexit([]() { console::cleanup(); }); atexit([]() { console::cleanup(); });
#endif #endif
const int n_vocab = llama_n_vocab(model); const int n_vocab = llama_n_vocab(vocab);
for (int i = 0; i < n_vocab; ++i) { for (int i = 0; i < n_vocab; ++i) {
std::string str = common_detokenize(ctx, std::vector<int>(1, i)); std::string str = common_detokenize(ctx, std::vector<int>(1, i));

View File

@ -52,8 +52,10 @@ int main(int argc, char ** argv) {
} }
} }
const llama_vocab * vocab = llama_get_vocab(model);
//GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM);
if (llama_vocab_type(model) != LLAMA_VOCAB_TYPE_SPM) { if (llama_vocab_type(vocab) != LLAMA_VOCAB_TYPE_SPM) {
return 99; return 99;
} }
@ -63,7 +65,7 @@ int main(int argc, char ** argv) {
atexit([]() { console::cleanup(); }); atexit([]() { console::cleanup(); });
#endif #endif
const int n_vocab = llama_n_vocab(model); const int n_vocab = llama_n_vocab(vocab);
for (int i = 0; i < n_vocab; ++i) { for (int i = 0; i < n_vocab; ++i) {
std::string str = common_detokenize(ctx, std::vector<int>(1, i), true); std::string str = common_detokenize(ctx, std::vector<int>(1, i), true);

View File

@ -76,7 +76,7 @@ class LibLlamaModel:
self.ffi = libllama.ffi self.ffi = libllama.ffi
if isinstance(mparams, dict): if isinstance(mparams, dict):
mparams = libllama.model_default_params(**mparams) mparams = libllama.model_default_params(**mparams)
self.model = self.lib.llama_load_model_from_file(path_model.encode(), mparams) self.model = self.lib.llama_model_load_from_file(path_model.encode(), mparams)
if not self.model: if not self.model:
raise RuntimeError("error: failed to load model '%s'" % path_model) raise RuntimeError("error: failed to load model '%s'" % path_model)
if isinstance(cparams, dict): if isinstance(cparams, dict):
@ -92,7 +92,7 @@ class LibLlamaModel:
if self.ctx: if self.ctx:
self.lib.llama_free(self.ctx) self.lib.llama_free(self.ctx)
if self.model: if self.model:
self.lib.llama_free_model(self.model) self.lib.llama_model_free(self.model)
self.ctx = None self.ctx = None
self.model = None self.model = None
self.lib = None self.lib = None