From 5f2d4e60e202aabee10051e6615bb821e51787be Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 3 Jul 2024 19:33:31 +0200 Subject: [PATCH] ppl : fix n_seq_max for perplexity (#8277) * ppl : fix n_seq_max for perplexity * use 1 seq for kl_divergence --- examples/perplexity/perplexity.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index efde8dfdf..dbe445391 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1991,6 +1991,12 @@ int main(int argc, char ** argv) { params.n_batch = std::min(params.n_batch, n_kv); } else { params.n_batch = std::min(params.n_batch, params.n_ctx); + if (params.kl_divergence) { + params.n_parallel = 1; + } else { + // ensure there's at least enough seq_ids for HellaSwag + params.n_parallel = std::max(4, params.n_parallel); + } } if (params.ppl_stride > 0) { @@ -2015,9 +2021,6 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; - // ensure there's at least enough seq_ids for HellaSwag - params.n_parallel = std::max(4, params.n_parallel); - // load the model and apply lora adapter, if any std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == NULL) {