mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 02:14:35 +00:00
Add Winogrande evaluation (#5015)
* winogrande: simple implementation It doesn't look like it is working - why? For Mistral-7B it is barely better than random chance (score ~60% for 1267 tasks), while I see Mistral-7B scoring 78.4% on the HF leader board. 1-sigma statistical uncertainty for 1267 tasks is ~1.4, so no way the difference is due to statistics. * winogrande: somewhat better Score for Mistrali7-B is now 68.9 on the validation set of winogrande_debiased. Still far from the reported 78.4, but better than what I had before. * winogrande: improving Mistral-7B score is now 73.56. Still not quite 78.4 but getting there. We are also getting a lower score on HellaSwag compared to HF leader board, so I'm not expecting we will get up to 78.4 anyway. It looks like it is better to skip the choice word(s) when evaluating the average log-likelihood. This kind of makes sense because a more common word (in Winogrande this is often a name) will have a higher probability without knowing about the follow up context, and this will skew the log-likelihood towards the more common word. We can only do this if the choice words are not last in the sentence. It also looks like it is better to skip the punctuation at the end of the sentence, provided the choice words are not last. * winogrande: add dataset instructions --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
parent
dcad445d0c
commit
682986a08e
@ -681,6 +681,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
||||
break;
|
||||
}
|
||||
params.hellaswag_tasks = std::stoi(argv[i]);
|
||||
} else if (arg == "--winogrande") {
|
||||
params.winogrande = true;
|
||||
} else if (arg == "--winogrande-tasks") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.winogrande_tasks = std::stoi(argv[i]);
|
||||
} else if (arg == "--ignore-eos") {
|
||||
params.ignore_eos = true;
|
||||
} else if (arg == "--no-penalize-nl") {
|
||||
@ -926,6 +934,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
|
||||
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
|
||||
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
|
||||
printf(" --winogrande compute Winogrande score over random tasks from datafile supplied with -f\n");
|
||||
printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks);
|
||||
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
||||
printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
|
||||
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||
|
@ -105,6 +105,9 @@ struct gpt_params {
|
||||
bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt
|
||||
size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score
|
||||
|
||||
bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt
|
||||
size_t winogrande_tasks= 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed
|
||||
|
||||
bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS
|
||||
bool random_prompt = false; // do not randomize prompt if none provided
|
||||
bool use_color = false; // use color to distinguish generations and inputs
|
||||
|
@ -9,6 +9,9 @@
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <vector>
|
||||
#include <array>
|
||||
#include <fstream>
|
||||
#include <sstream>
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
@ -419,9 +422,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
||||
return {tokens, ppl, logit_history, prob_history};
|
||||
}
|
||||
|
||||
static std::vector<float> hellaswag_evaluate_tokens(
|
||||
llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab
|
||||
) {
|
||||
static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int> & tokens,
|
||||
int n_past, int n_batch, int n_vocab) {
|
||||
std::vector<float> result;
|
||||
result.reserve(tokens.size() * n_vocab);
|
||||
size_t n_chunk = (tokens.size() + n_batch - 1)/n_batch;
|
||||
@ -573,7 +575,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||
// clear the KV cache
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
|
||||
auto logits = evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
|
||||
if (logits.empty()) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
@ -622,7 +624,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||
//}
|
||||
|
||||
// Evaluate the query
|
||||
logits = hellaswag_evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
|
||||
logits = evaluate_tokens(ctx, query_embd, context_size, params.n_batch, n_vocab);
|
||||
if (logits.empty()) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
@ -676,6 +678,235 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
struct winogrande_entry {
|
||||
std::string first;
|
||||
std::string second;
|
||||
std::array<std::string, 2> choices;
|
||||
int answer;
|
||||
};
|
||||
|
||||
static std::vector<winogrande_entry> load_winogrande_from_csv(const std::string& prompt) {
|
||||
std::vector<winogrande_entry> result;
|
||||
std::istringstream in(prompt);
|
||||
std::string line;
|
||||
std::array<int, 4> comma_pos;
|
||||
while (true) {
|
||||
std::getline(in, line);
|
||||
if (in.fail() || in.eof()) break;
|
||||
int ipos = 0;
|
||||
bool quote_open = false;
|
||||
for (int i = 0; i < int(line.size()); ++i) {
|
||||
if (!quote_open) {
|
||||
if (line[i] == ',') {
|
||||
comma_pos[ipos++] = i;
|
||||
if (ipos == 4) break;
|
||||
}
|
||||
else if (line[i] == '"') {
|
||||
quote_open = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (line[i] == '"') {
|
||||
quote_open = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (ipos != 4) {
|
||||
printf("%s: failed to find comma separators in <%s>\n", __func__, line.c_str());
|
||||
continue;
|
||||
}
|
||||
auto sentence = line[comma_pos[0]+1] == '"' ? line.substr(comma_pos[0]+2, comma_pos[1] - comma_pos[0] - 3)
|
||||
: line.substr(comma_pos[0]+1, comma_pos[1] - comma_pos[0] - 1);
|
||||
auto choice1 = line.substr(comma_pos[1]+1, comma_pos[2] - comma_pos[1] - 1);
|
||||
auto choice2 = line.substr(comma_pos[2]+1, comma_pos[3] - comma_pos[2] - 1);
|
||||
auto answer = line.substr(comma_pos[3]+1, line.size() - comma_pos[3] - 1);
|
||||
auto index = line.substr(0, comma_pos[0]);
|
||||
int where = 0;
|
||||
for ( ; where < int(sentence.size()); ++where) {
|
||||
if (sentence[where] == '_') break;
|
||||
}
|
||||
if (where == int(sentence.size())) {
|
||||
printf("%s: no _ in <%s>\n", __func__, sentence.c_str());
|
||||
continue;
|
||||
}
|
||||
std::istringstream stream(answer.c_str());
|
||||
int i_answer; stream >> i_answer;
|
||||
if (stream.fail() || i_answer < 1 || i_answer > 2) {
|
||||
printf("%s: failed to parse answer <%s>\n", __func__, answer.c_str());
|
||||
continue;
|
||||
}
|
||||
result.emplace_back();
|
||||
auto& wg = result.back();
|
||||
wg.first = sentence.substr(0, where);
|
||||
wg.second = sentence.substr(where + 1, sentence.size() - where - 1);
|
||||
wg.choices[0] = std::move(choice1);
|
||||
wg.choices[1] = std::move(choice2);
|
||||
wg.answer = i_answer;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/*
|
||||
* Evaluates the Winogrande score.
|
||||
* Uses a CSV containing task index, dentence, choice 1, choice 2, answer (1 or 2)
|
||||
* You can get one such dataset from e.g. https://huggingface.co/datasets/ikawrakow/winogrande-eval-for-llama.cpp
|
||||
* As an example, the 1st row in the above dataset is
|
||||
*
|
||||
* 0,Sarah was a much better surgeon than Maria so _ always got the easier cases.,Sarah,Maria,2
|
||||
*
|
||||
*/
|
||||
static void winogrande_score(llama_context * ctx, const gpt_params & params) {
|
||||
|
||||
constexpr int k_min_trailing_ctx = 3;
|
||||
|
||||
auto data = load_winogrande_from_csv(params.prompt);
|
||||
if (data.empty()) {
|
||||
fprintf(stderr, "%s: no tasks\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s : loaded %zu tasks from prompt.\n", __func__, data.size());
|
||||
|
||||
if (params.winogrande_tasks > 0 && params.winogrande_tasks < data.size()) {
|
||||
fprintf(stderr, "%s : selecting %zu random tasks\n", __func__, params.winogrande_tasks);
|
||||
std::mt19937 rng(1);
|
||||
std::vector<int> aux(data.size());
|
||||
for (int i = 0; i < int(data.size()); ++i) {
|
||||
aux[i] = i;
|
||||
}
|
||||
float scale = 1/(1.f + (float)rng.max());
|
||||
std::vector<winogrande_entry> selected;
|
||||
selected.reserve(params.winogrande_tasks);
|
||||
for (int i = 0; i < int(params.winogrande_tasks); ++i) {
|
||||
int j = int(scale*rng()*aux.size());
|
||||
selected[i] = std::move(data[aux[j]]);
|
||||
aux[j] = aux.back();
|
||||
aux.pop_back();
|
||||
}
|
||||
data = std::move(selected);
|
||||
}
|
||||
|
||||
// This is needed as usual for LLaMA models
|
||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||
|
||||
fprintf(stderr, "%s : calculating winogrande score over selected tasks.\n", __func__);
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
std::vector<float> tok_logits(n_vocab);
|
||||
|
||||
int n_correct = 0;
|
||||
int n_done = 0;
|
||||
|
||||
for (size_t task_idx = 0; task_idx < data.size(); task_idx++) {
|
||||
const auto& task = data[task_idx];
|
||||
|
||||
auto base_context = ::llama_tokenize(ctx, task.first, add_bos);
|
||||
auto base_ctx_1st = ::llama_tokenize(ctx, task.first + task.choices[0], add_bos);
|
||||
auto base_ctx_2nd = ::llama_tokenize(ctx, task.first + task.choices[1], add_bos);
|
||||
|
||||
auto sentence_1st = task.first + task.choices[0] + task.second;
|
||||
auto sentence_2nd = task.first + task.choices[1] + task.second;
|
||||
auto query_1st = ::llama_tokenize(ctx, sentence_1st, add_bos);
|
||||
auto query_2nd = ::llama_tokenize(ctx, sentence_2nd, add_bos);
|
||||
|
||||
if (query_1st.size() > (size_t)n_ctx || query_2nd.size() > (size_t)n_ctx) {
|
||||
fprintf(stderr, "%s : number of tokens in queries %zu, %zu > n_ctxl\n", __func__, query_1st.size(), query_2nd.size());
|
||||
return;
|
||||
}
|
||||
|
||||
auto query_1st_size = query_1st.size();
|
||||
auto query_2nd_size = query_2nd.size();
|
||||
|
||||
// Speedup small evaluations by evaluating atleast 32 tokens
|
||||
// For Winogrande this seems to slow it down rather than speed it up.
|
||||
//if (query_1st.size() < 32) query_1st.resize(32);
|
||||
//if (query_2nd.size() < 32) query_2nd.resize(32);
|
||||
|
||||
llama_kv_cache_clear(ctx);
|
||||
auto logits_1st = evaluate_tokens(ctx, query_1st, 0, params.n_batch, n_vocab);
|
||||
|
||||
llama_kv_cache_clear(ctx);
|
||||
auto logits_2nd = evaluate_tokens(ctx, query_2nd, 0, params.n_batch, n_vocab);
|
||||
|
||||
if (logits_1st.empty() || logits_2nd.empty()) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
||||
bool skip_choice = query_1st_size - base_ctx_1st.size() > k_min_trailing_ctx &&
|
||||
query_2nd_size - base_ctx_2nd.size() > k_min_trailing_ctx;
|
||||
|
||||
float score_1st = 0;
|
||||
bool is_nan_1st = false;
|
||||
const auto& base_1 = skip_choice ? base_ctx_1st : base_context;
|
||||
const int last_1st = query_1st_size - base_1.size() > 1 ? 1 : 0;
|
||||
for (size_t j = base_1.size()-1; j < query_1st_size-1-last_1st; ++j) {
|
||||
std::memcpy(tok_logits.data(), logits_1st.data() + j*n_vocab, n_vocab*sizeof(float));
|
||||
const float prob = softmax(tok_logits)[query_1st[j+1]];
|
||||
if (std::isnan(prob) || !prob) {
|
||||
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
|
||||
prob, j, sentence_1st.c_str(), base_context.size());
|
||||
is_nan_1st = true;
|
||||
break;
|
||||
}
|
||||
score_1st += std::log(prob);
|
||||
}
|
||||
score_1st /= (query_1st_size - base_1.size() - last_1st);
|
||||
|
||||
float score_2nd = 0;
|
||||
bool is_nan_2nd = false;
|
||||
const auto& base_2 = skip_choice ? base_ctx_2nd : base_context;
|
||||
const int last_2nd = query_2nd_size - base_2.size() > 1 ? 1 : 0;
|
||||
for (size_t j = base_2.size()-1; j < query_2nd_size-1-last_2nd; ++j) {
|
||||
std::memcpy(tok_logits.data(), logits_2nd.data() + j*n_vocab, n_vocab*sizeof(float));
|
||||
const float prob = softmax(tok_logits)[query_2nd[j+1]];
|
||||
if (std::isnan(prob) || !prob) {
|
||||
fprintf(stderr, "%s: %g probability for token %zu when evaluating <%s>. Base context has %zu tokens\n", __func__,
|
||||
prob, j, sentence_2nd.c_str(), base_context.size());
|
||||
is_nan_2nd = true;
|
||||
break;
|
||||
}
|
||||
score_2nd += std::log(prob);
|
||||
}
|
||||
score_2nd /= (query_2nd_size - base_2.size() - last_2nd);
|
||||
|
||||
if (is_nan_1st || is_nan_2nd) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (std::isnan(score_1st) || std::isnan(score_2nd)) {
|
||||
printf("================== NaN score %g, %g) for:\n", score_1st, score_2nd);
|
||||
printf("Q1: <%s> - %zu tokens\n", sentence_1st.c_str(), query_1st_size);
|
||||
printf("Q2: <%s> - %zu tokens\n", sentence_2nd.c_str(), query_2nd_size);
|
||||
printf("B : <%s> - %zu tokens\n", task.first.c_str(), base_context.size());
|
||||
printf("base_1 has %zu tokens, base_2 has %zu tokens, skip_choice = %d\n", base_1.size(), base_2.size(), skip_choice);
|
||||
continue;
|
||||
}
|
||||
|
||||
int result = score_1st > score_2nd ? 1 : 2;
|
||||
|
||||
if (result == task.answer) {
|
||||
++n_correct;
|
||||
}
|
||||
++n_done;
|
||||
|
||||
// Print the accumulated accuracy mean x 100
|
||||
printf("%zu\t%.4lf\t%10.6f %10.6f %d %d\n",task_idx+1, 100.0 * n_correct/n_done,score_1st,score_2nd,result,task.answer);
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
printf("\n");
|
||||
|
||||
if (n_done < 100) return;
|
||||
|
||||
const float p = 1.f*n_correct/n_done;
|
||||
const float sigma = 100.f*sqrt(p*(1-p)/(n_done-1));
|
||||
printf("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma);
|
||||
}
|
||||
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
@ -733,6 +964,8 @@ int main(int argc, char ** argv) {
|
||||
struct results_perplexity results;
|
||||
if (params.hellaswag) {
|
||||
hellaswag_score(ctx, params);
|
||||
} else if (params.winogrande) {
|
||||
winogrande_score(ctx, params);
|
||||
} else {
|
||||
results = perplexity(ctx, params);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user