diff --git a/examples/common.cpp b/examples/common.cpp index 2bf0dc597..9b23b1f63 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -405,6 +405,37 @@ std::vector llama_tokenize(struct llama_context * ctx, const std::s return res; } +struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { + auto lparams = llama_context_default_params(); + + lparams.n_ctx = params.n_ctx; + lparams.n_parts = params.n_parts; + lparams.seed = params.seed; + lparams.f16_kv = params.memory_f16; + lparams.use_mmap = params.use_mmap; + lparams.use_mlock = params.use_mlock; + + llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams); + + if (lctx == NULL) { + fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); + return NULL; + } + + if (!params.lora_adapter.empty()) { + int err = llama_apply_lora_from_file(lctx, + params.lora_adapter.c_str(), + params.lora_base.empty() ? NULL : params.lora_base.c_str(), + params.n_threads); + if (err != 0) { + fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); + return NULL; + } + } + + return lctx; +} + /* Keep track of current color of output, and emit ANSI code if it changes. */ void set_console_color(console_state & con_st, console_color_t color) { if (con_st.use_color && con_st.color != color) { diff --git a/examples/common.h b/examples/common.h index 627696e30..138d0ded0 100644 --- a/examples/common.h +++ b/examples/common.h @@ -77,6 +77,12 @@ std::string gpt_random_prompt(std::mt19937 & rng); std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos); +// +// Model utils +// + +struct llama_context * llama_init_from_gpt_params(const gpt_params & params); + // // Console utils // diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 1e9d8a8ce..e4b729128 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -35,24 +35,10 @@ int main(int argc, char ** argv) { llama_context * ctx; // load the model - { - auto lparams = llama_context_default_params(); - - lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.logits_all = params.perplexity; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.embedding = params.embedding; - - ctx = llama_init_from_file(params.model.c_str(), lparams); - - if (ctx == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return 1; - } + ctx = llama_init_from_gpt_params(params); + if (ctx == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return 1; } // print system information diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 54836b365..a10256abf 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -101,34 +101,11 @@ int main(int argc, char ** argv) { llama_context * ctx; g_ctx = &ctx; - // load the model - { - auto lparams = llama_context_default_params(); - - lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - - ctx = llama_init_from_file(params.model.c_str(), lparams); - - if (ctx == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return 1; - } - } - - if (!params.lora_adapter.empty()) { - int err = llama_apply_lora_from_file(ctx, - params.lora_adapter.c_str(), - params.lora_base.empty() ? NULL : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { - fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - return 1; - } + // load the model and apply lora adapter, if any + ctx = llama_init_from_gpt_params(params); + if (ctx == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return 1; } // print system information diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index d474bc50f..299a19999 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -122,36 +122,11 @@ int main(int argc, char ** argv) { llama_context * ctx; - // load the model - { - auto lparams = llama_context_default_params(); - - lparams.n_ctx = params.n_ctx; - lparams.n_parts = params.n_parts; - lparams.seed = params.seed; - lparams.f16_kv = params.memory_f16; - lparams.logits_all = params.perplexity; - lparams.use_mmap = params.use_mmap; - lparams.use_mlock = params.use_mlock; - lparams.embedding = params.embedding; - - ctx = llama_init_from_file(params.model.c_str(), lparams); - - if (ctx == NULL) { - fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, params.model.c_str()); - return 1; - } - } - - if (!params.lora_adapter.empty()) { - int err = llama_apply_lora_from_file(ctx, - params.lora_adapter.c_str(), - params.lora_base.empty() ? NULL : params.lora_base.c_str(), - params.n_threads); - if (err != 0) { - fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__); - return 1; - } + // load the model and apply lora adapter, if any + ctx = llama_init_from_gpt_params(params); + if (ctx == NULL) { + fprintf(stderr, "%s: error: unable to load model\n", __func__); + return 1; } // print system information