diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 72a025077..dedaa34fd 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -986,7 +986,12 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); // warmup run - test_gen(ctx, 1, 0, t.n_threads); + if (t.n_prompt > 0) { + test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads); + } + if (t.n_gen > 0) { + test_gen(ctx, 1, 0, t.n_threads); + } for (int i = 0; i < params.reps; i++) { uint64_t t_start = get_time_ns();