diff --git a/common/common.cpp b/common/common.cpp index a30d6920e..07ac7ccad 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -857,22 +857,22 @@ struct common_init_result common_init_from_params(common_params & params) { return iparams; } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); if (params.reranking) { bool ok = true; - if (llama_token_bos(vocab) == LLAMA_TOKEN_NULL) { + if (llama_vocab_bos(vocab) == LLAMA_TOKEN_NULL) { LOG_WRN("%s: warning: vocab does not have a BOS token, reranking will not work\n", __func__); ok = false; } - if (llama_token_eos(vocab) == LLAMA_TOKEN_NULL) { + if (llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { LOG_WRN("%s: warning: vocab does not have an EOS token, reranking will not work\n", __func__); ok = false; } - if (llama_token_sep(vocab) == LLAMA_TOKEN_NULL) { + if (llama_vocab_sep(vocab) == LLAMA_TOKEN_NULL) { LOG_WRN("%s: warning: vocab does not have a SEP token, reranking will not work\n", __func__); ok = false; } @@ -886,7 +886,7 @@ struct common_init_result common_init_from_params(common_params & params) { auto cparams = common_context_params_to_llama(params); - llama_context * lctx = llama_new_context_with_model(model, cparams); + llama_context * lctx = llama_init_from_model(model, cparams); if (lctx == NULL) { LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.c_str()); llama_model_free(model); @@ -900,7 +900,7 @@ struct common_init_result common_init_from_params(common_params & params) { if (!params.control_vectors.empty()) { if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1; - if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model); + if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_model_n_layer(model); const auto cvec = common_control_vector_load(params.control_vectors); if (cvec.n_embd == -1) { @@ -944,14 +944,14 @@ struct common_init_result common_init_from_params(common_params & params) { common_set_adapter_lora(lctx, params.lora_adapters); } - if (params.sampling.ignore_eos && llama_token_eos(vocab) == LLAMA_TOKEN_NULL) { + if (params.sampling.ignore_eos && llama_vocab_eos(vocab) == LLAMA_TOKEN_NULL) { LOG_WRN("%s: warning: vocab does not have an EOS token, ignoring --ignore-eos\n", __func__); params.sampling.ignore_eos = false; } if (params.sampling.ignore_eos) { - for (llama_token i = 0; i < llama_n_vocab(vocab); i++) { - if (llama_token_is_eog(vocab, i)) { + for (llama_token i = 0; i < llama_vocab_n_vocab(vocab); i++) { + if (llama_vocab_is_eog(vocab, i)) { LOG_INF("%s: added %s logit bias = %f\n", __func__, common_token_to_piece(lctx, i).c_str(), -INFINITY); params.sampling.logit_bias.push_back({i, -INFINITY}); } @@ -972,8 +972,8 @@ struct common_init_result common_init_from_params(common_params & params) { LOG_WRN("%s: warming up the model with an empty run - please wait ... (--no-warmup to disable)\n", __func__); std::vector tmp; - llama_token bos = llama_token_bos(vocab); - llama_token eos = llama_token_eos(vocab); + llama_token bos = llama_vocab_bos(vocab); + llama_token eos = llama_vocab_eos(vocab); // some models (e.g. T5) don't have a BOS token if (bos != LLAMA_TOKEN_NULL) { @@ -1564,7 +1564,7 @@ std::vector common_tokenize( bool add_special, bool parse_special) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); return common_tokenize(vocab, text, add_special, parse_special); } @@ -1589,7 +1589,7 @@ std::vector common_tokenize( std::string common_token_to_piece(const struct llama_context * ctx, llama_token token, bool special) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); return common_token_to_piece(vocab, token, special); } @@ -1611,7 +1611,7 @@ std::string common_token_to_piece(const struct llama_vocab * vocab, llama_token std::string common_detokenize(const struct llama_context * ctx, const std::vector & tokens, bool special) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); return common_detokenize(vocab, tokens, special); } diff --git a/common/sampling.cpp b/common/sampling.cpp index e9f7701b4..1d2c1815e 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -114,9 +114,9 @@ struct common_sampler { const auto * logits = llama_get_logits_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); cur.resize(n_vocab); @@ -145,7 +145,7 @@ std::string common_params_sampling::print() const { } struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); @@ -162,7 +162,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(result->chain, llama_sampler_init_logit_bias( - llama_n_vocab(vocab), + llama_vocab_n_vocab(vocab), params.logit_bias.size(), params.logit_bias.data())); @@ -177,7 +177,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co c_breakers.push_back(str.c_str()); } - llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); + llama_sampler_chain_add(result->chain, llama_sampler_init_dry (vocab, llama_model_n_ctx_train(model), params.dry_multiplier, params.dry_base, params.dry_allowed_length, params.dry_penalty_last_n, c_breakers.data(), c_breakers.size())); } break; case COMMON_SAMPLER_TYPE_TOP_K: @@ -211,7 +211,7 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(result->chain, llama_sampler_init_dist(params.seed)); } else if (params.mirostat == 1) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); - llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_n_vocab(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); + llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat(llama_vocab_n_vocab(vocab), params.seed, params.mirostat_tau, params.mirostat_eta, 100)); } else if (params.mirostat == 2) { llama_sampler_chain_add(result->chain, llama_sampler_init_temp(params.temp)); llama_sampler_chain_add(result->chain, llama_sampler_init_mirostat_v2(params.seed, params.mirostat_tau, params.mirostat_eta)); diff --git a/common/speculative.cpp b/common/speculative.cpp index eb3bfe301..a582122ee 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -79,8 +79,8 @@ bool common_speculative_are_compatible( const struct llama_model * model_tgt = llama_get_model(ctx_tgt); const struct llama_model * model_dft = llama_get_model(ctx_dft); - const struct llama_vocab * vocab_tgt = llama_get_vocab(model_tgt); - const struct llama_vocab * vocab_dft = llama_get_vocab(model_dft); + const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); @@ -94,32 +94,32 @@ bool common_speculative_are_compatible( return false; } - if (llama_add_bos_token(vocab_tgt) != llama_add_bos_token(vocab_dft) || - llama_add_eos_token(vocab_tgt) != llama_add_eos_token(vocab_dft) || - llama_token_bos(vocab_tgt) != llama_token_bos(vocab_dft) || - llama_token_eos(vocab_tgt) != llama_token_eos(vocab_dft)) { + if (llama_vocab_add_bos(vocab_tgt) != llama_vocab_add_bos(vocab_dft) || + llama_vocab_add_eos(vocab_tgt) != llama_vocab_add_eos(vocab_dft) || + llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || + llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft)) { LOG_ERR("%s: draft vocab special tokens must match target vocab to use speculation\n", __func__); - LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(vocab_tgt), llama_add_bos_token(vocab_tgt), llama_token_eos(vocab_tgt), llama_add_eos_token(vocab_tgt)); - LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_token_bos(vocab_dft), llama_add_bos_token(vocab_dft), llama_token_eos(vocab_dft), llama_add_eos_token(vocab_dft)); + LOG_ERR("%s: tgt: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_tgt), llama_vocab_add_bos(vocab_tgt), llama_vocab_eos(vocab_tgt), llama_vocab_add_eos(vocab_tgt)); + LOG_ERR("%s: dft: bos = %d (%d), eos = %d (%d)\n", __func__, llama_vocab_bos(vocab_dft), llama_vocab_add_bos(vocab_dft), llama_vocab_eos(vocab_dft), llama_vocab_add_eos(vocab_dft)); return false; } { - const int n_vocab_tgt = llama_n_vocab(vocab_tgt); - const int n_vocab_dft = llama_n_vocab(vocab_dft); + const int n_vocab_tgt = llama_vocab_n_vocab(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_vocab(vocab_dft); const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", - __func__, n_vocab_tgt, llama_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + __func__, n_vocab_tgt, llama_vocab_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return false; } for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_token_get_text(vocab_tgt, i); - const char * token_text_dft = llama_token_get_text(vocab_dft, i); + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { LOG_ERR("%s: draft vocab vocab must match target vocab to use speculation but " "token %d content differs - target '%s', draft '%s'\n", __func__, i, diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index dd75ff9f1..0659ab6f1 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -50,7 +50,7 @@ int main(int argc, char ** argv) { // ensure enough sequences are available ctx_params.n_seq_max = n_pl.empty() ? 1 : *std::max_element(n_pl.begin(), n_pl.end()); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index d196a01ed..371917b2e 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -141,7 +141,7 @@ while n_cur <= n_len { let new_token_id = llama_sampler_sample(smpl, context, i_batch[i]) // is it an end of stream? -> mark the stream as finished - if llama_token_is_eog(model, new_token_id) || n_cur == n_len { + if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len { i_batch[i] = -1 // print("") if n_parallel > 1 { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 3ce9cee9c..21b95ef5e 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -48,7 +48,7 @@ int main(int argc, char ** argv) { return 1; } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); // tokenize the prompt @@ -64,7 +64,7 @@ int main(int argc, char ** argv) { ctx_params.n_ctx = n_kv_req; ctx_params.n_batch = std::max(n_predict, n_parallel); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = false; @@ -123,7 +123,7 @@ int main(int argc, char ** argv) { llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == LLAMA_TOKEN_NULL) { - decoder_start_token_id = llama_token_bos(vocab); + decoder_start_token_id = llama_vocab_bos(vocab); } common_batch_clear(batch); @@ -176,7 +176,7 @@ int main(int argc, char ** argv) { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]); // is it an end of generation? -> mark the stream as finished - if (llama_token_is_eog(vocab, new_token_id) || n_cur == n_predict) { + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { i_batch[i] = -1; LOG("\n"); if (n_parallel > 1) { diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 1256abb17..bdf0eed2a 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -911,7 +911,7 @@ int main(int argc, char ** argv) { load_vocab(params.fn_vocab_model, &config, &vocab); struct my_llama_model model; - model.hparams.n_vocab = config.vocab_size; //llama_n_vocab(lctx); + model.hparams.n_vocab = config.vocab_size; //llama_vocab_n_vocab(lctx); model.hparams.n_ctx = params.n_ctx; model.hparams.n_embd = config.dim; //params.n_embd; model.hparams.n_ff = config.hidden_dim; diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 8cebf237f..384df05fe 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -274,8 +274,8 @@ struct tokenized_prompt { tokenized_prompt(llama_context * ctx, std::string pos, std::string neg) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); - const bool add_bos = llama_add_bos_token(vocab); + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_add_bos(vocab); tokens_pos = common_tokenize(ctx, pos, add_bos, true); tokens_neg = common_tokenize(ctx, neg, add_bos, true); max_seq_len = std::max(tokens_pos.size(), tokens_neg.size()); @@ -423,8 +423,8 @@ int main(int argc, char ** argv) { llama_context * ctx = llama_init.context.get(); // int n_ctx = llama_n_ctx(ctx); - int n_layers = llama_n_layer(model); - int n_embd = llama_n_embd(model); + int n_layers = llama_model_n_layer(model); + int n_embd = llama_model_n_embd(model); // get model hint param (a.k.a model arch name) char model_hint[128]; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index a40934bca..38d22c90f 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -105,9 +105,9 @@ int main(int argc, char ** argv) { return 1; } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); @@ -150,7 +150,7 @@ int main(int argc, char ** argv) { // check if the last token is SEP // it should be automatically added by the tokenizer when 'tokenizer.ggml.add_eos_token' is set to 'true' for (auto & inp : inputs) { - if (inp.empty() || inp.back() != llama_token_sep(vocab)) { + if (inp.empty() || inp.back() != llama_vocab_sep(vocab)) { LOG_WRN("%s: last token in the prompt is not SEP\n", __func__); LOG_WRN("%s: 'tokenizer.ggml.add_eos_token' should be set to 'true' in the GGUF header\n", __func__); } @@ -183,7 +183,7 @@ int main(int argc, char ** argv) { } // allocate output - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embeddings(n_embd_count * n_embd, 0); float * emb = embeddings.data(); diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index 3e89454a3..65577157e 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -128,9 +128,9 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { static bool run(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - const bool add_bos = llama_add_bos_token(vocab); + const bool add_bos = llama_vocab_add_bos(vocab); std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); diff --git a/examples/export-lora/export-lora.cpp b/examples/export-lora/export-lora.cpp index 124184f77..99063b5d5 100644 --- a/examples/export-lora/export-lora.cpp +++ b/examples/export-lora/export-lora.cpp @@ -8,7 +8,6 @@ #include #include #include -#include #include static bool g_verbose = false; diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index b88090bd0..72eb46257 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -11,7 +11,7 @@ static std::vector> encode(llama_context * ctx, const std::ve std::vector> result; const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); @@ -26,7 +26,7 @@ static std::vector> encode(llama_context * ctx, const std::ve // GritLM seems to have EOS = "" // https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18 - // inputs.push_back(llama_token_eos(vocab)); + // inputs.push_back(llama_vocab_eos(vocab)); // we want to ignore instruction tokens for mean pooling const int32_t n_inst = common_tokenize(vocab, instruction, true, false).size(); @@ -53,7 +53,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_decode(ctx, batch); // get embedding dimensions - uint64_t n_embd = llama_n_embd(model); + uint64_t n_embd = llama_model_n_embd(model); // allocate embedding output std::vector emb_unorm(n_embd, 0.0f); @@ -98,9 +98,9 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std std::string result; const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - llama_token eos_token = llama_token_eos(vocab); + llama_token eos_token = llama_vocab_eos(vocab); llama_kv_cache_clear(ctx); llama_set_embeddings(ctx, false); @@ -171,7 +171,7 @@ int main(int argc, char * argv[]) { llama_model * model = llama_model_load_from_file(params.model.c_str(), mparams); // create generation context - llama_context * ctx = llama_new_context_with_model(model, cparams); + llama_context * ctx = llama_init_from_model(model, cparams); auto sparams = llama_sampler_chain_default_params(); @@ -200,7 +200,7 @@ int main(int argc, char * argv[]) { const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); const float cosine_sim_q0_d0 = common_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd); const float cosine_sim_q0_d1 = common_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index eec653142..b63fdf4d6 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -40,7 +39,7 @@ public: void set_params(common_params params) { m_params = std::move(params); } bool collect_imatrix(struct ggml_tensor * t, bool ask, void * user_data); void save_imatrix(int ncall = -1) const; - bool load_imatrix(const char * file_name); + bool load_imatrix(const char * fname); private: std::unordered_map m_stats; common_params m_params; @@ -430,12 +429,12 @@ static void process_logits( static bool compute_imatrix(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - const bool add_bos = llama_add_bos_token(vocab); + const bool add_bos = llama_vocab_add_bos(vocab); const int n_ctx = llama_n_ctx(ctx); - GGML_ASSERT(!llama_add_eos_token(vocab)); + GGML_ASSERT(!llama_vocab_add_eos(vocab)); auto tim1 = std::chrono::high_resolution_clock::now(); LOG_INF("%s: tokenizing the input ..\n", __func__); @@ -471,7 +470,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { const int n_chunk_max = tokens.size() / n_ctx; const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); const int n_batch = params.n_batch; int count = 0; @@ -511,7 +510,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(vocab); + tokens[batch_start] = llama_vocab_bos(vocab); } common_batch_clear(batch); @@ -630,7 +629,7 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); if (params.n_ctx > n_ctx_train) { LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, params.n_ctx); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 70e899fe2..eb848414a 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -139,9 +139,9 @@ int main(int argc, char ** argv) { return 1; } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); LOG_DBG("n_ctx: %d\n", n_ctx); @@ -154,28 +154,28 @@ int main(int argc, char ** argv) { LOG_INF("\n"); LOG_INF("%s\n", common_params_get_system_info(params).c_str()); } - const bool add_bos = llama_add_bos_token(vocab); - GGML_ASSERT(!llama_add_eos_token(vocab)); + const bool add_bos = llama_vocab_add_bos(vocab); + GGML_ASSERT(!llama_vocab_add_eos(vocab)); std::vector embd_inp; std::vector embd_end; std::vector inp_pfx = common_tokenize(ctx, params.input_prefix, false); std::vector inp_sfx = common_tokenize(ctx, params.input_suffix, false); - GGML_ASSERT(llama_token_fim_pre(vocab) >= 0); - GGML_ASSERT(llama_token_fim_suf(vocab) >= 0); + GGML_ASSERT(llama_vocab_fim_pre(vocab) >= 0); + GGML_ASSERT(llama_vocab_fim_suf(vocab) >= 0); - inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(vocab)); - inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(vocab)); + inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab)); + inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab)); embd_inp = params.spm_infill ? inp_sfx : inp_pfx; embd_end = params.spm_infill ? inp_pfx : inp_sfx; if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(vocab)); + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); } embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - const llama_token middle_token = llama_token_fim_mid(vocab); + const llama_token middle_token = llama_vocab_fim_mid(vocab); if (middle_token >= 0) { embd_inp.push_back(middle_token); } @@ -187,7 +187,7 @@ int main(int argc, char ** argv) { // Should not run without any tokens if (embd_inp.empty()) { - embd_inp.push_back(llama_token_bos(vocab)); + embd_inp.push_back(llama_vocab_bos(vocab)); LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); } @@ -422,10 +422,10 @@ int main(int argc, char ** argv) { // if not currently processing queued inputs; if ((int) embd_inp.size() <= n_consumed) { // deal with eot token in infill mode - if ((common_sampler_last(smpl) == llama_token_eot(vocab) || is_interacting) && params.interactive){ + if ((common_sampler_last(smpl) == llama_vocab_eot(vocab) || is_interacting) && params.interactive){ if (is_interacting && !params.interactive_first) { // print an eot token - LOG("%s", common_token_to_piece(ctx, llama_token_eot(vocab)).c_str()); + LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str()); } LOG("\n"); console::set_display(console::user_input); @@ -465,13 +465,13 @@ int main(int argc, char ** argv) { std::vector inp_pfx = common_tokenize(ctx, params.input_prefix, false); std::vector inp_sfx = common_tokenize(ctx, params.input_suffix, false); - inp_pfx.insert(inp_pfx.begin(), llama_token_fim_pre(vocab)); - inp_sfx.insert(inp_sfx.begin(), llama_token_fim_suf(vocab)); + inp_pfx.insert(inp_pfx.begin(), llama_vocab_fim_pre(vocab)); + inp_sfx.insert(inp_sfx.begin(), llama_vocab_fim_suf(vocab)); embd_inp = params.spm_infill ? inp_sfx : inp_pfx; embd_end = params.spm_infill ? inp_pfx : inp_sfx; if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(vocab)); + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); } embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); @@ -486,7 +486,7 @@ int main(int argc, char ** argv) { is_interacting = false; } // deal with end of generation tokens in interactive mode - else if (llama_token_is_eog(vocab, common_sampler_last(smpl))) { + else if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { LOG_DBG("found EOS token\n"); if (params.interactive) { @@ -502,7 +502,7 @@ int main(int argc, char ** argv) { if (params.input_prefix_bos) { LOG_DBG("adding input prefix BOS token\n"); - embd_inp.push_back(llama_token_bos(vocab)); + embd_inp.push_back(llama_vocab_bos(vocab)); } std::string buffer; @@ -565,7 +565,7 @@ int main(int argc, char ** argv) { } // end of generation - if (!embd.empty() && llama_token_is_eog(vocab, embd.back()) && !params.interactive) { + if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !params.interactive) { break; } @@ -577,7 +577,7 @@ int main(int argc, char ** argv) { } } if (!params.interactive && n_remain <= 0) { - LOG("%s", common_token_to_piece(ctx, llama_token_eot(vocab)).c_str()); + LOG("%s", common_token_to_piece(ctx, llama_vocab_eot(vocab)).c_str()); } LOG("\n"); diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 3f36b65bd..3439e5bec 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1401,8 +1401,8 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); - const int32_t n_vocab = llama_n_vocab(vocab); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_vocab(vocab); std::vector tokens(n_batch); @@ -1410,7 +1410,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); - tokens[0] = n_processed == 0 && llama_add_bos_token(vocab) ? llama_token_bos(vocab) : std::rand() % n_vocab; + tokens[0] = n_processed == 0 && llama_vocab_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } @@ -1425,10 +1425,10 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); - const int32_t n_vocab = llama_n_vocab(vocab); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int32_t n_vocab = llama_vocab_n_vocab(vocab); - llama_token token = llama_add_bos_token(vocab) ? llama_token_bos(vocab) : std::rand() % n_vocab; + llama_token token = llama_vocab_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { llama_decode(ctx, llama_batch_get_one(&token, 1)); @@ -1539,7 +1539,7 @@ int main(int argc, char ** argv) { prev_inst = &inst; } - llama_context * ctx = llama_new_context_with_model(lmodel, inst.to_llama_cparams()); + llama_context * ctx = llama_init_from_model(lmodel, inst.to_llama_cparams()); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, inst.model.c_str()); llama_model_free(lmodel); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index d5dd78c2c..99b14961d 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -405,7 +405,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( const auto batch = reinterpret_cast(batch_pointer); const auto sampler = reinterpret_cast(sampler_pointer); const auto model = llama_get_model(context); - const auto vocab = llama_get_vocab(model); + const auto vocab = llama_model_get_vocab(model); if (!la_int_var) la_int_var = env->GetObjectClass(intvar_ncur); if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I"); @@ -415,7 +415,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( const auto new_token_id = llama_sampler_sample(sampler, context, -1); const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value); - if (llama_token_is_eog(vocab, new_token_id) || n_cur == n_len) { + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { return nullptr; } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index f88b1fdcc..477c3e6f2 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -151,7 +151,7 @@ actor LlamaContext { new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) - if llama_token_is_eog(model, new_token_id) || n_cur == n_len { + if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len { print("\n") is_done = true let new_token_str = String(cString: temporary_invalid_cchars + [0]) diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index c065d9097..40aa0876f 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -49,10 +49,10 @@ static const char * sample(struct common_sampler * smpl, common_sampler_accept(smpl, id, true); const llama_model * model = llama_get_model(ctx_llama); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); static std::string ret; - if (llama_token_is_eog(vocab, id)) { + if (llama_vocab_is_eog(vocab, id)) { ret = ""; } else { ret = common_token_to_piece(ctx_llama, id); @@ -243,11 +243,10 @@ static struct llava_context * llava_init_context(common_params * params, llama_m auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - llama_context_params ctx_params = common_context_params_to_llama(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings - llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + llama_context * ctx_llama = llama_init_from_model(model, ctx_params); if (ctx_llama == NULL) { LOG_ERR("%s: failed to create the llama_context\n" , __func__); diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 16f30c56c..c598caf3d 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -384,7 +384,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) { // make sure that the correct mmproj was used, i.e., compare apples to apples - int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_llama_embd = llama_model_n_embd(llama_get_model(ctx_llama)); auto n_image_embd = clip_n_mmproj_embd(ctx_clip); if (n_image_embd != n_llama_embd) { LOG_ERR("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd); @@ -456,7 +456,7 @@ struct llava_embd_batch { }; bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { - int n_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); for (int i = 0; i < image_embed->n_image_pos; i += n_batch) { int n_eval = image_embed->n_image_pos - i; diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 06c961531..38c44e130 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -54,7 +54,7 @@ static struct llava_context * llava_init_context(common_params * params, llama_m ctx_params.n_ctx = params->n_ctx; } - llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + llama_context * ctx_llama = llama_init_from_model(model, ctx_params); if (ctx_llama == NULL) { LOG_ERR("%s: failed to create the llama_context\n" , __func__); @@ -169,10 +169,10 @@ static const char * sample(struct common_sampler * smpl, common_sampler_accept(smpl, id, true); const llama_model * model = llama_get_model(ctx_llama); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); static std::string ret; - if (llama_token_is_eog(vocab, id)) { + if (llama_vocab_is_eog(vocab, id)) { ret = ""; } else { ret = common_token_to_piece(ctx_llama, id); diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 784357d32..132a7da54 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -27,7 +27,7 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past, int * st_pos_id, struct clip_image_size * image_size) { - int n_embd = llama_n_embd(llama_get_model(ctx_llama)); + int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); const int patch_size = 14 * 2; const int ph = image_size->height / patch_size + (image_size->height % patch_size > 0); const int pw = image_size->width / patch_size + (image_size->width % patch_size > 0); @@ -134,10 +134,10 @@ static const char * sample(struct common_sampler * smpl, common_sampler_accept(smpl, id, true); const llama_model * model = llama_get_model(ctx_llama); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); static std::string ret; - if (llama_token_is_eog(vocab, id)) { + if (llama_vocab_is_eog(vocab, id)) { ret = ""; } else { ret = common_token_to_piece(ctx_llama, id); @@ -332,11 +332,10 @@ static struct llava_context * llava_init_context(common_params * params, llama_m auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - llama_context_params ctx_params = common_context_params_to_llama(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings - llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + llama_context * ctx_llama = llama_init_from_model(model, ctx_params); if (ctx_llama == NULL) { LOG_ERR("%s: failed to create the llama_context\n" , __func__); @@ -485,7 +484,7 @@ static void debug_test_mrope_2d() { } static void debug_dump_img_embed(struct llava_context * ctx_llava) { - int n_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama)); + int n_embd = llama_model_n_embd(llama_get_model(ctx_llava->ctx_llama)); int ne = n_embd * 4; float vals[56 * 56 * 3]; // float embd[ne]; diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 576d4fa5d..e23060e24 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -61,7 +61,7 @@ int main(int argc, char ** argv) { llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); // Tokenize the prompt std::vector inp; @@ -149,7 +149,7 @@ int main(int argc, char ** argv) { } // here we keep adding new n-grams as we go - ngram_container ngrams_observed(llama_n_vocab(vocab), N, G); + ngram_container ngrams_observed(llama_vocab_n_vocab(vocab), N, G); // debug struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1); @@ -299,7 +299,7 @@ int main(int argc, char ** argv) { } fflush(stdout); - if (llama_token_is_eog(vocab, id)) { + if (llama_vocab_is_eog(vocab, id)) { has_eos = true; } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index f451f9c38..dbd0444ec 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -36,7 +36,7 @@ int main(int argc, char ** argv){ llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); // tokenize the prompt std::vector inp; @@ -138,7 +138,7 @@ int main(int argc, char ** argv){ LOG("%s", token_str.c_str()); } - if (llama_token_is_eog(vocab, id)) { + if (llama_vocab_is_eog(vocab, id)) { has_eos = true; } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 9e899d6d8..7246fe910 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -5,7 +5,6 @@ #include "sampling.h" #include "llama.h" -#include #include #include #include @@ -163,7 +162,7 @@ int main(int argc, char ** argv) { return 1; } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); LOG_INF("%s: llama threadpool init, n_threads = %d\n", __func__, (int) params.cpuparams.n_threads); @@ -198,7 +197,7 @@ int main(int argc, char ** argv) { llama_attach_threadpool(ctx, threadpool, threadpool_batch); - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); if (n_ctx > n_ctx_train) { @@ -243,9 +242,9 @@ int main(int argc, char ** argv) { } } - const bool add_bos = llama_add_bos_token(vocab); + const bool add_bos = llama_vocab_add_bos(vocab); if (!llama_model_has_encoder(model)) { - GGML_ASSERT(!llama_add_eos_token(vocab)); + GGML_ASSERT(!llama_vocab_add_eos(vocab)); } LOG_DBG("n_ctx: %d, add_bos: %d\n", n_ctx, add_bos); @@ -271,7 +270,7 @@ int main(int argc, char ** argv) { // Should not run without any tokens if (embd_inp.empty()) { if (add_bos) { - embd_inp.push_back(llama_token_bos(vocab)); + embd_inp.push_back(llama_vocab_bos(vocab)); LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str()); } else { LOG_ERR("input is empty\n"); @@ -497,7 +496,7 @@ int main(int argc, char ** argv) { llama_token decoder_start_token_id = llama_model_decoder_start_token(model); if (decoder_start_token_id == LLAMA_TOKEN_NULL) { - decoder_start_token_id = llama_token_bos(vocab); + decoder_start_token_id = llama_vocab_bos(vocab); } embd_inp.clear(); @@ -744,7 +743,7 @@ int main(int argc, char ** argv) { } // deal with end of generation tokens in interactive mode - if (llama_token_is_eog(vocab, common_sampler_last(smpl))) { + if (llama_vocab_is_eog(vocab, common_sampler_last(smpl))) { LOG_DBG("found an EOG token\n"); if (params.interactive) { @@ -778,7 +777,7 @@ int main(int argc, char ** argv) { if (params.input_prefix_bos) { LOG_DBG("adding input prefix BOS token\n"); - embd_inp.push_back(llama_token_bos(vocab)); + embd_inp.push_back(llama_vocab_bos(vocab)); } std::string buffer; @@ -832,8 +831,8 @@ int main(int argc, char ** argv) { // if user stop generation mid-way, we must add EOT to finish model's last response if (need_insert_eot && format_chat) { - llama_token eot = llama_token_eot(vocab); - embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_token_eos(vocab) : eot); + llama_token eot = llama_vocab_eot(vocab); + embd_inp.push_back(eot == LLAMA_TOKEN_NULL ? llama_vocab_eos(vocab) : eot); need_insert_eot = false; } @@ -868,7 +867,7 @@ int main(int argc, char ** argv) { } // end of generation - if (!embd.empty() && llama_token_is_eog(vocab, embd.back()) && !(params.interactive)) { + if (!embd.empty() && llama_vocab_is_eog(vocab, embd.back()) && !(params.interactive)) { LOG(" [end of text]\n"); break; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index cefcdb9d1..7ef43d5e1 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -135,7 +135,7 @@ int main(int argc, char ** argv) { llama_model * model = llama_init.model.get(); llama_context * ctx = llama_init.context.get(); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); // load the prompts from an external file if there are any if (params.prompt.empty()) { @@ -360,7 +360,7 @@ int main(int argc, char ** argv) { // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); if (client.n_decoded > 2 && - (llama_token_is_eog(vocab, id) || + (llama_vocab_is_eog(vocab, id) || (params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) || client.response.find("User:") != std::string::npos || client.response.find('\n') != std::string::npos)) { diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index b6cb2d587..5953928d4 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -70,17 +70,17 @@ int main(int argc, char ** argv) { return 1; } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); // initialize the context llama_context_params ctx_params = common_context_params_to_llama(params); - ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; + ctx_params.n_ctx = llama_model_n_ctx_train(model)*n_grp + n_keep; GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (ctx == NULL) { LOG_ERR("%s: failed to create the llama_context\n" , __func__); return 1; @@ -225,7 +225,7 @@ int main(int argc, char ** argv) { const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); // is it an end of generation? - if (llama_token_is_eog(vocab, new_token_id) || n_cur == n_len) { + if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { LOG("\n"); break; diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 0fc8f0dcc..fa0098004 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -297,10 +297,10 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params // BOS tokens will be added for each chunk before eval const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - const bool add_bos = llama_add_bos_token(vocab); - GGML_ASSERT(!llama_add_eos_token(vocab)); + const bool add_bos = llama_vocab_add_bos(vocab); + GGML_ASSERT(!llama_vocab_add_eos(vocab)); LOG_INF("%s: tokenizing the input ..\n", __func__); @@ -341,7 +341,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_batch = params.n_batch; - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); int count = 0; double nll = 0.0; @@ -385,7 +385,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(vocab); + tokens[batch_start] = llama_vocab_bos(vocab); } const auto * batch_logits = llama_get_logits(ctx); @@ -448,10 +448,10 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & // BOS tokens will be added for each chunk before eval const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - const bool add_bos = llama_add_bos_token(vocab); - GGML_ASSERT(!llama_add_eos_token(vocab)); + const bool add_bos = llama_vocab_add_bos(vocab); + GGML_ASSERT(!llama_vocab_add_eos(vocab)); std::ofstream logits_stream; if (!params.logits_file.empty()) { @@ -491,7 +491,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_batch = params.n_batch; - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); int count = 0; double nll = 0.0; @@ -563,7 +563,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[seq_start] = llama_token_bos(vocab); + tokens[seq_start] = llama_vocab_bos(vocab); } for (int k = 0; k < batch_size; ++k) { @@ -739,7 +739,7 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto static void hellaswag_score(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); // Calculates hellaswag score (acc_norm) from prompt // @@ -857,7 +857,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); @@ -1082,7 +1082,7 @@ static std::vector load_winogrande_from_csv(const std::string */ static void winogrande_score(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); constexpr int k_min_trailing_ctx = 3; @@ -1141,7 +1141,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); const int max_tasks_per_batch = 128; const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); @@ -1386,7 +1386,7 @@ static bool multiple_choice_prepare_one_task(llama_context * ctx, multiple_choic // static void multiple_choice_score(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); std::istringstream strstream(params.prompt); uint32_t n_task; @@ -1495,7 +1495,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); @@ -1669,7 +1669,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par static void kl_divergence(llama_context * ctx, const common_params & params) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); if (params.logits_file.empty()) { LOG_ERR("%s: you must provide a name of a file containing the log probabilities of the base model\n", __func__); @@ -1704,8 +1704,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { LOG_ERR("%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str()); return; } - if (n_vocab != llama_n_vocab(vocab)) { - LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(vocab)); + if (n_vocab != llama_vocab_n_vocab(vocab)) { + LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_vocab_n_vocab(vocab)); } std::vector tokens(size_t(n_ctx) * n_chunk); @@ -1717,8 +1717,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { const int n_batch = params.n_batch; const int num_batches = (n_ctx + n_batch - 1)/n_batch; const int nv = 2*((n_vocab + 1)/2) + 4; - const bool add_bos = llama_add_bos_token(vocab); - GGML_ASSERT(!llama_add_eos_token(vocab)); + const bool add_bos = llama_vocab_add_bos(vocab); + GGML_ASSERT(!llama_vocab_add_eos(vocab)); std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); @@ -1777,7 +1777,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { // add BOS token for the first batch of each chunk if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(vocab); + tokens[batch_start] = llama_vocab_bos(vocab); } common_batch_clear(batch); @@ -2011,7 +2011,7 @@ int main(int argc, char ** argv) { return 1; } - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); if (params.n_ctx > n_ctx_train) { LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", diff --git a/examples/quantize-stats/quantize-stats.cpp b/examples/quantize-stats/quantize-stats.cpp index 9bfbb8862..bd2f73467 100644 --- a/examples/quantize-stats/quantize-stats.cpp +++ b/examples/quantize-stats/quantize-stats.cpp @@ -319,7 +319,7 @@ int main(int argc, char ** argv) { auto cparams = llama_context_default_params(); cparams.n_ctx = 256; - ctx = llama_new_context_with_model(model, cparams); + ctx = llama_init_from_model(model, cparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 3d69f6d0d..2439022a2 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -159,9 +159,9 @@ int main(int argc, char ** argv) { return 1; } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_ctx_train = llama_n_ctx_train(model); + const int n_ctx_train = llama_model_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); @@ -194,8 +194,8 @@ int main(int argc, char ** argv) { return 1; } // add eos if not present - if (llama_token_eos(vocab) >= 0 && (inp.empty() || inp.back() != llama_token_eos(vocab))) { - inp.push_back(llama_token_eos(vocab)); + if (llama_vocab_eos(vocab) >= 0 && (inp.empty() || inp.back() != llama_vocab_eos(vocab))) { + inp.push_back(llama_vocab_eos(vocab)); } chunk.tokens = inp; } @@ -217,7 +217,7 @@ int main(int argc, char ** argv) { struct llama_batch batch = llama_batch_init(n_batch, 0, 1); // allocate output - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embeddings(n_chunks * n_embd, 0); float * emb = embeddings.data(); diff --git a/examples/run/run.cpp b/examples/run/run.cpp index d53d8c07e..bfa8378bb 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -685,7 +685,7 @@ class LlamaData { // Initializes the context with the specified parameters llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { - llama_context_ptr context(llama_new_context_with_model(model.get(), opt.ctx_params)); + llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params)); if (!context) { printe("%s: error: failed to create the llama_context\n", __func__); } @@ -773,7 +773,7 @@ static void print_word_and_concatenate_to_response(const std::string & piece, st // helper function to evaluate a prompt and generate a response static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { - const llama_vocab * vocab = llama_get_vocab(llama_data.model.get()); + const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); std::vector tokens; if (tokenize_prompt(vocab, prompt, tokens) < 0) { @@ -792,7 +792,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str // sample the next token, check is it an end of generation? new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); - if (llama_token_is_eog(vocab, new_token_id)) { + if (llama_vocab_is_eog(vocab, new_token_id)) { break; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index cd03661cf..cf7cbd815 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -97,7 +97,7 @@ int main(int argc, char ** argv) { printf("\n\n"); // make new context - llama_context * ctx2 = llama_new_context_with_model(model, common_context_params_to_llama(params)); + llama_context * ctx2 = llama_init_from_model(model, common_context_params_to_llama(params)); llama_sampler * smpl2 = llama_sampler_chain_init(sparams); @@ -154,7 +154,7 @@ int main(int argc, char ** argv) { } // make new context - llama_context * ctx3 = llama_new_context_with_model(model, common_context_params_to_llama(params)); + llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params)); llama_sampler * smpl3 = llama_sampler_chain_init(sparams); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 7096a883e..6861d745c 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -207,7 +207,7 @@ struct server_task { const common_params & params_base, const json & data) { const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); slot_params params; @@ -331,7 +331,7 @@ struct server_task { const auto & logit_bias = data.find("logit_bias"); if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); for (const auto & el : *logit_bias) { // TODO: we may want to throw errors here, in case "el" is incorrect if (el.is_array() && el.size() == 2) { @@ -1694,12 +1694,12 @@ struct server_context { return false; } - vocab = llama_get_vocab(model); + vocab = llama_model_get_vocab(model); n_ctx = llama_n_ctx(ctx); - add_bos_token = llama_add_bos_token(vocab); - has_eos_token = llama_token_eos(vocab) != LLAMA_TOKEN_NULL; + add_bos_token = llama_vocab_add_bos(vocab); + has_eos_token = llama_vocab_eos(vocab) != LLAMA_TOKEN_NULL; if (!params_base.speculative.model.empty()) { SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str()); @@ -1763,7 +1763,7 @@ struct server_context { if (model_dft) { slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft); + slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { SRV_ERR("%s", "failed to create draft context\n"); return; @@ -1898,7 +1898,7 @@ struct server_context { } if (slot.params.ignore_eos && has_eos_token) { - slot.params.sampling.logit_bias.push_back({llama_token_eos(vocab), -INFINITY}); + slot.params.sampling.logit_bias.push_back({llama_vocab_eos(vocab), -INFINITY}); } { @@ -2054,14 +2054,14 @@ struct server_context { slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx); } - if (llama_token_is_eog(vocab, result.tok)) { + if (llama_vocab_is_eog(vocab, result.tok)) { slot.stop = STOP_TYPE_EOS; slot.has_next_token = false; SLT_DBG(slot, "%s", "stopped by EOS\n"); } - const auto n_ctx_train = llama_n_ctx_train(model); + const auto n_ctx_train = llama_model_n_ctx_train(model); if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) { slot.truncated = true; @@ -2081,7 +2081,7 @@ struct server_context { void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) { size_t n_probs = slot.params.sampling.n_probs; - size_t n_vocab = llama_n_vocab(vocab); + size_t n_vocab = llama_vocab_n_vocab(vocab); if (post_sampling) { const auto * cur_p = common_sampler_get_candidates(slot.smpl); const size_t max_probs = cur_p->size; @@ -2232,7 +2232,7 @@ struct server_context { res->n_tokens = slot.n_prompt_tokens; res->oaicompat = slot.params.oaicompat; - const int n_embd = llama_n_embd(model); + const int n_embd = llama_model_n_embd(model); std::vector embd_res(n_embd, 0.0f); @@ -3136,12 +3136,12 @@ struct server_context { json model_meta() const { return json { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_n_vocab (vocab)}, - {"n_ctx_train", llama_n_ctx_train (model)}, - {"n_embd", llama_n_embd (model)}, - {"n_params", llama_model_n_params(model)}, - {"size", llama_model_size (model)}, + {"vocab_type", llama_vocab_type (vocab)}, + {"n_vocab", llama_vocab_n_vocab (vocab)}, + {"n_ctx_train", llama_model_n_ctx_train(model)}, + {"n_embd", llama_model_n_embd (model)}, + {"n_params", llama_model_n_params (model)}, + {"size", llama_model_size (model)}, }; } }; @@ -3751,13 +3751,13 @@ int main(int argc, char ** argv) { const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { // check model compatibility std::string err; - if (llama_token_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { err += "prefix token is missing. "; } - if (llama_token_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { err += "suffix token is missing. "; } - if (llama_token_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { err += "middle token is missing. "; } if (!err.empty()) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index e373c44be..b365e8302 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -235,12 +235,12 @@ static llama_tokens format_rerank(const struct llama_vocab * vocab, const llama_ llama_tokens result; result.reserve(doc.size() + query.size() + 4); - result.push_back(llama_token_bos(vocab)); + result.push_back(llama_vocab_bos(vocab)); result.insert(result.end(), query.begin(), query.end()); - result.push_back(llama_token_eos(vocab)); - result.push_back(llama_token_sep(vocab)); + result.push_back(llama_vocab_eos(vocab)); + result.push_back(llama_vocab_sep(vocab)); result.insert(result.end(), doc.begin(), doc.end()); - result.push_back(llama_token_eos(vocab)); + result.push_back(llama_vocab_eos(vocab)); return result; } @@ -277,11 +277,11 @@ static llama_tokens format_infill( auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); - if (llama_token_fim_rep(vocab) != LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { // TODO: make project name an input static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); - extra_tokens.push_back(llama_token_fim_rep(vocab)); + extra_tokens.push_back(llama_vocab_fim_rep(vocab)); extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); } for (const auto & chunk : input_extra) { @@ -289,10 +289,10 @@ static llama_tokens format_infill( const std::string text = json_value(chunk, "text", std::string()); const std::string filename = json_value(chunk, "filename", std::string("tmp")); - if (llama_token_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); - extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); } else { // chunk separator in binary form to avoid confusing the AI @@ -306,11 +306,11 @@ static llama_tokens format_infill( extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); } - if (llama_token_fim_sep(vocab) != LLAMA_TOKEN_NULL) { + if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { // TODO: current filename static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); - extra_tokens.insert(extra_tokens.end(), llama_token_fim_sep(vocab)); + extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); } @@ -326,15 +326,15 @@ static llama_tokens format_infill( tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); tokens_suffix.resize(n_suffix_take); - tokens_prefix.insert(tokens_prefix.begin(), llama_token_fim_pre(vocab)); + tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); - tokens_suffix.insert(tokens_suffix.begin(), llama_token_fim_suf(vocab)); + tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; - if (llama_add_bos_token(vocab)) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(vocab)); + if (llama_vocab_add_bos(vocab)) { + embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); } SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); @@ -343,7 +343,7 @@ static llama_tokens format_infill( embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - embd_inp.push_back(llama_token_fim_mid(vocab)); + embd_inp.push_back(llama_vocab_fim_mid(vocab)); return embd_inp; } @@ -774,9 +774,9 @@ static std::vector get_token_probabilities(llama_context * ctx const auto * logits = llama_get_logits_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); cur.resize(n_vocab); for (llama_token token_id = 0; token_id < n_vocab; token_id++) { diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 3888b9d43..e8eda9c22 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -75,14 +75,14 @@ int main(int argc, char ** argv) { return 1; } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); // initialize the context llama_context_params ctx_params = llama_context_default_params(); ctx_params.n_ctx = n_ctx; ctx_params.n_batch = n_ctx; - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (!ctx) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); return 1; @@ -126,7 +126,7 @@ int main(int argc, char ** argv) { new_token_id = llama_sampler_sample(smpl, ctx, -1); // is it an end of generation? - if (llama_token_is_eog(vocab, new_token_id)) { + if (llama_vocab_is_eog(vocab, new_token_id)) { break; } diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index c38386004..10e79a0a6 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -84,7 +84,7 @@ int main(int argc, char ** argv) { model_params.n_gpu_layers = ngl; llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); @@ -113,7 +113,7 @@ int main(int argc, char ** argv) { // enable performance counters ctx_params.no_perf = false; - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (ctx == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); @@ -165,7 +165,7 @@ int main(int argc, char ** argv) { new_token_id = llama_sampler_sample(smpl, ctx, -1); // is it an end of generation? - if (llama_token_is_eog(vocab, new_token_id)) { + if (llama_vocab_is_eog(vocab, new_token_id)) { break; } diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 550143cd6..403ba2dd2 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -45,7 +45,7 @@ int main(int argc, char ** argv) { model_tgt = llama_init_tgt.model.get(); ctx_tgt = llama_init_tgt.context.get(); - const llama_vocab * vocab = llama_get_vocab(model_tgt); + const llama_vocab * vocab = llama_model_get_vocab(model_tgt); // load the draft model params.devices = params.speculative.devices; @@ -198,7 +198,7 @@ int main(int argc, char ** argv) { id_last = ids[i]; - if (llama_token_is_eog(vocab, id_last)) { + if (llama_vocab_is_eog(vocab, id_last)) { has_eos = true; break; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 4ace5b758..db791245c 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -90,8 +90,8 @@ int main(int argc, char ** argv) { model_dft = llama_init_dft.model.get(); ctx_dft = llama_init_dft.context.get(); - const llama_vocab * vocab_tgt = llama_get_vocab(model_tgt); - const llama_vocab * vocab_dft = llama_get_vocab(model_dft); + const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); @@ -106,18 +106,18 @@ int main(int argc, char ** argv) { } if ( - llama_add_bos_token(vocab_tgt) != llama_add_bos_token(vocab_dft) || - llama_add_eos_token(vocab_tgt) != llama_add_eos_token(vocab_dft) || - llama_token_bos(vocab_tgt) != llama_token_bos(vocab_dft) || - llama_token_eos(vocab_tgt) != llama_token_eos(vocab_dft) + llama_vocab_add_bos(vocab_tgt) != llama_vocab_add_bos(vocab_dft) || + llama_vocab_add_eos(vocab_tgt) != llama_vocab_add_eos(vocab_dft) || + llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft) || + llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft) ) { LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); return 1; } { - const int n_vocab_tgt = llama_n_vocab(vocab_tgt); - const int n_vocab_dft = llama_n_vocab(vocab_dft); + const int n_vocab_tgt = llama_vocab_n_vocab(vocab_tgt); + const int n_vocab_dft = llama_vocab_n_vocab(vocab_dft); const int vocab_diff = n_vocab_tgt > n_vocab_dft ? n_vocab_tgt - n_vocab_dft : n_vocab_dft - n_vocab_tgt; @@ -125,13 +125,13 @@ int main(int argc, char ** argv) { if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__); LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", - n_vocab_tgt, llama_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + n_vocab_tgt, llama_vocab_n_vocab(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return 1; } for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_token_get_text(vocab_tgt, i); - const char * token_text_dft = llama_token_get_text(vocab_dft, i); + const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); + const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__); LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i, @@ -173,7 +173,7 @@ int main(int argc, char ** argv) { const auto t_enc_end = ggml_time_us(); // the 2 models should have the same vocab - //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft)); + //GGML_ASSERT(n_vocab == llama_vocab_n_vocab(model_dft)); // how many tokens to draft each time int n_draft = params.speculative.n_max; @@ -389,7 +389,7 @@ int main(int argc, char ** argv) { } } - if (llama_token_is_eog(vocab_tgt, token_id)) { + if (llama_vocab_is_eog(vocab_tgt, token_id)) { has_eos = true; } ++n_predict; diff --git a/examples/tokenize/tokenize.cpp b/examples/tokenize/tokenize.cpp index 9d3d8233a..dc7bb4869 100644 --- a/examples/tokenize/tokenize.cpp +++ b/examples/tokenize/tokenize.cpp @@ -344,10 +344,10 @@ int main(int raw_argc, char ** raw_argv) { return 1; } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); llama_context_params ctx_params = llama_context_default_params(); - llama_context * ctx = llama_new_context_with_model(model, ctx_params); + llama_context * ctx = llama_init_from_model(model, ctx_params); if (!ctx) { fprintf(stderr, "Error: could not create context.\n"); return 1; @@ -367,7 +367,7 @@ int main(int raw_argc, char ** raw_argv) { prompt = stdin_buffer.str(); } - const bool model_wants_add_bos = llama_add_bos_token(vocab); + const bool model_wants_add_bos = llama_vocab_add_bos(vocab); const bool add_bos = model_wants_add_bos && !no_bos; const bool parse_special = !no_parse_special; const bool escape = !no_escape; diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index d1155dad5..5a9161181 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -462,7 +462,7 @@ int main(int argc, char ** argv) { model_ttc = llama_init_ttc.model.get(); ctx_ttc = llama_init_ttc.context.get(); - const llama_vocab * vocab = llama_get_vocab(model_ttc); + const llama_vocab * vocab = llama_model_get_vocab(model_ttc); // TODO: refactor in a common struct params.model = params.vocoder.model; @@ -737,9 +737,9 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 const auto * cands = common_sampler_get_candidates(smpl[i]); // is it an end of generation? -> mark the stream as finished - if (llama_token_is_eog(vocab, new_token_id) || n_decode == n_predict) { + if (llama_vocab_is_eog(vocab, new_token_id) || n_decode == n_predict) { std::string reason; - if (llama_token_is_eog(vocab, new_token_id)) { + if (llama_vocab_is_eog(vocab, new_token_id)) { reason = "eos"; } else { reason = "n_predict"; @@ -875,7 +875,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 #if 1 // spectral operations - const int n_embd = llama_n_embd(model_cts); + const int n_embd = llama_model_n_embd(model_cts); const float * embd = llama_get_embeddings(ctx_cts); auto audio = embd_to_audio(embd, n_codes, n_embd, params.cpuparams.n_threads); diff --git a/include/llama.h b/include/llama.h index 302df87e9..2e8c0e94d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -399,18 +399,19 @@ extern "C" { // Call once at the start of the program LLAMA_API void llama_backend_init(void); + // Call once at the end of the program - currently only used for MPI + LLAMA_API void llama_backend_free(void); + //optional: LLAMA_API void llama_numa_init(enum ggml_numa_strategy numa); // Optional: an auto threadpool gets created in ggml if not passed explicitly LLAMA_API void llama_attach_threadpool( - struct llama_context * ctx, - ggml_threadpool_t threadpool, - ggml_threadpool_t threadpool_batch); - LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); + struct llama_context * ctx, + ggml_threadpool_t threadpool, + ggml_threadpool_t threadpool_batch); - // Call once at the end of the program - currently only used for MPI - LLAMA_API void llama_backend_free(void); + LLAMA_API void llama_detach_threadpool(struct llama_context * ctx); DEPRECATED(LLAMA_API struct llama_model * llama_load_model_from_file( const char * path_model, @@ -426,11 +427,15 @@ extern "C" { LLAMA_API void llama_model_free(struct llama_model * model); - // TODO: rename to llama_init_from_model - LLAMA_API struct llama_context * llama_new_context_with_model( + LLAMA_API struct llama_context * llama_init_from_model( struct llama_model * model, struct llama_context_params params); + DEPRECATED(LLAMA_API struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params), + "use llama_init_from_model instead"); + // Frees all allocated memory LLAMA_API void llama_free(struct llama_context * ctx); @@ -448,22 +453,30 @@ extern "C" { LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model); - LLAMA_API int32_t llama_n_embd (const struct llama_model * model); - LLAMA_API int32_t llama_n_layer (const struct llama_model * model); - LLAMA_API int32_t llama_n_head (const struct llama_model * model); + DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead"); + DEPRECATED(LLAMA_API int32_t llama_n_embd (const struct llama_model * model), "use llama_model_n_embd instead"); + DEPRECATED(LLAMA_API int32_t llama_n_layer (const struct llama_model * model), "use llama_model_n_layer instead"); + DEPRECATED(LLAMA_API int32_t llama_n_head (const struct llama_model * model), "use llama_model_n_head instead"); - LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab); + DEPRECATED(LLAMA_API int32_t llama_n_vocab (const struct llama_vocab * vocab), "use llama_vocab_n_vocab instead"); - LLAMA_API const struct llama_model * llama_get_model(const struct llama_context * ctx); - LLAMA_API const struct llama_vocab * llama_get_vocab(const struct llama_model * model); + LLAMA_API const struct llama_model * llama_get_model (const struct llama_context * ctx); + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); - LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_vocab * vocab); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); + + LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); // Get the model's RoPE frequency scaling factor - LLAMA_API float llama_rope_freq_scale_train(const struct llama_model * model); + LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); + + LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab); + + LLAMA_API int32_t llama_vocab_n_vocab(const struct llama_vocab * vocab); // Functions to access the model's GGUF metadata scalar values // - The functions return the length of the string on success, or -1 on failure @@ -908,41 +921,57 @@ extern "C" { // Vocab // - LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token); + LLAMA_API const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token); - LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token); + LLAMA_API float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token); - LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token); + LLAMA_API enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token); // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) - LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token); + LLAMA_API bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token); // Identify if Token Id is a control token or a render-able token - LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token); + LLAMA_API bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token); // Special tokens - LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab); // beginning-of-sentence - LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab); // end-of-sentence - LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab); // end-of-turn - LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab); // classification - LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab); // sentence separator - LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab); // next-line - LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab); // padding + LLAMA_API llama_token llama_vocab_bos(const struct llama_vocab * vocab); // beginning-of-sentence + LLAMA_API llama_token llama_vocab_eos(const struct llama_vocab * vocab); // end-of-sentence + LLAMA_API llama_token llama_vocab_eot(const struct llama_vocab * vocab); // end-of-turn + LLAMA_API llama_token llama_vocab_cls(const struct llama_vocab * vocab); // classification + LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator + LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line + LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding - LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab); - LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab); + LLAMA_API bool llama_vocab_add_bos(const struct llama_vocab * vocab); + LLAMA_API bool llama_vocab_add_eos(const struct llama_vocab * vocab); - // infill tokens - DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_vocab * vocab), "use llama_token_fim_pre instead"); - DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_vocab * vocab), "use llama_token_fim_mid instead"); - DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_vocab * vocab), "use llama_token_fim_suf instead"); + LLAMA_API llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab); + LLAMA_API llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab); - LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab); - LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab); - LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab); - LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab); - LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab); - LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab); + DEPRECATED(LLAMA_API const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token), "use llama_vocabable_get_text instead"); + DEPRECATED(LLAMA_API float llama_token_get_score(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_score instead"); + DEPRECATED(LLAMA_API enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_get_attr instead"); + DEPRECATED(LLAMA_API bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_eog instead"); + DEPRECATED(LLAMA_API bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token), "use llama_vocab_is_control instead"); + DEPRECATED(LLAMA_API llama_token llama_token_bos(const struct llama_vocab * vocab), "use llama_vocab_bos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_eos(const struct llama_vocab * vocab), "use llama_vocab_eos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_eot(const struct llama_vocab * vocab), "use llama_vocab_eot instead"); + DEPRECATED(LLAMA_API llama_token llama_token_cls(const struct llama_vocab * vocab), "use llama_vocab_cls instead"); + DEPRECATED(LLAMA_API llama_token llama_token_sep(const struct llama_vocab * vocab), "use llama_vocab_sep instead"); + DEPRECATED(LLAMA_API llama_token llama_token_nl (const struct llama_vocab * vocab), "use llama_vocab_nl instead"); + DEPRECATED(LLAMA_API llama_token llama_token_pad(const struct llama_vocab * vocab), "use llama_vocab_pad instead"); + DEPRECATED(LLAMA_API bool llama_add_bos_token(const struct llama_vocab * vocab), "use llama_vocab_add_bos instead"); + DEPRECATED(LLAMA_API bool llama_add_eos_token(const struct llama_vocab * vocab), "use llama_vocab_add_eos instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_pre(const struct llama_vocab * vocab), "use llama_vocab_fim_pre instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_suf(const struct llama_vocab * vocab), "use llama_vocab_fim_suf instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_mid(const struct llama_vocab * vocab), "use llama_vocab_fim_mid instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_pad(const struct llama_vocab * vocab), "use llama_vocab_fim_pad instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_rep(const struct llama_vocab * vocab), "use llama_vocab_fim_rep instead"); + DEPRECATED(LLAMA_API llama_token llama_token_fim_sep(const struct llama_vocab * vocab), "use llama_vocab_fim_sep instead"); // // Tokenization diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b81ed9437..27afbba03 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -1244,7 +1244,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.use_alibi = true; } - hparams.rope_type = llama_rope_type(this); + hparams.rope_type = llama_model_rope_type(this); } void llama_model::load_vocab(llama_model_loader & ml) { @@ -3735,7 +3735,7 @@ struct llama_model_params llama_model_default_params() { return result; } -const struct llama_vocab * llama_get_vocab(const struct llama_model * model) { +const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model) { return &model->vocab; } @@ -3747,23 +3747,43 @@ void llama_model_free(struct llama_model * model) { delete model; } -int32_t llama_n_ctx_train(const struct llama_model * model) { +int32_t llama_model_n_ctx_train(const struct llama_model * model) { return model->hparams.n_ctx_train; } -int32_t llama_n_embd(const struct llama_model * model) { +int32_t llama_model_n_embd(const struct llama_model * model) { return model->hparams.n_embd; } -int32_t llama_n_layer(const struct llama_model * model) { +int32_t llama_model_n_layer(const struct llama_model * model) { return model->hparams.n_layer; } -int32_t llama_n_head(const struct llama_model * model) { +int32_t llama_model_n_head(const struct llama_model * model) { return model->hparams.n_head(); } -enum llama_rope_type llama_rope_type(const struct llama_model * model) { +// deprecated +int32_t llama_n_ctx_train(const struct llama_model * model) { + return llama_model_n_ctx_train(model); +} + +// deprecated +int32_t llama_n_embd(const struct llama_model * model) { + return llama_model_n_embd(model); +} + +// deprecated +int32_t llama_n_layer(const struct llama_model * model) { + return llama_model_n_layer(model); +} + +// deprecated +int32_t llama_n_head(const struct llama_model * model) { + return llama_model_n_head(model); +} + +enum llama_rope_type llama_model_rope_type(const struct llama_model * model) { switch (model->arch) { // these models do not use RoPE case LLM_ARCH_GPT2: @@ -3841,7 +3861,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { return LLAMA_ROPE_TYPE_NONE; } -float llama_rope_freq_scale_train(const struct llama_model * model) { +float llama_model_rope_freq_scale_train(const struct llama_model * model) { return model->hparams.rope_freq_scale_train; } diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e1faa48e8..8775cf5e3 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -372,9 +372,9 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte const auto * logits = llama_get_logits_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); // TODO: do not allocate each time std::vector cur; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 425b9e83f..e5d9d29d4 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -3026,106 +3026,199 @@ void llama_vocab::print_info() const { // interface implementation // -int32_t llama_n_vocab(const struct llama_vocab * vocab) { +int32_t llama_vocab_n_vocab(const struct llama_vocab * vocab) { return vocab->n_vocab(); } +// deprecated +int32_t llama_n_vocab(const struct llama_vocab * vocab) { + return llama_vocab_n_vocab(vocab); +} + enum llama_vocab_type llama_vocab_type(const struct llama_vocab * vocab) { return vocab->get_type(); } -const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) { +const char * llama_vocab_get_text(const struct llama_vocab * vocab, llama_token token) { return vocab->token_get_text(token); } -float llama_token_get_score(const struct llama_vocab * vocab, llama_token token) { +float llama_vocab_get_score(const struct llama_vocab * vocab, llama_token token) { return vocab->token_get_score(token); } -enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token) { +enum llama_token_attr llama_vocab_get_attr(const struct llama_vocab * vocab, llama_token token) { return vocab->token_get_attr(token); } -bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token) { +bool llama_vocab_is_eog(const struct llama_vocab * vocab, llama_token token) { return vocab->is_eog(token); } -bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token) { +bool llama_vocab_is_control(const struct llama_vocab * vocab, llama_token token) { return vocab->is_control(token); } -llama_token llama_token_bos(const struct llama_vocab * vocab) { +llama_token llama_vocab_bos(const struct llama_vocab * vocab) { return vocab->token_bos(); } -llama_token llama_token_eos(const struct llama_vocab * vocab) { +llama_token llama_vocab_eos(const struct llama_vocab * vocab) { return vocab->token_eos(); } -llama_token llama_token_eot(const struct llama_vocab * vocab) { +llama_token llama_vocab_eot(const struct llama_vocab * vocab) { return vocab->token_eot(); } -llama_token llama_token_cls(const struct llama_vocab * vocab) { +llama_token llama_vocab_cls(const struct llama_vocab * vocab) { return vocab->token_cls(); } -llama_token llama_token_sep(const struct llama_vocab * vocab) { +llama_token llama_vocab_sep(const struct llama_vocab * vocab) { return vocab->token_sep(); } -llama_token llama_token_nl (const struct llama_vocab * vocab) { +llama_token llama_vocab_nl (const struct llama_vocab * vocab) { return vocab->token_nl(); } -llama_token llama_token_pad(const struct llama_vocab * vocab) { +llama_token llama_vocab_pad(const struct llama_vocab * vocab) { return vocab->token_pad(); } -bool llama_add_bos_token(const struct llama_vocab * vocab) { +bool llama_vocab_add_bos(const struct llama_vocab * vocab) { return vocab->add_bos_token(); } -bool llama_add_eos_token(const struct llama_vocab * vocab) { +bool llama_vocab_add_eos(const struct llama_vocab * vocab) { return vocab->add_eos_token(); } -llama_token llama_token_prefix(const struct llama_vocab * vocab) { - return vocab->token_prefix(); -} - -llama_token llama_token_middle(const struct llama_vocab * vocab) { - return vocab->token_middle(); -} - -llama_token llama_token_suffix(const struct llama_vocab * vocab) { - return vocab->token_suffix(); -} - -llama_token llama_token_fim_pre(const struct llama_vocab * vocab) { +llama_token llama_vocab_fim_pre(const struct llama_vocab * vocab) { return vocab->token_fim_pre(); } -llama_token llama_token_fim_suf(const struct llama_vocab * vocab) { +llama_token llama_vocab_fim_suf(const struct llama_vocab * vocab) { return vocab->token_fim_suf(); } -llama_token llama_token_fim_mid(const struct llama_vocab * vocab) { +llama_token llama_vocab_fim_mid(const struct llama_vocab * vocab) { return vocab->token_fim_mid(); } -llama_token llama_token_fim_pad(const struct llama_vocab * vocab) { +llama_token llama_vocab_fim_pad(const struct llama_vocab * vocab) { return vocab->token_fim_pad(); } -llama_token llama_token_fim_rep(const struct llama_vocab * vocab) { +llama_token llama_vocab_fim_rep(const struct llama_vocab * vocab) { return vocab->token_fim_rep(); } -llama_token llama_token_fim_sep(const struct llama_vocab * vocab) { +llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) { return vocab->token_fim_sep(); } +// deprecated +const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) { + return llama_vocab_get_text(vocab, token); +} + +// deprecated +float llama_token_get_score(const struct llama_vocab * vocab, llama_token token) { + return llama_vocab_get_score(vocab, token); +} + +// deprecated +enum llama_token_attr llama_token_get_attr(const struct llama_vocab * vocab, llama_token token) { + return llama_vocab_get_attr(vocab, token); +} + +// deprecated +bool llama_token_is_eog(const struct llama_vocab * vocab, llama_token token) { + return llama_vocab_is_eog(vocab, token); +} + +// deprecated +bool llama_token_is_control(const struct llama_vocab * vocab, llama_token token) { + return llama_vocab_is_control(vocab, token); +} + +// deprecated +llama_token llama_token_bos(const struct llama_vocab * vocab) { + return llama_vocab_bos(vocab); +} + +// deprecated +llama_token llama_token_eos(const struct llama_vocab * vocab) { + return llama_vocab_eos(vocab); +} + +// deprecated +llama_token llama_token_eot(const struct llama_vocab * vocab) { + return llama_vocab_eot(vocab); +} + +// deprecated +llama_token llama_token_cls(const struct llama_vocab * vocab) { + return llama_vocab_cls(vocab); +} + +// deprecated +llama_token llama_token_sep(const struct llama_vocab * vocab) { + return llama_vocab_sep(vocab); +} + +// deprecated +llama_token llama_token_nl (const struct llama_vocab * vocab) { + return llama_vocab_nl(vocab); +} + +// deprecated +llama_token llama_token_pad(const struct llama_vocab * vocab) { + return llama_vocab_pad(vocab); +} + +// deprecated +bool llama_add_bos_token(const struct llama_vocab * vocab) { + return llama_vocab_add_bos(vocab); +} + +// deprecated +bool llama_add_eos_token(const struct llama_vocab * vocab) { + return llama_vocab_add_eos(vocab); +} + +// deprecated +llama_token llama_token_fim_pre(const struct llama_vocab * vocab) { + return llama_vocab_fim_pre(vocab); +} + +// deprecated +llama_token llama_token_fim_suf(const struct llama_vocab * vocab) { + return llama_vocab_fim_suf(vocab); +} + +// deprecated +llama_token llama_token_fim_mid(const struct llama_vocab * vocab) { + return llama_vocab_fim_mid(vocab); +} + +// deprecated +llama_token llama_token_fim_pad(const struct llama_vocab * vocab) { + return llama_vocab_fim_pad(vocab); +} + +// deprecated +llama_token llama_token_fim_rep(const struct llama_vocab * vocab) { + return llama_vocab_fim_rep(vocab); +} + +// deprecated +llama_token llama_token_fim_sep(const struct llama_vocab * vocab) { + return llama_vocab_fim_sep(vocab); +} + // // tokenization // diff --git a/src/llama.cpp b/src/llama.cpp index a1c62eb31..fa8dff09d 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9501,7 +9501,7 @@ struct llama_model * llama_model_load_from_file( return model; } -struct llama_context * llama_new_context_with_model( +struct llama_context * llama_init_from_model( struct llama_model * model, struct llama_context_params params) { @@ -9852,6 +9852,12 @@ struct llama_context * llama_new_context_with_model( return ctx; } +struct llama_context * llama_new_context_with_model( + struct llama_model * model, + struct llama_context_params params) { + return llama_init_from_model(model, params); +} + // // kv cache // diff --git a/tests/test-autorelease.cpp b/tests/test-autorelease.cpp index ba084a91a..35b09aaea 100644 --- a/tests/test-autorelease.cpp +++ b/tests/test-autorelease.cpp @@ -14,7 +14,7 @@ int main(int argc, char ** argv) { std::thread([&model_path]() { llama_backend_init(); auto * model = llama_model_load_from_file(model_path, llama_model_default_params()); - auto * ctx = llama_new_context_with_model(model, llama_context_default_params()); + auto * ctx = llama_init_from_model(model, llama_context_default_params()); llama_free(ctx); llama_model_free(model); llama_backend_free(); diff --git a/tests/test-tokenizer-0.cpp b/tests/test-tokenizer-0.cpp index 121c2c60c..59dda4877 100644 --- a/tests/test-tokenizer-0.cpp +++ b/tests/test-tokenizer-0.cpp @@ -161,7 +161,7 @@ int main(int argc, char **argv) { auto cparams = llama_context_default_params(); - ctx = llama_new_context_with_model(model, cparams); + ctx = llama_init_from_model(model, cparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); diff --git a/tests/test-tokenizer-1-bpe.cpp b/tests/test-tokenizer-1-bpe.cpp index 8a2d88d72..9360a061e 100644 --- a/tests/test-tokenizer-1-bpe.cpp +++ b/tests/test-tokenizer-1-bpe.cpp @@ -55,7 +55,7 @@ int main(int argc, char **argv) { auto cparams = llama_context_default_params(); - ctx = llama_new_context_with_model(model, cparams); + ctx = llama_init_from_model(model, cparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); @@ -64,7 +64,7 @@ int main(int argc, char **argv) { } } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); //GGML_ASSERT(llama_vocab_type(vocab) == LLAMA_VOCAB_TYPE_BPE); if (llama_vocab_type(vocab) != LLAMA_VOCAB_TYPE_BPE) { @@ -77,7 +77,7 @@ int main(int argc, char **argv) { atexit([]() { console::cleanup(); }); #endif - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); for (int i = 0; i < n_vocab; ++i) { std::string str = common_detokenize(ctx, std::vector(1, i)); diff --git a/tests/test-tokenizer-1-spm.cpp b/tests/test-tokenizer-1-spm.cpp index 633d284f6..da84308f2 100644 --- a/tests/test-tokenizer-1-spm.cpp +++ b/tests/test-tokenizer-1-spm.cpp @@ -43,7 +43,7 @@ int main(int argc, char ** argv) { auto cparams = llama_context_default_params(); - ctx = llama_new_context_with_model(model, cparams); + ctx = llama_init_from_model(model, cparams); if (ctx == NULL) { fprintf(stderr, "%s: error: failed to load vocab '%s'\n", __func__, fname.c_str()); @@ -52,7 +52,7 @@ int main(int argc, char ** argv) { } } - const llama_vocab * vocab = llama_get_vocab(model); + const llama_vocab * vocab = llama_model_get_vocab(model); //GGML_ASSERT(llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM); if (llama_vocab_type(vocab) != LLAMA_VOCAB_TYPE_SPM) { @@ -65,7 +65,7 @@ int main(int argc, char ** argv) { atexit([]() { console::cleanup(); }); #endif - const int n_vocab = llama_n_vocab(vocab); + const int n_vocab = llama_vocab_n_vocab(vocab); for (int i = 0; i < n_vocab; ++i) { std::string str = common_detokenize(ctx, std::vector(1, i), true);