perplexity : support using multiple sequences to allow larger batch sizes (#5946)

* perplexity : support using multiple sequences to allow larger batch sizes

ggml-ci

* set cparams.n_parallel to the number of sequences

* print tested n_ctx, add assert
This commit is contained in:
slaren 2024-03-09 19:55:54 +01:00 committed by GitHub
parent 098dbaab44
commit d894f352bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 107 additions and 52 deletions

View File

@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {tokens, std::exp(nll / count), logit_history, prob_history}; return {tokens, std::exp(nll / count), logit_history, prob_history};
} }
static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) { static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
if (params.ppl_stride > 0) { if (params.ppl_stride > 0) {
return perplexity_v2(ctx, params); return perplexity_v2(ctx, params);
} }
@ -453,7 +453,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// BOS tokens will be added for each chunk before eval // BOS tokens will be added for each chunk before eval
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx);
std::ofstream logits_stream; std::ofstream logits_stream;
if (!params.logits_file.empty()) { if (!params.logits_file.empty()) {
@ -499,13 +498,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double nll2 = 0.0; double nll2 = 0.0;
const int num_batches = (n_ctx + n_batch - 1) / n_batch; const int num_batches = (n_ctx + n_batch - 1) / n_batch;
const int n_seq = std::max(1, n_batch / n_ctx);
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
std::vector<float> logits; std::vector<float> logits;
if (num_batches > 1) { if (num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab); logits.reserve((size_t)n_ctx * n_vocab);
} }
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1); std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
@ -518,54 +523,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
log_probs.resize(n_ctx * nv); log_probs.resize(n_ctx * nv);
} }
for (int i = 0; i < n_chunk; ++i) {
const int start = i * n_ctx;
const int end = start + n_ctx;
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
// save original token and restore it after eval
const auto token_org = tokens[batch_start];
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
if (num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
}
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
// We get the logits for all the tokens in the context window (params.n_ctx) // We get the logits for all the tokens in the context window (params.n_ctx)
// from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
// calculate the perplexity over the last half of the window (so the model always has // calculate the perplexity over the last half of the window (so the model always has
@ -579,25 +536,98 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
// last 256 tokens. Then, we split the input up into context window size chunks to // last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt. // process the entire prompt.
const int first = n_ctx/2; const int first = n_ctx/2;
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
for (int i = 0; i < n_chunk; i += n_seq) {
const int start = i * n_ctx;
const int end = start + n_ctx;
const int n_seq_batch = std::min(n_seq, n_chunk - i);
const auto t_start = std::chrono::high_resolution_clock::now();
// clear the KV cache
llama_kv_cache_clear(ctx);
for (int j = 0; j < num_batches; ++j) {
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
batch.n_tokens = 0;
for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx;
// save original token and restore it after eval
const auto token_org = tokens[seq_start];
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
}
for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k;
batch.token[idx] = tokens[seq_start + k];
batch.pos[idx] = j*n_batch + k;
batch.n_seq_id[idx] = 1;
batch.seq_id[idx][0] = seq;
batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0;
}
batch.n_tokens += batch_size;
// restore the original token in case it was set to BOS
tokens[seq_start] = token_org;
}
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
if (num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
}
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) {
const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total*n_chunk/n_seq);
if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60);
}
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
}
for (int seq = 0; seq < n_seq_batch; seq++) {
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx);
llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
if (!params.logits_file.empty()) { if (!params.logits_file.empty()) {
process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, process_logits(logits_stream, n_vocab, all_logits + first*n_vocab,
tokens_data, n_ctx - 1 - first,
workers, log_probs, nll, nll2); workers, log_probs, nll, nll2);
} else { } else {
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, process_logits(n_vocab, all_logits + first*n_vocab,
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); tokens_data, n_ctx - 1 - first,
workers, nll, nll2,
logit_history.data() + start + seq*n_ctx + first,
prob_history.data() + start + seq*n_ctx + first);
} }
count += n_ctx - first - 1; count += n_ctx - first - 1;
// perplexity is e^(average negative log-likelihood) // perplexity is e^(average negative log-likelihood)
if (params.ppl_output_type == 0) { if (params.ppl_output_type == 0) {
printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
} else { } else {
double av = nll/count; double av = nll/count;
double av2 = nll2/count - av*av; double av2 = nll2/count - av*av;
if (av2 > 0) av2 = sqrt(av2/(count-1)); if (av2 > 0) av2 = sqrt(av2/(count-1));
printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2);
} }
}
fflush(stdout); fflush(stdout);
logits.clear(); logits.clear();
@ -615,6 +645,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
printf("Unexpected negative standard deviation of log(prob)\n"); printf("Unexpected negative standard deviation of log(prob)\n");
} }
llama_batch_free(batch);
return {tokens, ppl, logit_history, prob_history}; return {tokens, ppl, logit_history, prob_history};
} }
@ -1782,13 +1814,24 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
params.n_batch = 512;
if (!gpt_params_parse(argc, argv, params)) { if (!gpt_params_parse(argc, argv, params)) {
return 1; return 1;
} }
params.logits_all = true; params.logits_all = true;
const int32_t n_ctx = params.n_ctx;
const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence;
if (ppl) {
int n_seq = std::max(1, params.n_batch / n_ctx);
int32_t n_kv = n_seq * n_ctx;
params.n_parallel = n_seq;
params.n_ctx = n_kv;
params.n_batch = std::min(params.n_batch, n_kv);
} else {
params.n_batch = std::min(params.n_batch, params.n_ctx); params.n_batch = std::min(params.n_batch, params.n_ctx);
}
if (params.ppl_stride > 0) { if (params.ppl_stride > 0) {
fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n", fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n",
@ -1847,7 +1890,7 @@ int main(int argc, char ** argv) {
} else if (params.kl_divergence) { } else if (params.kl_divergence) {
kl_divergence(ctx, params); kl_divergence(ctx, params);
} else { } else {
results = perplexity(ctx, params); results = perplexity(ctx, params, n_ctx);
} }
llama_print_timings(ctx); llama_print_timings(ctx);

View File

@ -8925,17 +8925,29 @@ static int llama_decode_internal(
if (batch.logits) { if (batch.logits) {
logits_out.resize(n_vocab * n_tokens); logits_out.resize(n_vocab * n_tokens);
int32_t i_first = -1;
for (uint32_t i = 0; i < n_tokens; i++) { for (uint32_t i = 0; i < n_tokens; i++) {
if (batch.logits[i] == 0) { if (batch.logits[i] && i_first == -1) {
continue; i_first = (int32_t) i;
}
if (batch.logits[i] == 0 || i == n_tokens - 1) {
if (i_first != -1) {
int i_last = batch.logits[i] == 0 ? i : i + 1;
// extract logits for the range [i_first, i_last)
// group the requests to minimize the number of calls to the backend
ggml_backend_tensor_get_async(backend_res, res,
logits_out.data() + (n_vocab*i_first),
(n_vocab*i_first)*sizeof(float),
(i_last - i_first)*n_vocab*sizeof(float));
i_first = -1;
}
} }
ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float));
#ifndef NDEBUG #ifndef NDEBUG
logits_valid[i] = true; logits_valid[i] = batch.logits[i] != 0;
#endif #endif
} }
} else if (lctx.logits_all) { } else if (lctx.logits_all) {
logits_out.resize(n_vocab * n_tokens); logits_out.resize(n_vocab*n_tokens);
ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float));
#ifndef NDEBUG #ifndef NDEBUG
std::fill(logits_valid.begin(), logits_valid.end(), true); std::fill(logits_valid.begin(), logits_valid.end(), true);