diff --git a/common/common.cpp b/common/common.cpp index ce20360a4..0e4b8bab2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -203,6 +203,25 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { params.prompt_cache_all = true; } else if (arg == "--prompt-cache-ro") { params.prompt_cache_ro = true; + } else if (arg == "-bf" || arg == "--binary-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i], std::ios::binary); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + // store the external file name in params + params.prompt_file = argv[i]; + file.seekg(0, std::ios::end); + size_t size = file.tellg(); + file.seekg(0, std::ios::beg); + params.prompt.resize(size); + file.read((char *)params.prompt.data(), size); + fprintf(stderr, "Read %zu bytes from binary file %s\n", size, argv[i]); } else if (arg == "-f" || arg == "--file") { if (++i >= argc) { invalid_param = true; @@ -689,6 +708,14 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) { break; } params.winogrande_tasks = std::stoi(argv[i]); + } else if (arg == "--multiple-choice") { + params.multiple_choice = true; + } else if (arg == "--multiple-choice-tasks") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.multiple_choice_tasks = std::stoi(argv[i]); } else if (arg == "--ignore-eos") { params.ignore_eos = true; } else if (arg == "--no-penalize-nl") { @@ -888,6 +915,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); printf(" -f FNAME, --file FNAME\n"); printf(" prompt file to start generation.\n"); + printf(" -bf FNAME, --binary-file FNAME\n"); + printf(" binary file containing multiple choice tasks.\n"); printf(" -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity, -2 = until context filled)\n", params.n_predict); printf(" -c N, --ctx-size N size of the prompt context (default: %d, 0 = loaded from model)\n", params.n_ctx); printf(" -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch); @@ -936,6 +965,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); printf(" --winogrande compute Winogrande score over random tasks from datafile supplied with -f\n"); printf(" --winogrande-tasks N number of tasks to use when computing the Winogrande score (default: %zu)\n", params.winogrande_tasks); + printf(" --multiple-choice compute multiple choice score over random tasks from datafile supplied with -f\n"); + printf(" --multiple-choice-tasks N number of tasks to use when computing the multiple choice score (default: %zu)\n", params.winogrande_tasks); printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); printf(" --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft); printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); diff --git a/common/common.h b/common/common.h index 0ae9c18b3..c69ad7e94 100644 --- a/common/common.h +++ b/common/common.h @@ -108,6 +108,9 @@ struct gpt_params { bool winogrande = false; // compute Winogrande score over random tasks from datafile supplied in prompt size_t winogrande_tasks= 0; // number of tasks to use when computing the Winogrande score. If 0, all tasks will be computed + bool multiple_choice = false; // compute TruthfulQA score over random tasks from datafile supplied in prompt + size_t multiple_choice_tasks = 0; // number of tasks to use when computing the TruthfulQA score. If 0, all tasks will be computed + bool mul_mat_q = true; // if true, use mul_mat_q kernels instead of cuBLAS bool random_prompt = false; // do not randomize prompt if none provided bool use_color = false; // use color to distinguish generations and inputs diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index f91f5795a..b7ef9a084 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -540,14 +540,14 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { // This is needed as usual for LLaMA models const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + // The tasks should be randomized so the score stabilizes quickly. + bool randomize_tasks = true; + // Number of tasks to use when computing the score if (params.hellaswag_tasks < hs_task_count) { hs_task_count = params.hellaswag_tasks; } - // The tasks should be randomized so the score stabilizes quickly. - bool randomize_tasks = true; - // The random seed should not impact the final result if the computation is done over enough tasks, so kept hardcoded for now std::mt19937 rng(1); @@ -1031,6 +1031,389 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { printf("Final Winogrande score(%d tasks): %.4lf +/- %.4lf\n", n_done, 100*p, sigma); } +static bool deserialize_string(std::istream& in, std::string& str) { + uint32_t size; + if (!in.read((char *)&size, sizeof(size)).fail()) { + str.resize(size); + if (!in.read((char *)str.data(), size).fail()) return true; + } + return false; +} + +struct multiple_choice_answers { + std::vector answers; + std::vector labels; + bool deserialize(std::istream& in) { + uint32_t n; + in.read((char *)&n, sizeof(n)); + if (in.fail() || n > 100) return false; // 100 as max. number of answers should be good enough for any practical purpose + answers.resize(n); + labels.resize(n); + for (auto& a : answers) { + if (!deserialize_string(in, a)) return false; + } + in.read((char *)labels.data(), n*sizeof(int)); + return !in.fail(); + } +}; + +struct multiple_choice_task { + std::string question; // the question (or context that needs to be continued) + multiple_choice_answers mc1; // possible answers (continuations) with a single correct answer + multiple_choice_answers mc2; // possible answers (continuations) with multiple correct answers - not handled yet + bool deserialize(std::istream& in) { + if (!deserialize_string(in, question)) return false; + return mc1.deserialize(in) && mc2.deserialize(in); + } + + // For evaluation + size_t i_batch; // starting index in the llama_batch + size_t common_prefix; // max number of initial tokens that are the same in all sentences + size_t required_tokens; // needed number of tokens to evaluate all answers + std::vector> seq_tokens; + std::vector log_probs; +}; + +static bool multiple_choice_prepare_one_task(llama_context * ctx, bool add_bos, multiple_choice_task& task, bool log_error) { + if (task.question.empty() || task.mc1.answers.empty()) { + if (log_error) { + printf("%s: found bad task with empty question and/or answers\n", __func__); + } + return false; + } + task.seq_tokens.reserve(task.mc1.answers.size()); + for (auto& answer : task.mc1.answers) { + if (answer.empty()) { + if (log_error) { + printf("%s: found empty answer\n", __func__); + } + return false; + } + task.seq_tokens.emplace_back(::llama_tokenize(ctx, task.question + " " + answer, add_bos)); + } + auto min_len = task.seq_tokens.front().size(); + for (auto& seq : task.seq_tokens) { + min_len = std::min(min_len, seq.size()); + } + task.common_prefix = 0; + for (size_t k = 0; k < min_len; ++k) { + auto token = task.seq_tokens[0][k]; + bool all_same = true; + for (size_t i = 1; i < task.seq_tokens.size(); ++i) { + if (task.seq_tokens[i][k] != token) { + all_same = false; + break; + } + } + if (!all_same) { + break; + } + ++task.common_prefix; + } + task.required_tokens = task.common_prefix; + for (auto& seq : task.seq_tokens) { + task.required_tokens += seq.size() - task.common_prefix; + } + return true; +} + +// +// Calculates score for multiple choice tasks with single correct answer from prompt. +// Commonly used LLM evaluation metrics of this type are +// * ARC +// * HellaSwag +// * MMLU +// * TruthfulQA +// +// Validation datasets for these 4 tests can be found at +// https://huggingface.co/datasets/ikawrakow/validation-datasets-for-llama.cpp +// The data for these datasets was extracted from +// git@hf.co:datasets/allenai/ai2_arc +// https://github.com/rowanz/hellaswag/blob/master/data/hellaswag_val.jsonl +// git@hf.co:datasets/Stevross/mmlu +// https://huggingface.co/datasets/truthful_qa +// +static void multiple_choice_score(llama_context * ctx, const gpt_params & params) { + + std::istringstream strstream(params.prompt); + uint32_t n_task; + strstream.read((char *)&n_task, sizeof(n_task)); + if (strstream.fail() || n_task == 0) { + printf("%s: no tasks\n", __func__); + return; + } + printf("%s: there are %u tasks in prompt\n", __func__, n_task); + std::vector task_pos(n_task); + strstream.read((char *)task_pos.data(), task_pos.size()*sizeof(uint32_t)); + if (strstream.fail()) { + printf("%s: failed to raad task positions from prompt\n", __func__); + return; + } + + std::vector tasks; + if (params.multiple_choice_tasks == 0 || params.multiple_choice_tasks >= (size_t)n_task) { + // Use all tasks + tasks.resize(n_task); + printf("%s: reading tasks", __func__); + int n_dot = n_task/100; + int i = 0; + for (auto& task : tasks) { + ++i; + if (!task.deserialize(strstream)) { + printf("%s: failed to read task %d of %u\n", __func__, i, n_task); + return; + } + if (i%n_dot == 0) printf("."); + } + printf("done\n"); + } + else { + printf("%s: selecting %zu random tasks from %u tasks available\n", __func__, params.multiple_choice_tasks, n_task); + std::mt19937 rng(1); + std::vector aux(n_task); + for (uint32_t i = 0; i < n_task; ++i) aux[i] = i; + float scale = 1.f/(1.f + (float)std::mt19937::max()); + tasks.resize(params.multiple_choice_tasks); + for (auto& task : tasks) { + int j = (int)(scale * rng() * aux.size()); + int idx = aux[j]; + aux[j] = aux.back(); + aux.pop_back(); + strstream.seekg(task_pos[idx], std::ios::beg); + if (!task.deserialize(strstream)) { + printf("%s: failed to read task %d at position %u\n", __func__, idx, task_pos[idx]); + return; + } + } + n_task = params.multiple_choice_tasks; + } + + // This is needed as usual for LLaMA models + const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); + + printf("%s: preparing task data", __func__); + fflush(stdout); + if (n_task > 500) { + printf("..."); + fflush(stdout); + std::atomic counter(0); + std::atomic n_bad(0); + auto prepare = [&counter, &n_bad, &tasks, ctx, add_bos] () { + int num_tasks = tasks.size(); + int n_bad_local = 0; + while (true) { + int first = counter.fetch_add(K_TOKEN_CHUNK); + if (first >= num_tasks) { + if (n_bad_local > 0) n_bad += n_bad_local; + break; + } + int last = std::min(first + K_TOKEN_CHUNK, num_tasks); + for (int i = first; i < last; ++i) { + if (!multiple_choice_prepare_one_task(ctx, add_bos, tasks[i], false)) ++n_bad_local; + } + } + }; + size_t max_thread = std::thread::hardware_concurrency(); + max_thread = std::min(max_thread, (tasks.size() + K_TOKEN_CHUNK - 1)/K_TOKEN_CHUNK); + std::vector workers(max_thread-1); + for (auto& w : workers) w = std::thread(prepare); + prepare(); + for (auto& w : workers) w.join(); + printf("done\n"); + fflush(stdout); + int nbad = n_bad; + if (nbad > 0) { + printf("%s: found %d malformed tasks\n", __func__, nbad); + return; + } + } else { + int n_dot = n_task/100; + int i_task = 0; + for (auto& task : tasks) { + ++i_task; + if (!multiple_choice_prepare_one_task(ctx, add_bos, task, true)) { + return; + } + if (i_task%n_dot == 0) { + printf("."); + fflush(stdout); + } + } + printf("done\n"); + } + + printf("%s : calculating TruthfulQA score over %zu tasks.\n", __func__, tasks.size()); + + printf("\ntask\tacc_norm\n"); + + const int n_vocab = llama_n_vocab(llama_get_model(ctx)); + const int n_ctx = llama_n_ctx(ctx); + const int n_batch = params.n_batch; + + const int max_tasks_per_batch = 32; + const int max_seq = 4*max_tasks_per_batch; + + llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); + + std::vector tok_logits(n_vocab); + std::vector batch_logits(n_vocab*n_ctx); + + std::vector> eval_pairs; + std::vector eval_results; + std::vector workers(std::thread::hardware_concurrency()); + std::vector batch_indeces; + + int n_done = 0; + int n_correct = 0; + int n_tot_answers = 0; + + for (size_t i0 = 0; i0 < tasks.size(); i0++) { + int n_cur = 0; + + size_t i1 = i0; + size_t i_batch = 0; // this tells us where in `llama_batch` we are currently + + llama_batch_clear(batch); + + // batch as much tasks as possible into the available context + // each task has 4 unique seuqnce ids - one for each ending + // the common prefix is shared among the 4 sequences to save tokens + // we extract logits only from the last common token and from all ending tokens of each sequence + int s0 = 0; + while (n_cur + (int) tasks[i1].required_tokens <= n_ctx) { + auto& cur_task = tasks[i1]; + + int num_answers = cur_task.seq_tokens.size(); + if (s0 + num_answers > max_seq) { + break; + } + + if (int(batch_indeces.size()) != num_answers) { + batch_indeces.resize(num_answers); + } + for (int s = 0; s < num_answers; ++s) batch_indeces[s] = s0 + s; + + for (size_t i = 0; i < cur_task.common_prefix; ++i) { + //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); + llama_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); + } + batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + + for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { + for (size_t i = cur_task.common_prefix; i < cur_task.seq_tokens[s].size(); ++i) { + llama_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, true); + } + } + + s0 += num_answers; + + cur_task.i_batch = i_batch; + i_batch += cur_task.required_tokens; + + n_cur += cur_task.required_tokens; + if (++i1 == tasks.size()) { + break; + } + } + + if (i0 == i1) { + fprintf(stderr, "%s : task %zu does not fit in the context window\n", __func__, i0); + return; + } + + llama_kv_cache_clear(ctx); + + // decode all tasks [i0, i1) + if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { + fprintf(stderr, "%s: llama_decode() failed\n", __func__); + return; + } + + // Compute log-probs in parallel + // First we collect all tasks + eval_pairs.clear(); + for (size_t i = i0; i < i1; ++i) { + auto& cur_task = tasks[i]; + size_t li = cur_task.common_prefix; + for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { + for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) { + eval_pairs.push_back(std::make_pair(cur_task.i_batch + li++, cur_task.seq_tokens[s][j + 1])); + } + ++li; + } + } + // Then we do the actual calculation + compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results); + + size_t ir = 0; + + // compute the logprobs for each ending of the decoded tasks + for (size_t i = i0; i < i1; ++i) { + auto & cur_task = tasks[i]; + //printf("==== Evaluating <%s> with correct answer ", cur_task.question.c_str()); + //for (int j = 0; j < int(cur_task.mc1.labels.size()); ++j) { + // if (cur_task.mc1.labels[j] == 1) { + // printf("%d", j+1); + // } + //} + //printf("\n common_prefix: %zu\n", cur_task.common_prefix); + + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(cur_task.i_batch + cur_task.common_prefix - 1), n_vocab*sizeof(float)); + + const auto first_probs = softmax(tok_logits); + + cur_task.log_probs.resize(cur_task.seq_tokens.size()); + for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { + size_t count = 1; + float log_prob = std::log(first_probs[cur_task.seq_tokens[s][cur_task.common_prefix]]); + for (size_t j = cur_task.common_prefix; j < cur_task.seq_tokens[s].size() - 1; j++) { + //printf(" %zu %g\n", ir, eval_results[ir]); + ++count; + log_prob += eval_results[ir++]; + } + cur_task.log_probs[s] = log_prob / count; + //printf(" Final: %g\n", log_prob / count); + //printf(" <%s> : %g\n", cur_task.mc1.answers[s].c_str(), log_prob/count); + } + + // Find the ending with maximum logprob + size_t logprob_max_idx = 0; + float logprob_max_val = cur_task.log_probs[0]; + for (size_t s = 1; s < cur_task.log_probs.size(); s++) { + if (cur_task.log_probs[s] > logprob_max_val) { + logprob_max_val = cur_task.log_probs[s]; + logprob_max_idx = s; + } + } + + n_tot_answers += cur_task.log_probs.size(); + if (cur_task.mc1.labels[logprob_max_idx] == 1) { + ++n_correct; + } + ++n_done; + + // Print the accumulated accuracy mean x 100 + printf("%d\t%.8lf\n", n_done, 100.*n_correct/n_done); + fflush(stdout); + } + + i0 = i1 - 1; + } + + llama_batch_free(batch); + + if (n_done < 100) return; + + float p = 1.f*n_correct/n_done; + float sigma = sqrt(p*(1-p)/(n_done-1)); + printf("\n Final result: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma); + p = 1.f*n_done/n_tot_answers; + sigma = sqrt(p*(1-p)/(n_done-1)); + printf("Random chance: %.4f +/- %.4f\n", 100.f*p, 100.f*sigma); + + printf("\n"); +} + int main(int argc, char ** argv) { gpt_params params; @@ -1091,6 +1474,8 @@ int main(int argc, char ** argv) { hellaswag_score(ctx, params); } else if (params.winogrande) { winogrande_score(ctx, params); + } else if (params.multiple_choice) { + multiple_choice_score(ctx, params); } else { results = perplexity(ctx, params); }