From 32a392fe6878e40fd02c2ba0c72163b252e68b75 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Fri, 19 Jan 2024 17:10:23 -0500 Subject: [PATCH] try a differerent fix --- examples/perplexity/perplexity.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index b86a688e1..f91f5795a 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -458,23 +458,24 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector< return true; } +#define K_TOKEN_CHUNK 4 + static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector& workers, const std::vector>& eval_pairs, std::vector& eval_results) { - constexpr int k_token_chunk = 4; if (eval_results.size() != eval_pairs.size()) { eval_results.resize(eval_pairs.size()); } if (eval_pairs.empty()) return; - size_t max_threads = std::min((eval_pairs.size() + k_token_chunk - 1)/k_token_chunk, workers.size()); + size_t max_threads = std::min((eval_pairs.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK, workers.size()); std::atomic counter(0); - auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab, k_token_chunk] () { - float local_logprobs[k_token_chunk]; + auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () { + float local_logprobs[K_TOKEN_CHUNK]; while (true) { - size_t first = counter.fetch_add(k_token_chunk, std::memory_order_relaxed); + size_t first = counter.fetch_add(K_TOKEN_CHUNK, std::memory_order_relaxed); if (first >= eval_results.size()) break; - size_t last = std::min(first + k_token_chunk, eval_results.size()); + size_t last = std::min(first + K_TOKEN_CHUNK, eval_results.size()); for (size_t i = first; i < last; ++i) { auto logits = batch_logits + eval_pairs[i].first * n_vocab; float max_logit = logits[0]; @@ -497,7 +498,6 @@ static void compute_logprobs(const float * batch_logits, int n_vocab, std::vecto for (size_t it = 0; it < max_threads; ++it) { workers[it].join(); } - } static void hellaswag_score(llama_context * ctx, const gpt_params & params) {