perplexity : faster HellaSwag via batching (#5017)

* perplexity : faster HellaSwag

ggml-ci

* perplexity : clean-up

ggml-ci

* perplexity : no need for decode_helper

ggml-ci

* perplexity : add comments

* perplexity : option to specify max batched tasks via `n_parallel`

* perplexity : remove HellaSwag restruction for n_batch
This commit is contained in:
Georgi Gerganov 2024-01-18 15:33:01 +02:00 committed by GitHub
parent 682986a08e
commit ad19812cda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -502,27 +502,54 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
std::string ending[4]; std::string ending[4];
size_t ending_logprob_count[4]; size_t ending_logprob_count[4];
double ending_logprob[4]; double ending_logprob[4];
size_t i_batch; // starting index in the llama_batch
size_t common_prefix; // max number of initial tokens that are the same in all sentences
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
std::vector<llama_token> seq_tokens[4];
}; };
fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") ); fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
// Select and read data from prompt lines // Select and read data from prompt lines
hs_data_t *hs_data = new hs_data_t[hs_task_count]; std::vector<hs_data_t> hs_data(hs_task_count);
for (size_t i = 0; i < hs_task_count; i++) { for (size_t i = 0; i < hs_task_count; i++) {
size_t idx = i; size_t idx = i;
auto & hs_cur = hs_data[i];
// Select a random example of those left in the prompt // Select a random example of those left in the prompt
if (randomize_tasks) { if (randomize_tasks) {
std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ; std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
idx = dist(rng); idx = dist(rng);
} }
hs_data[i].context = prompt_lines[idx*6]; hs_cur.context = prompt_lines[idx*6];
hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] ); hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
for (size_t j = 0; j < 4; j++) { for (size_t j = 0; j < 4; j++) {
hs_data[i].ending[j] = prompt_lines[idx*6+2+j]; hs_cur.ending[j] = prompt_lines[idx*6+2+j];
hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], add_bos);
} }
// determine the common prefix of the endings
hs_cur.common_prefix = 0;
hs_cur.required_tokens = 0;
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) {
break;
}
hs_cur.common_prefix++;
}
hs_cur.required_tokens = hs_cur.common_prefix +
hs_cur.seq_tokens[0].size() - hs_cur.common_prefix +
hs_cur.seq_tokens[1].size() - hs_cur.common_prefix +
hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;
//GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size());
// Delete the selected random example from the prompt // Delete the selected random example from the prompt
if (randomize_tasks) { if (randomize_tasks) {
prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) ); prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) );
@ -530,150 +557,160 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
} }
fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__); fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
printf("\ntask\tacc_norm\n"); printf("\ntask\tacc_norm\n");
double acc = 0.0f; double acc = 0.0f;
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch;
std::vector<std::vector<int>> ending_tokens(4); const int max_tasks_per_batch = params.n_parallel;
const int max_seq = 4*max_tasks_per_batch;
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
std::vector<float> tok_logits(n_vocab); std::vector<float> tok_logits(n_vocab);
std::vector<float> batch_logits(n_ctx*n_vocab);
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) { auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
// Tokenize the context to count tokens for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos); const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
size_t context_size = context_embd.size();
for (int i = 0; i < 4; ++i) { llama_batch batch_view = {
ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos); n_tokens,
for (int k = 0; k < int(context_size); ++k) { batch.token + i,
if (ending_tokens[i][k] != context_embd[k]) { nullptr,
fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k); batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
const int ret = llama_decode(ctx, batch_view);
if (ret != 0) {
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
}
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
}
return true;
};
for (size_t i0 = 0; i0 < hs_task_count; i0++) {
int n_cur = 0;
size_t i1 = i0;
size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
llama_batch_clear(batch);
// batch as much tasks as possible into the available context
// each task has 4 unique seuqnce ids - one for each ending
// the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and from all ending tokens of each sequence
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
auto & hs_cur = hs_data[i1];
const int s0 = 4*(i1 - i0);
if (s0 + 4 > max_seq) {
break;
}
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
}
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
for (int s = 0; s < 4; ++s) {
for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true);
}
}
hs_cur.i_batch = i_batch;
i_batch += hs_cur.required_tokens;
n_cur += hs_data[i1].required_tokens;
if (++i1 == hs_task_count) {
break; break;
} }
} }
}
// Do the 1st ending if (i0 == i1) {
// In this case we include the context when evaluating fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
//auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
auto query_embd = ending_tokens[0];
auto query_size = query_embd.size();
// Stop if query wont fit the ctx window
if (query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return; return;
} }
// Speedup small evaluations by evaluating atleast 32 tokens
if (query_size < 32) {
query_embd.resize(32);
}
// clear the KV cache
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
auto logits = evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab); // decode all tasks [i0, i1)
if (logits.empty()) { if (!decode_helper(ctx, batch, n_batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s: llama_decode() failed\n", __func__);
return; return;
} }
std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float)); // compute the logprobs for each ending of the decoded tasks
for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i];
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
const auto first_probs = softmax(tok_logits); const auto first_probs = softmax(tok_logits);
hs_data[task_idx].ending_logprob_count[0] = 1; size_t li = hs_cur.common_prefix; // logits index in the batch
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
for (int s = 0; s < 4; ++s) {
hs_cur.ending_logprob_count[s] = 1;
hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
// Calculate the logprobs over the ending // Calculate the logprobs over the ending
for (size_t j = context_size; j < query_size - 1; j++) { for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float));
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float)); const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];
const float prob = softmax(tok_logits)[query_embd[j + 1]]; hs_cur.ending_logprob[s] += std::log(prob);
hs_cur.ending_logprob_count[s]++;
hs_data[task_idx].ending_logprob[0] += std::log(prob);
hs_data[task_idx].ending_logprob_count[0]++;
} }
// account that we skip the last token in the ending
++li;
// Calculate the mean token logprob for acc_norm // Calculate the mean token logprob for acc_norm
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0]; hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
// Do the remaining endings
// For these, we use the bare ending with n_past = context_size
//
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
// Tokenize the query
query_embd.resize(ending_tokens[ending_idx].size() - context_size);
std::memcpy(query_embd.data(), ending_tokens[ending_idx].data() + context_size, query_embd.size()*sizeof(int));
query_size = query_embd.size();
// Stop if query wont fit the ctx window
if (context_size + query_size > (size_t)n_ctx) {
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
return;
}
// Speedup small evaluations by evaluating atleast 32 tokens
// No, resizing to 32 is actually slightly slower (at least on CUDA)
//if (query_size < 32) {
// query_embd.resize(32);
//}
// Evaluate the query
logits = evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
if (logits.empty()) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
// Calculate the logprobs over the ending
for (size_t j = 0; j < query_size - 1; j++) {
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
const float prob = softmax(tok_logits)[query_embd[j + 1]];
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
hs_data[task_idx].ending_logprob_count[ending_idx]++;
}
// Calculate the mean token logprob for acc_norm
hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx];
// printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
// task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
} }
// Find the ending with maximum logprob // Find the ending with maximum logprob
size_t ending_logprob_max_idx = 0; size_t ending_logprob_max_idx = 0;
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0]; double ending_logprob_max_val = hs_cur.ending_logprob[0];
for (size_t j = 1; j < 4; j++) { for (size_t s = 1; s < 4; s++) {
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) { if (hs_cur.ending_logprob[s] > ending_logprob_max_val) {
ending_logprob_max_idx = j; ending_logprob_max_idx = s;
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j]; ending_logprob_max_val = hs_cur.ending_logprob[s];
} }
} }
// printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx); //printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
// If the gold ending got the maximum logprobe add one accuracy point // If the gold ending got the maximum logprobe add one accuracy point
if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) { if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
acc += 1.0; acc += 1.0;
} }
// Print the accumulated accuracy mean x 100 // Print the accumulated accuracy mean x 100
printf("%zu\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0); printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
fflush(stdout); fflush(stdout);
} }
delete [] hs_data; i0 = i1 - 1;
}
llama_batch_free(batch);
printf("\n"); printf("\n");
} }