mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 10:24:35 +00:00
perplexity : support using multiple sequences to allow larger batch sizes (#5946)
* perplexity : support using multiple sequences to allow larger batch sizes ggml-ci * set cparams.n_parallel to the number of sequences * print tested n_ctx, add assert
This commit is contained in:
parent
098dbaab44
commit
d894f352bf
@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||||||
return {tokens, std::exp(nll / count), logit_history, prob_history};
|
return {tokens, std::exp(nll / count), logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
|
static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
|
||||||
if (params.ppl_stride > 0) {
|
if (params.ppl_stride > 0) {
|
||||||
return perplexity_v2(ctx, params);
|
return perplexity_v2(ctx, params);
|
||||||
}
|
}
|
||||||
@ -453,7 +453,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
// BOS tokens will be added for each chunk before eval
|
// BOS tokens will be added for each chunk before eval
|
||||||
|
|
||||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
|
||||||
|
|
||||||
std::ofstream logits_stream;
|
std::ofstream logits_stream;
|
||||||
if (!params.logits_file.empty()) {
|
if (!params.logits_file.empty()) {
|
||||||
@ -499,13 +498,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
double nll2 = 0.0;
|
double nll2 = 0.0;
|
||||||
|
|
||||||
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
||||||
|
const int n_seq = std::max(1, n_batch / n_ctx);
|
||||||
|
|
||||||
|
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
|
||||||
|
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
|
||||||
|
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
if (num_batches > 1) {
|
if (num_batches > 1) {
|
||||||
logits.reserve((size_t)n_ctx * n_vocab);
|
logits.reserve((size_t)n_ctx * n_vocab);
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
|
fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
|
||||||
|
|
||||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||||
|
|
||||||
@ -518,10 +523,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
log_probs.resize(n_ctx * nv);
|
log_probs.resize(n_ctx * nv);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
// We get the logits for all the tokens in the context window (params.n_ctx)
|
||||||
|
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
|
||||||
|
// calculate the perplexity over the last half of the window (so the model always has
|
||||||
|
// some context to predict the token).
|
||||||
|
//
|
||||||
|
// We rely on the fact that attention in the forward pass only looks at previous
|
||||||
|
// tokens here, so the logits returned for each token are an accurate representation
|
||||||
|
// of what the model would have predicted at that point.
|
||||||
|
//
|
||||||
|
// Example, we have a context window of 512, we will compute perplexity for each of the
|
||||||
|
// last 256 tokens. Then, we split the input up into context window size chunks to
|
||||||
|
// process the entire prompt.
|
||||||
|
const int first = n_ctx/2;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_chunk; i += n_seq) {
|
||||||
const int start = i * n_ctx;
|
const int start = i * n_ctx;
|
||||||
const int end = start + n_ctx;
|
const int end = start + n_ctx;
|
||||||
|
|
||||||
|
const int n_seq_batch = std::min(n_seq, n_chunk - i);
|
||||||
|
|
||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
@ -531,22 +552,37 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
const int batch_size = std::min(end - batch_start, n_batch);
|
const int batch_size = std::min(end - batch_start, n_batch);
|
||||||
|
|
||||||
// save original token and restore it after eval
|
batch.n_tokens = 0;
|
||||||
const auto token_org = tokens[batch_start];
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
|
int seq_start = batch_start + seq*n_ctx;
|
||||||
|
|
||||||
// add BOS token for the first batch of each chunk
|
// save original token and restore it after eval
|
||||||
if (add_bos && j == 0) {
|
const auto token_org = tokens[seq_start];
|
||||||
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
|
|
||||||
|
// add BOS token for the first batch of each chunk
|
||||||
|
if (add_bos && j == 0) {
|
||||||
|
tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int k = 0; k < batch_size; ++k) {
|
||||||
|
const int idx = seq*n_ctx + k;
|
||||||
|
batch.token[idx] = tokens[seq_start + k];
|
||||||
|
batch.pos[idx] = j*n_batch + k;
|
||||||
|
batch.n_seq_id[idx] = 1;
|
||||||
|
batch.seq_id[idx][0] = seq;
|
||||||
|
batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0;
|
||||||
|
}
|
||||||
|
batch.n_tokens += batch_size;
|
||||||
|
|
||||||
|
// restore the original token in case it was set to BOS
|
||||||
|
tokens[seq_start] = token_org;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
|
if (llama_decode(ctx, batch)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return {tokens, -1, logit_history, prob_history};
|
return {tokens, -1, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
// restore the original token in case it was set to BOS
|
|
||||||
tokens[batch_start] = token_org;
|
|
||||||
|
|
||||||
if (num_batches > 1) {
|
if (num_batches > 1) {
|
||||||
const auto * batch_logits = llama_get_logits(ctx);
|
const auto * batch_logits = llama_get_logits(ctx);
|
||||||
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
|
||||||
@ -558,7 +594,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
if (i == 0) {
|
if (i == 0) {
|
||||||
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
|
||||||
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
|
||||||
int total_seconds = (int)(t_total * n_chunk);
|
int total_seconds = (int)(t_total*n_chunk/n_seq);
|
||||||
if (total_seconds >= 60*60) {
|
if (total_seconds >= 60*60) {
|
||||||
fprintf(stderr, "%d hours ", total_seconds / (60*60));
|
fprintf(stderr, "%d hours ", total_seconds / (60*60));
|
||||||
total_seconds = total_seconds % (60*60);
|
total_seconds = total_seconds % (60*60);
|
||||||
@ -566,37 +602,31 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// We get the logits for all the tokens in the context window (params.n_ctx)
|
for (int seq = 0; seq < n_seq_batch; seq++) {
|
||||||
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
|
||||||
// calculate the perplexity over the last half of the window (so the model always has
|
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
|
||||||
// some context to predict the token).
|
if (!params.logits_file.empty()) {
|
||||||
//
|
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
|
||||||
// We rely on the fact that attention in the forward pass only looks at previous
|
tokens_data, n_ctx - 1 - first,
|
||||||
// tokens here, so the logits returned for each token are an accurate representation
|
workers, log_probs, nll, nll2);
|
||||||
// of what the model would have predicted at that point.
|
} else {
|
||||||
//
|
process_logits(n_vocab, all_logits + first*n_vocab,
|
||||||
// Example, we have a context window of 512, we will compute perplexity for each of the
|
tokens_data, n_ctx - 1 - first,
|
||||||
// last 256 tokens. Then, we split the input up into context window size chunks to
|
workers, nll, nll2,
|
||||||
// process the entire prompt.
|
logit_history.data() + start + seq*n_ctx + first,
|
||||||
const int first = n_ctx/2;
|
prob_history.data() + start + seq*n_ctx + first);
|
||||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
}
|
||||||
if (!params.logits_file.empty()) {
|
count += n_ctx - first - 1;
|
||||||
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
|
||||||
workers, log_probs, nll, nll2);
|
|
||||||
} else {
|
|
||||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
|
||||||
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
|
|
||||||
}
|
|
||||||
count += n_ctx - first - 1;
|
|
||||||
|
|
||||||
// perplexity is e^(average negative log-likelihood)
|
// perplexity is e^(average negative log-likelihood)
|
||||||
if (params.ppl_output_type == 0) {
|
if (params.ppl_output_type == 0) {
|
||||||
printf("[%d]%.4lf,", i + 1, std::exp(nll / count));
|
printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
|
||||||
} else {
|
} else {
|
||||||
double av = nll/count;
|
double av = nll/count;
|
||||||
double av2 = nll2/count - av*av;
|
double av2 = nll2/count - av*av;
|
||||||
if (av2 > 0) av2 = sqrt(av2/(count-1));
|
if (av2 > 0) av2 = sqrt(av2/(count-1));
|
||||||
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
@ -615,6 +645,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
printf("Unexpected negative standard deviation of log(prob)\n");
|
printf("Unexpected negative standard deviation of log(prob)\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
|
||||||
return {tokens, ppl, logit_history, prob_history};
|
return {tokens, ppl, logit_history, prob_history};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1782,13 +1814,24 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
|||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
params.n_batch = 512;
|
|
||||||
if (!gpt_params_parse(argc, argv, params)) {
|
if (!gpt_params_parse(argc, argv, params)) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
params.logits_all = true;
|
params.logits_all = true;
|
||||||
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
|
||||||
|
const int32_t n_ctx = params.n_ctx;
|
||||||
|
|
||||||
|
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
|
||||||
|
if (ppl) {
|
||||||
|
int n_seq = std::max(1, params.n_batch / n_ctx);
|
||||||
|
int32_t n_kv = n_seq * n_ctx;
|
||||||
|
params.n_parallel = n_seq;
|
||||||
|
params.n_ctx = n_kv;
|
||||||
|
params.n_batch = std::min(params.n_batch, n_kv);
|
||||||
|
} else {
|
||||||
|
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||||
|
}
|
||||||
|
|
||||||
if (params.ppl_stride > 0) {
|
if (params.ppl_stride > 0) {
|
||||||
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
|
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
|
||||||
@ -1847,7 +1890,7 @@ int main(int argc, char ** argv) {
|
|||||||
} else if (params.kl_divergence) {
|
} else if (params.kl_divergence) {
|
||||||
kl_divergence(ctx, params);
|
kl_divergence(ctx, params);
|
||||||
} else {
|
} else {
|
||||||
results = perplexity(ctx, params);
|
results = perplexity(ctx, params, n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
|
22
llama.cpp
22
llama.cpp
@ -8925,17 +8925,29 @@ static int llama_decode_internal(
|
|||||||
|
|
||||||
if (batch.logits) {
|
if (batch.logits) {
|
||||||
logits_out.resize(n_vocab * n_tokens);
|
logits_out.resize(n_vocab * n_tokens);
|
||||||
|
int32_t i_first = -1;
|
||||||
for (uint32_t i = 0; i < n_tokens; i++) {
|
for (uint32_t i = 0; i < n_tokens; i++) {
|
||||||
if (batch.logits[i] == 0) {
|
if (batch.logits[i] && i_first == -1) {
|
||||||
continue;
|
i_first = (int32_t) i;
|
||||||
|
}
|
||||||
|
if (batch.logits[i] == 0 || i == n_tokens - 1) {
|
||||||
|
if (i_first != -1) {
|
||||||
|
int i_last = batch.logits[i] == 0 ? i : i + 1;
|
||||||
|
// extract logits for the range [i_first, i_last)
|
||||||
|
// group the requests to minimize the number of calls to the backend
|
||||||
|
ggml_backend_tensor_get_async(backend_res, res,
|
||||||
|
logits_out.data() + (n_vocab*i_first),
|
||||||
|
(n_vocab*i_first)*sizeof(float),
|
||||||
|
(i_last - i_first)*n_vocab*sizeof(float));
|
||||||
|
i_first = -1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
logits_valid[i] = true;
|
logits_valid[i] = batch.logits[i] != 0;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
} else if (lctx.logits_all) {
|
} else if (lctx.logits_all) {
|
||||||
logits_out.resize(n_vocab * n_tokens);
|
logits_out.resize(n_vocab*n_tokens);
|
||||||
ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
|
ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
std::fill(logits_valid.begin(), logits_valid.end(), true);
|
std::fill(logits_valid.begin(), logits_valid.end(), true);
|
||||||
|
Loading…
Reference in New Issue
Block a user