perplexity: avoid unnecessary alloocations and logit copies (#5035)

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow 2024-01-19 11:02:39 +02:00 committed by GitHub
parent 8b20858e5e
commit 993fba8180
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -325,6 +325,13 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double nll = 0.0; double nll = 0.0;
double nll2 = 0.0; double nll2 = 0.0;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
if (num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab);
}
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1); std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
@ -333,10 +340,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const int start = i * n_ctx; const int start = i * n_ctx;
const int end = start + n_ctx; const int end = start + n_ctx;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
std::vector<float> logits;
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache // clear the KV cache
@ -362,9 +365,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// restore the original token in case it was set to BOS // restore the original token in case it was set to BOS
tokens[batch_start] = token_org; tokens[batch_start] = token_org;
if (num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx); const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
} }
}
const auto t_end = std::chrono::high_resolution_clock::now(); const auto t_end = std::chrono::high_resolution_clock::now();
@ -392,7 +397,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// last 256 tokens. Then, we split the input up into context window size chunks to // last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt. // process the entire prompt.
const int first = n_ctx/2; const int first = n_ctx/2;
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1; count += n_ctx - first - 1;
@ -406,6 +412,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
} }
fflush(stdout); fflush(stdout);
logits.clear();
} }
printf("\n"); printf("\n");