mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
perplexity : faster HellaSwag via batching (#5017)
* perplexity : faster HellaSwag ggml-ci * perplexity : clean-up ggml-ci * perplexity : no need for decode_helper ggml-ci * perplexity : add comments * perplexity : option to specify max batched tasks via `n_parallel` * perplexity : remove HellaSwag restruction for n_batch
This commit is contained in:
parent
682986a08e
commit
ad19812cda
@ -470,7 +470,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
prompt_lines.push_back(line);
|
prompt_lines.push_back(line);
|
||||||
}
|
}
|
||||||
|
|
||||||
if( prompt_lines.size() % 6 != 0) {
|
if (prompt_lines.size() % 6 != 0) {
|
||||||
fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
|
fprintf(stderr, "%s : number of lines in prompt not a multiple of 6.\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@ -485,7 +485,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||||
|
|
||||||
// Number of tasks to use when computing the score
|
// Number of tasks to use when computing the score
|
||||||
if ( params.hellaswag_tasks < hs_task_count ) {
|
if (params.hellaswag_tasks < hs_task_count) {
|
||||||
hs_task_count = params.hellaswag_tasks;
|
hs_task_count = params.hellaswag_tasks;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -502,27 +502,54 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
std::string ending[4];
|
std::string ending[4];
|
||||||
size_t ending_logprob_count[4];
|
size_t ending_logprob_count[4];
|
||||||
double ending_logprob[4];
|
double ending_logprob[4];
|
||||||
|
|
||||||
|
size_t i_batch; // starting index in the llama_batch
|
||||||
|
size_t common_prefix; // max number of initial tokens that are the same in all sentences
|
||||||
|
size_t required_tokens; // needed number of tokens to evaluate all 4 endings
|
||||||
|
std::vector<llama_token> seq_tokens[4];
|
||||||
};
|
};
|
||||||
|
|
||||||
fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
|
fprintf(stderr, "%s : selecting %zu %s tasks.\n", __func__, hs_task_count, (randomize_tasks?"randomized":"the first") );
|
||||||
|
|
||||||
// Select and read data from prompt lines
|
// Select and read data from prompt lines
|
||||||
hs_data_t *hs_data = new hs_data_t[hs_task_count];
|
std::vector<hs_data_t> hs_data(hs_task_count);
|
||||||
for (size_t i=0; i < hs_task_count; i++) {
|
for (size_t i = 0; i < hs_task_count; i++) {
|
||||||
size_t idx = i;
|
size_t idx = i;
|
||||||
|
|
||||||
|
auto & hs_cur = hs_data[i];
|
||||||
|
|
||||||
// Select a random example of those left in the prompt
|
// Select a random example of those left in the prompt
|
||||||
if (randomize_tasks) {
|
if (randomize_tasks) {
|
||||||
std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
|
std::uniform_int_distribution<size_t> dist(0, prompt_lines.size()/6-1 ) ;
|
||||||
idx = dist(rng);
|
idx = dist(rng);
|
||||||
}
|
}
|
||||||
|
|
||||||
hs_data[i].context = prompt_lines[idx*6];
|
hs_cur.context = prompt_lines[idx*6];
|
||||||
hs_data[i].gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
|
hs_cur.gold_ending_idx = std::stoi( prompt_lines[idx*6+1] );
|
||||||
for (size_t j=0; j < 4; j++) {
|
for (size_t j = 0; j < 4; j++) {
|
||||||
hs_data[i].ending[j] = prompt_lines[idx*6+2+j];
|
hs_cur.ending[j] = prompt_lines[idx*6+2+j];
|
||||||
|
hs_cur.seq_tokens[j] = ::llama_tokenize(ctx, hs_cur.context + " " + hs_cur.ending[j], add_bos);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// determine the common prefix of the endings
|
||||||
|
hs_cur.common_prefix = 0;
|
||||||
|
hs_cur.required_tokens = 0;
|
||||||
|
for (size_t k = 0; k < hs_cur.seq_tokens[0].size(); k++) {
|
||||||
|
if (hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[1][k] ||
|
||||||
|
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[2][k] ||
|
||||||
|
hs_cur.seq_tokens[0][k] != hs_cur.seq_tokens[3][k]) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
hs_cur.common_prefix++;
|
||||||
|
}
|
||||||
|
hs_cur.required_tokens = hs_cur.common_prefix +
|
||||||
|
hs_cur.seq_tokens[0].size() - hs_cur.common_prefix +
|
||||||
|
hs_cur.seq_tokens[1].size() - hs_cur.common_prefix +
|
||||||
|
hs_cur.seq_tokens[2].size() - hs_cur.common_prefix +
|
||||||
|
hs_cur.seq_tokens[3].size() - hs_cur.common_prefix;
|
||||||
|
|
||||||
|
//GGML_ASSERT(hs_cur.common_prefix >= ::llama_tokenize(ctx, hs_cur.context, add_bos).size());
|
||||||
|
|
||||||
// Delete the selected random example from the prompt
|
// Delete the selected random example from the prompt
|
||||||
if (randomize_tasks) {
|
if (randomize_tasks) {
|
||||||
prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) );
|
prompt_lines.erase( std::next(prompt_lines.begin(),idx*6) , std::next(prompt_lines.begin(),idx*6+6) );
|
||||||
@ -530,150 +557,160 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
|
fprintf(stderr, "%s : calculating hellaswag score over selected tasks.\n", __func__);
|
||||||
|
|
||||||
printf("\ntask\tacc_norm\n");
|
printf("\ntask\tacc_norm\n");
|
||||||
|
|
||||||
double acc = 0.0f;
|
double acc = 0.0f;
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
|
||||||
|
|
||||||
std::vector<std::vector<int>> ending_tokens(4);
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||||
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
|
const int n_batch = params.n_batch;
|
||||||
|
|
||||||
|
const int max_tasks_per_batch = params.n_parallel;
|
||||||
|
const int max_seq = 4*max_tasks_per_batch;
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);
|
||||||
|
|
||||||
std::vector<float> tok_logits(n_vocab);
|
std::vector<float> tok_logits(n_vocab);
|
||||||
|
std::vector<float> batch_logits(n_ctx*n_vocab);
|
||||||
|
|
||||||
for (size_t task_idx = 0; task_idx < hs_task_count; task_idx++) {
|
auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
|
||||||
// Tokenize the context to count tokens
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||||
std::vector<int> context_embd = ::llama_tokenize(ctx, hs_data[task_idx].context, add_bos);
|
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
size_t context_size = context_embd.size();
|
|
||||||
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
llama_batch batch_view = {
|
||||||
ending_tokens[i] = ::llama_tokenize(ctx, hs_data[task_idx].context + " " + hs_data[task_idx].ending[i], add_bos);
|
n_tokens,
|
||||||
for (int k = 0; k < int(context_size); ++k) {
|
batch.token + i,
|
||||||
if (ending_tokens[i][k] != context_embd[k]) {
|
nullptr,
|
||||||
fprintf(stderr, "Oops: ending %d of task %d differs from context at position %d\n",i,int(task_idx),k);
|
batch.pos + i,
|
||||||
break;
|
batch.n_seq_id + i,
|
||||||
|
batch.seq_id + i,
|
||||||
|
batch.logits + i,
|
||||||
|
0, 0, 0, // unused
|
||||||
|
};
|
||||||
|
|
||||||
|
const int ret = llama_decode(ctx, batch_view);
|
||||||
|
if (ret != 0) {
|
||||||
|
LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
for (size_t i0 = 0; i0 < hs_task_count; i0++) {
|
||||||
|
int n_cur = 0;
|
||||||
|
|
||||||
|
size_t i1 = i0;
|
||||||
|
size_t i_batch = 0; // this tells us where in `llama_batch` we are currently
|
||||||
|
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
|
// batch as much tasks as possible into the available context
|
||||||
|
// each task has 4 unique seuqnce ids - one for each ending
|
||||||
|
// the common prefix is shared among the 4 sequences to save tokens
|
||||||
|
// we extract logits only from the last common token and from all ending tokens of each sequence
|
||||||
|
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
|
||||||
|
auto & hs_cur = hs_data[i1];
|
||||||
|
|
||||||
|
const int s0 = 4*(i1 - i0);
|
||||||
|
if (s0 + 4 > max_seq) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
||||||
|
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
||||||
|
}
|
||||||
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||||
|
|
||||||
|
for (int s = 0; s < 4; ++s) {
|
||||||
|
for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
|
||||||
|
llama_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, true);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
hs_cur.i_batch = i_batch;
|
||||||
|
i_batch += hs_cur.required_tokens;
|
||||||
|
|
||||||
|
n_cur += hs_data[i1].required_tokens;
|
||||||
|
if (++i1 == hs_task_count) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Do the 1st ending
|
if (i0 == i1) {
|
||||||
// In this case we include the context when evaluating
|
fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0);
|
||||||
//auto query_embd = ::llama_tokenize(ctx, hs_data[task_idx].context + hs_data[task_idx].ending[0], add_bos);
|
|
||||||
auto query_embd = ending_tokens[0];
|
|
||||||
auto query_size = query_embd.size();
|
|
||||||
|
|
||||||
// Stop if query wont fit the ctx window
|
|
||||||
if (query_size > (size_t)n_ctx) {
|
|
||||||
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Speedup small evaluations by evaluating atleast 32 tokens
|
|
||||||
if (query_size < 32) {
|
|
||||||
query_embd.resize(32);
|
|
||||||
}
|
|
||||||
|
|
||||||
// clear the KV cache
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
auto logits = evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
|
// decode all tasks [i0, i1)
|
||||||
if (logits.empty()) {
|
if (!decode_helper(ctx, batch, n_batch)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::memcpy(tok_logits.data(), logits.data() + (context_size-1)*n_vocab, n_vocab*sizeof(float));
|
// compute the logprobs for each ending of the decoded tasks
|
||||||
const auto first_probs = softmax(tok_logits);
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
|
auto & hs_cur = hs_data[i];
|
||||||
|
|
||||||
hs_data[task_idx].ending_logprob_count[0] = 1;
|
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float));
|
||||||
hs_data[task_idx].ending_logprob[0] = std::log(first_probs[query_embd[context_size]]);
|
|
||||||
|
|
||||||
// Calculate the logprobs over the ending
|
const auto first_probs = softmax(tok_logits);
|
||||||
for (size_t j = context_size; j < query_size - 1; j++) {
|
|
||||||
|
|
||||||
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
|
size_t li = hs_cur.common_prefix; // logits index in the batch
|
||||||
|
|
||||||
const float prob = softmax(tok_logits)[query_embd[j + 1]];
|
for (int s = 0; s < 4; ++s) {
|
||||||
|
hs_cur.ending_logprob_count[s] = 1;
|
||||||
|
hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
|
||||||
|
|
||||||
hs_data[task_idx].ending_logprob[0] += std::log(prob);
|
// Calculate the logprobs over the ending
|
||||||
hs_data[task_idx].ending_logprob_count[0]++;
|
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
|
||||||
}
|
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float));
|
||||||
|
|
||||||
// Calculate the mean token logprob for acc_norm
|
const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];
|
||||||
hs_data[task_idx].ending_logprob[0] /= hs_data[task_idx].ending_logprob_count[0];
|
|
||||||
|
|
||||||
// Do the remaining endings
|
hs_cur.ending_logprob[s] += std::log(prob);
|
||||||
// For these, we use the bare ending with n_past = context_size
|
hs_cur.ending_logprob_count[s]++;
|
||||||
//
|
}
|
||||||
for (size_t ending_idx = 1; ending_idx < 4; ending_idx++) {
|
|
||||||
|
|
||||||
// Tokenize the query
|
// account that we skip the last token in the ending
|
||||||
query_embd.resize(ending_tokens[ending_idx].size() - context_size);
|
++li;
|
||||||
std::memcpy(query_embd.data(), ending_tokens[ending_idx].data() + context_size, query_embd.size()*sizeof(int));
|
|
||||||
query_size = query_embd.size();
|
|
||||||
|
|
||||||
// Stop if query wont fit the ctx window
|
// Calculate the mean token logprob for acc_norm
|
||||||
if (context_size + query_size > (size_t)n_ctx) {
|
hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
|
||||||
fprintf(stderr, "%s : number of tokens in query %zu > n_ctxl\n", __func__, query_size);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Speedup small evaluations by evaluating atleast 32 tokens
|
// Find the ending with maximum logprob
|
||||||
// No, resizing to 32 is actually slightly slower (at least on CUDA)
|
size_t ending_logprob_max_idx = 0;
|
||||||
//if (query_size < 32) {
|
double ending_logprob_max_val = hs_cur.ending_logprob[0];
|
||||||
// query_embd.resize(32);
|
for (size_t s = 1; s < 4; s++) {
|
||||||
//}
|
if (hs_cur.ending_logprob[s] > ending_logprob_max_val) {
|
||||||
|
ending_logprob_max_idx = s;
|
||||||
// Evaluate the query
|
ending_logprob_max_val = hs_cur.ending_logprob[s];
|
||||||
logits = evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
|
}
|
||||||
if (logits.empty()) {
|
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
hs_data[task_idx].ending_logprob_count[ending_idx] = 1;
|
//printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_cur.gold_ending_idx);
|
||||||
hs_data[task_idx].ending_logprob[ending_idx] = std::log(first_probs[query_embd[0]]);
|
|
||||||
|
|
||||||
// Calculate the logprobs over the ending
|
// If the gold ending got the maximum logprobe add one accuracy point
|
||||||
for (size_t j = 0; j < query_size - 1; j++) {
|
if (ending_logprob_max_idx == hs_cur.gold_ending_idx) {
|
||||||
std::memcpy(tok_logits.data(), logits.data() + j*n_vocab, n_vocab*sizeof(float));
|
acc += 1.0;
|
||||||
|
|
||||||
const float prob = softmax(tok_logits)[query_embd[j + 1]];
|
|
||||||
|
|
||||||
hs_data[task_idx].ending_logprob[ending_idx] += std::log(prob);
|
|
||||||
hs_data[task_idx].ending_logprob_count[ending_idx]++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the mean token logprob for acc_norm
|
// Print the accumulated accuracy mean x 100
|
||||||
hs_data[task_idx].ending_logprob[ending_idx] /= hs_data[task_idx].ending_logprob_count[ending_idx];
|
printf("%zu\t%.8lf\n", i + 1, acc/double(i + 1)*100.0);
|
||||||
|
fflush(stdout);
|
||||||
|
|
||||||
// printf("task %lu, ending %lu, whole_len %lu, context_len %lu, ending_logprob_count %lu, ending_logprob %.4f\n",
|
|
||||||
// task_idx,ending_idx,whole_size,context_size, hs_data[task_idx].ending_logprob_count[ending_idx], hs_data[task_idx].ending_logprob[ending_idx] );
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find the ending with maximum logprob
|
i0 = i1 - 1;
|
||||||
size_t ending_logprob_max_idx = 0;
|
|
||||||
double ending_logprob_max_val = hs_data[task_idx].ending_logprob[0];
|
|
||||||
for (size_t j = 1; j < 4; j++) {
|
|
||||||
if (hs_data[task_idx].ending_logprob[j] > ending_logprob_max_val) {
|
|
||||||
ending_logprob_max_idx = j;
|
|
||||||
ending_logprob_max_val = hs_data[task_idx].ending_logprob[j];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// printf("max logprob ending idx %lu, gold ending idx %lu\n", ending_logprob_max_idx, hs_data[task_idx].gold_ending_idx);
|
|
||||||
|
|
||||||
// If the gold ending got the maximum logprobe add one accuracy point
|
|
||||||
if (ending_logprob_max_idx == hs_data[task_idx].gold_ending_idx) {
|
|
||||||
acc += 1.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print the accumulated accuracy mean x 100
|
|
||||||
printf("%zu\t%.8lf\n",task_idx+1, acc/double(task_idx+1)*100.0);
|
|
||||||
fflush(stdout);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
delete [] hs_data;
|
llama_batch_free(batch);
|
||||||
|
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user