mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 18:34:36 +00:00
Faster perplexity computation (#2786)
Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
c82742ac9c
commit
d046dcee08
@ -6,6 +6,8 @@
|
|||||||
#include <ctime>
|
#include <ctime>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
#include <thread>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||||
@ -27,6 +29,40 @@ std::vector<float> softmax(const std::vector<float>& logits) {
|
|||||||
return probs;
|
return probs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
float log_softmax(int n_vocab, const float * logits, int tok) {
|
||||||
|
float max_logit = logits[0];
|
||||||
|
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
|
||||||
|
double sum_exp = 0.0;
|
||||||
|
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
|
||||||
|
return logits[tok] - max_logit - log(sum_exp);
|
||||||
|
}
|
||||||
|
|
||||||
|
void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread>& workers,
|
||||||
|
double& nll, double& nll2) {
|
||||||
|
|
||||||
|
std::mutex mutex;
|
||||||
|
int counter = 0;
|
||||||
|
auto compute = [&mutex, &counter, &nll, &nll2, n_vocab, logits, tokens, n_token] () {
|
||||||
|
double local_nll = 0, local_nll2 = 0;
|
||||||
|
while (true) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
|
int i = counter++;
|
||||||
|
if (i >= n_token) {
|
||||||
|
nll += local_nll; nll2 += local_nll2;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
lock.unlock();
|
||||||
|
double v = -log_softmax(n_vocab, logits + i*n_vocab, tokens[i+1]);
|
||||||
|
local_nll += v;
|
||||||
|
local_nll2 += v*v;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for (auto& w : workers) w = std::thread(compute);
|
||||||
|
compute();
|
||||||
|
for (auto& w : workers) w.join();
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
void perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||||
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
// Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research
|
||||||
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
// Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw`
|
||||||
@ -166,9 +202,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||||||
|
|
||||||
int count = 0;
|
int count = 0;
|
||||||
double nll = 0.0;
|
double nll = 0.0;
|
||||||
|
double nll2 = 0.0;
|
||||||
|
|
||||||
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
||||||
|
|
||||||
|
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int i = 0; i < n_chunk; ++i) {
|
||||||
const int start = i * params.n_ctx;
|
const int start = i * params.n_ctx;
|
||||||
const int end = start + params.n_ctx;
|
const int end = start + params.n_ctx;
|
||||||
@ -228,26 +267,32 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
|
|||||||
// Example, we have a context window of 512, we will compute perplexity for each of the
|
// Example, we have a context window of 512, we will compute perplexity for each of the
|
||||||
// last 256 tokens. Then, we split the input up into context window size chunks to
|
// last 256 tokens. Then, we split the input up into context window size chunks to
|
||||||
// process the entire prompt.
|
// process the entire prompt.
|
||||||
for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) {
|
const int first = std::min(512, params.n_ctx/2);
|
||||||
// Calculate probability of next token, given the previous ones.
|
process_logits(n_vocab, logits.data() + first*n_vocab, tokens.data() + start + first, params.n_ctx - 1 - first, workers, nll, nll2);
|
||||||
const std::vector<float> tok_logits(
|
count += params.n_ctx - first - 1;
|
||||||
logits.begin() + (j + 0) * n_vocab,
|
|
||||||
logits.begin() + (j + 1) * n_vocab);
|
|
||||||
|
|
||||||
const float prob = softmax(tok_logits)[tokens[start + j + 1]];
|
|
||||||
|
|
||||||
nll += -std::log(prob);
|
|
||||||
++count;
|
|
||||||
}
|
|
||||||
// perplexity is e^(average negative log-likelihood)
|
// perplexity is e^(average negative log-likelihood)
|
||||||
if (params.ppl_output_type == 0) {
|
if (params.ppl_output_type == 0) {
|
||||||
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
||||||
} else {
|
} else {
|
||||||
printf("%8d %.4lf\n", i*params.n_ctx, std::exp(nll / count));
|
double av = nll/count;
|
||||||
|
double av2 = nll2/count - av*av;
|
||||||
|
if (av2 > 0) av2 = sqrt(av2/(count-1));
|
||||||
|
printf("%8d %.4lf %4lf %4lf\n", i*params.n_ctx, std::exp(nll / count), av, av2);
|
||||||
}
|
}
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
nll2 /= count;
|
||||||
|
nll /= count;
|
||||||
|
nll2 -= nll * nll;
|
||||||
|
if (nll2 > 0) {
|
||||||
|
nll2 = sqrt(nll2/(count-1));
|
||||||
|
double ppl = exp(nll);
|
||||||
|
printf("Final estimate: PPL = %.4lf +/- %.5lf\n", ppl, nll2*ppl);
|
||||||
|
} else {
|
||||||
|
printf("Unexpected negative standard deviation of log(prob)\n");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
|
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch,
|
||||||
|
Loading…
Reference in New Issue
Block a user