imatrix : allow processing multiple chunks per batch

* perplexity : simplify filling the batch
This commit is contained in:
Francis Couture-Harpin 2024-08-20 15:17:24 -04:00
parent 90db8146d5
commit bce54642c8
2 changed files with 75 additions and 35 deletions

View File

@ -432,10 +432,9 @@ static void process_logits(
} }
} }
static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { static bool compute_imatrix(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) {
const bool add_bos = llama_add_bos_token(llama_get_model(ctx)); const bool add_bos = llama_add_bos_token(llama_get_model(ctx));
GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx))); GGML_ASSERT(!llama_add_eos_token(llama_get_model(ctx)));
const int n_ctx = llama_n_ctx(ctx);
auto tim1 = std::chrono::high_resolution_clock::now(); auto tim1 = std::chrono::high_resolution_clock::now();
fprintf(stderr, "%s: tokenizing the input ..\n", __func__); fprintf(stderr, "%s: tokenizing the input ..\n", __func__);
@ -479,22 +478,28 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
double nll = 0.0; double nll = 0.0;
double nll2 = 0.0; double nll2 = 0.0;
fprintf(stderr, "%s: computing over %d chunks with batch_size %d\n", __func__, n_chunk, n_batch);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1); std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
const int num_batches = (n_ctx + n_batch - 1) / n_batch; const int num_batches = (n_ctx + n_batch - 1) / n_batch;
const int n_seq = std::max(1, n_batch / n_ctx);
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
std::vector<float> logits; std::vector<float> logits;
if (params.compute_ppl && num_batches > 1) { if (params.compute_ppl && num_batches > 1) {
logits.reserve((size_t)n_ctx * n_vocab); logits.reserve((size_t)n_ctx * n_vocab);
} }
for (int i = 0; i < n_chunk; ++i) { fprintf(stderr, "%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
for (int i = 0; i < n_chunk; i += n_seq) {
const int start = i * n_ctx; const int start = i * n_ctx;
const int end = start + n_ctx; const int end = start + n_ctx;
std::vector<float> logits; const int n_seq_batch = std::min(n_seq, n_chunk - i);
const auto t_start = std::chrono::high_resolution_clock::now(); const auto t_start = std::chrono::high_resolution_clock::now();
@ -505,35 +510,50 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch); const int batch_size = std::min(end - batch_start, n_batch);
// save original token and restore it after eval // clear the batch
const auto token_org = tokens[batch_start]; llama_batch_clear(batch);
// add BOS token for the first batch of each chunk for (int seq = 0; seq < n_seq_batch; seq++) {
if (add_bos && j == 0) { int seq_start = batch_start + seq*n_ctx;
tokens[batch_start] = llama_token_bos(llama_get_model(ctx));
// save original token and restore it after eval
const auto token_org = tokens[seq_start];
// add BOS token for the first batch of each chunk
if (add_bos && j == 0) {
tokens[seq_start] = llama_token_bos(llama_get_model(ctx));
}
for (int k = 0; k < batch_size; ++k) {
// NOTE: specifying all logits to get activations for the output.weight tensor
// and also for the perplexity calculation.
// TODO: only get outputs when (params.process_output || params.compute_ppl)
// (not possible when this skips FFN computation of the last layer)
llama_batch_add(batch, tokens[seq_start + k], j*n_batch + k, { seq }, true);
}
// restore the original token in case it was set to BOS
tokens[seq_start] = token_org;
} }
// TODO: use batch.logits to save computations instead of relying on logits_all == true if (llama_decode(ctx, batch)) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return false; return false;
} }
// restore the original token in case it was set to BOS
tokens[batch_start] = token_org;
if (params.compute_ppl && num_batches > 1) { if (params.compute_ppl && num_batches > 1) {
const auto * batch_logits = llama_get_logits(ctx); const auto * batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
} }
} }
const auto t_end = std::chrono::high_resolution_clock::now();
if (i == 0) { if (i == 0) {
llama_synchronize(ctx);
const auto t_end = std::chrono::high_resolution_clock::now();
const float t_total = std::chrono::duration<float>(t_end - t_start).count(); const float t_total = std::chrono::duration<float>(t_end - t_start).count();
fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total);
int total_seconds = (int)(t_total * n_chunk); int total_seconds = (int)(t_total*n_chunk/n_seq);
if (total_seconds >= 60*60) { if (total_seconds >= 60*60) {
fprintf(stderr, "%d hours ", total_seconds / (60*60)); fprintf(stderr, "%d hours ", total_seconds / (60*60));
total_seconds = total_seconds % (60*60); total_seconds = total_seconds % (60*60);
@ -543,12 +563,21 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) {
if (params.compute_ppl) { if (params.compute_ppl) {
const int first = n_ctx/2; const int first = n_ctx/2;
const auto all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); for (int seq = 0; seq < n_seq_batch; seq++) {
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first);
workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first);
count += n_ctx - first - 1;
printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first;
process_logits(n_vocab, all_logits + first*n_vocab,
tokens_data, n_ctx - 1 - first,
workers, nll, nll2,
logit_history.data() + start + seq*n_ctx + first,
prob_history.data() + start + seq*n_ctx + first);
count += n_ctx - first - 1;
printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count));
}
fflush(stdout); fflush(stdout);
logits.clear(); logits.clear();
@ -584,7 +613,22 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
params.n_batch = std::min(params.n_batch, params.n_ctx); const int32_t n_ctx = params.n_ctx;
if (n_ctx <= 0) {
fprintf(stderr, "%s: imatrix tool requires '--ctx-size' > 0\n", __func__);
return 1;
}
{
const int32_t n_seq = std::max(1, params.n_batch / n_ctx);
const int32_t n_kv = n_seq * n_ctx;
params.n_parallel = n_seq;
params.n_ctx = n_kv;
params.n_batch = std::min(params.n_batch, n_kv);
}
g_collector.set_params(params); g_collector.set_params(params);
@ -632,7 +676,7 @@ int main(int argc, char ** argv) {
fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str()); fprintf(stderr, "%s\n", gpt_params_get_system_info(params).c_str());
} }
if (!compute_imatrix(ctx, params)) { if (!compute_imatrix(ctx, params, n_ctx)) {
return 1; return 1;
} }

View File

@ -583,7 +583,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
int n_outputs = 0; int n_outputs = 0;
batch.n_tokens = 0; // clear the batch
llama_batch_clear(batch);
for (int seq = 0; seq < n_seq_batch; seq++) { for (int seq = 0; seq < n_seq_batch; seq++) {
int seq_start = batch_start + seq*n_ctx; int seq_start = batch_start + seq*n_ctx;
@ -596,16 +598,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
} }
for (int k = 0; k < batch_size; ++k) { for (int k = 0; k < batch_size; ++k) {
const int idx = seq*n_ctx + k; llama_pos pos = j*n_batch + k;
batch.token [idx] = tokens[seq_start + k]; llama_batch_add(batch, tokens[seq_start + k], pos, { seq }, pos >= first);
batch.pos [idx] = j*n_batch + k; n_outputs += (int) (pos >= first);
batch.n_seq_id[idx] = 1;
batch.seq_id [idx][0] = seq;
batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0;
n_outputs += batch.logits[idx] != 0;
} }
batch.n_tokens += batch_size;
// restore the original token in case it was set to BOS // restore the original token in case it was set to BOS
tokens[seq_start] = token_org; tokens[seq_start] = token_org;