From 30e70334f71b3bd115024affcf98cac3d79aaa95 Mon Sep 17 00:00:00 2001 From: "k.h.lai" Date: Mon, 13 May 2024 22:02:36 +0800 Subject: [PATCH] llava-cli: fix base64 prompt (#7248) --- examples/llava/llava-cli.cpp | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index da60ddf2f..a6d67e5d7 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -300,14 +300,10 @@ int main(int argc, char ** argv) { return 1; } - for (auto & image : params.image) { + if (prompt_contains_image(params.prompt)) { auto ctx_llava = llava_init_context(¶ms, model); - auto image_embed = load_image(ctx_llava, ¶ms, image); - if (!image_embed) { - std::cerr << "error: failed to load image " << image << ". Terminating\n\n"; - return 1; - } + auto image_embed = load_image(ctx_llava, ¶ms, ""); // process the prompt process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); @@ -316,7 +312,26 @@ int main(int argc, char ** argv) { llava_image_embed_free(image_embed); ctx_llava->model = NULL; llava_free(ctx_llava); + } else { + for (auto & image : params.image) { + auto ctx_llava = llava_init_context(¶ms, model); + + auto image_embed = load_image(ctx_llava, ¶ms, image); + if (!image_embed) { + std::cerr << "error: failed to load image " << image << ". Terminating\n\n"; + return 1; + } + + // process the prompt + process_prompt(ctx_llava, image_embed, ¶ms, params.prompt); + + llama_print_timings(ctx_llava->ctx_llama); + llava_image_embed_free(image_embed); + ctx_llava->model = NULL; + llava_free(ctx_llava); + } } + llama_free_model(model); return 0;