perplexity : fix integer overflow

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-10-08 09:13:54 +03:00
parent 6374743747
commit 22cc760dba
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -10,6 +10,7 @@
#include <cstdio> #include <cstdio>
#include <cstring> #include <cstring>
#include <ctime> #include <ctime>
#include <cinttypes>
#include <fstream> #include <fstream>
#include <mutex> #include <mutex>
#include <random> #include <random>
@ -103,7 +104,7 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
return probs; return probs;
} }
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) { static results_log_softmax log_softmax(int64_t n_vocab, const float * logits, int tok) {
float max_logit = logits[0]; float max_logit = logits[0];
for (int i = 1; i < n_vocab; ++i) { for (int i = 1; i < n_vocab; ++i) {
max_logit = std::max(max_logit, logits[i]); max_logit = std::max(max_logit, logits[i]);
@ -122,7 +123,7 @@ static inline int nearest_int(float fval) {
return (i & 0x007fffff) - 0x00400000; return (i & 0x007fffff) - 0x00400000;
} }
static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob, int tok) { static double log_softmax(int64_t n_vocab, const float * logits, uint16_t * log_prob, int tok) {
float max_logit = logits[0]; float max_logit = logits[0];
float min_logit = logits[0]; float min_logit = logits[0];
for (int i = 1; i < n_vocab; ++i) { for (int i = 1; i < n_vocab; ++i) {
@ -153,7 +154,7 @@ static double log_softmax(int n_vocab, const float * logits, uint16_t * log_prob
} }
static void process_logits( static void process_logits(
int n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers, int64_t n_vocab, const float * logits, const int * tokens, int n_token, std::vector<std::thread> & workers,
double & nll, double & nll2, float * logit_history, float * prob_history double & nll, double & nll2, float * logit_history, float * prob_history
) { ) {
std::mutex mutex; std::mutex mutex;
@ -187,7 +188,7 @@ static void process_logits(
} }
} }
static void process_logits(std::ostream& out, int n_vocab, const float * logits, const int * tokens, int n_token, static void process_logits(std::ostream& out, int64_t n_vocab, const float * logits, const int * tokens, int n_token,
std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) { std::vector<std::thread> & workers, std::vector<uint16_t> & log_probs, double & nll, double & nll2) {
std::mutex mutex; std::mutex mutex;
const int nv = 2*((n_vocab + 1)/2) + 4; const int nv = 2*((n_vocab + 1)/2) + 4;
@ -234,7 +235,7 @@ struct kl_divergence_result {
size_t count = 0.0; size_t count = 0.0;
}; };
static std::pair<double, float> log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) { static std::pair<double, float> log_softmax(int64_t n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
float max_logit = logits[0]; float max_logit = logits[0];
int imax = 0; int imax = 0;
for (int i = 1; i < n_vocab; ++i) { for (int i = 1; i < n_vocab; ++i) {
@ -281,7 +282,9 @@ static std::pair<double, float> log_softmax(int n_vocab, const float * logits, c
kld.sum_kld += sum; kld.sum_kld += sum;
kld.sum_kld2 += sum*sum; kld.sum_kld2 += sum*sum;
++kld.count; ++kld.count;
if (imax == imax_base) ++kld.n_same_top; if (imax == imax_base) {
++kld.n_same_top;
}
const float p_base = expf(-nll_base); const float p_base = expf(-nll_base);
const float p = expf(-nll); const float p = expf(-nll);
@ -295,7 +298,7 @@ static std::pair<double, float> log_softmax(int n_vocab, const float * logits, c
return std::make_pair(sum, p_diff); return std::make_pair(sum, p_diff);
} }
static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, static void process_logits(int64_t n_vocab, const float * logits, const int * tokens, int n_token,
std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld, std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld,
float * kld_values, float * p_diff_values) { float * kld_values, float * p_diff_values) {
std::mutex mutex; std::mutex mutex;
@ -383,9 +386,10 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride; const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int64_t n_vocab = llama_n_vocab(llama_get_model(ctx));
int count = 0; int count = 0;
double nll = 0.0; double nll = 0.0;
@ -521,9 +525,10 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
const int n_chunk_max = tokens.size() / n_ctx; const int n_chunk_max = tokens.size() / n_ctx;
const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int64_t n_vocab = llama_n_vocab(llama_get_model(ctx));
int count = 0; int count = 0;
double nll = 0.0; double nll = 0.0;
double nll2 = 0.0; double nll2 = 0.0;
@ -723,7 +728,7 @@ static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector<
#define K_TOKEN_CHUNK 4 #define K_TOKEN_CHUNK 4
static void compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers, static void compute_logprobs(const float * batch_logits, int64_t n_vocab, std::vector<std::thread>& workers,
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) { const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
if (eval_results.size() != eval_pairs.size()) { if (eval_results.size() != eval_pairs.size()) {
eval_results.resize(eval_pairs.size()); eval_results.resize(eval_pairs.size());
@ -877,10 +882,11 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
double acc = 0.0f; double acc = 0.0f;
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int64_t n_vocab = llama_n_vocab(llama_get_model(ctx));
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1158,10 +1164,11 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__); LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__);
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int64_t n_vocab = llama_n_vocab(llama_get_model(ctx));
const int max_tasks_per_batch = 128; const int max_tasks_per_batch = 128;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1509,10 +1516,11 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
LOG("\ntask\tacc_norm\n"); LOG("\ntask\tacc_norm\n");
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
const int n_ctx = llama_n_ctx(ctx); const int n_ctx = llama_n_ctx(ctx);
const int n_batch = params.n_batch; const int n_batch = params.n_batch;
const int64_t n_vocab = llama_n_vocab(llama_get_model(ctx));
const int max_tasks_per_batch = 32; const int max_tasks_per_batch = 32;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx));
@ -1709,7 +1717,8 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
__func__, params.logits_file.c_str(), n_ctx, params.n_ctx); __func__, params.logits_file.c_str(), n_ctx, params.n_ctx);
} }
int n_vocab, n_chunk; int64_t n_vocab;
int64_t n_chunk;
in.read((char *)&n_vocab, sizeof(n_vocab)); in.read((char *)&n_vocab, sizeof(n_vocab));
in.read((char *)&n_chunk, sizeof(n_chunk)); in.read((char *)&n_chunk, sizeof(n_chunk));
if (in.fail()) { if (in.fail()) {
@ -1717,7 +1726,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
return; return;
} }
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) { if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
LOG_ERR("%s: inconsistent vocabulary (%d vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx))); LOG_ERR("%s: inconsistent vocabulary (%" PRId64 " vs %d)\n", __func__, n_vocab, llama_n_vocab(llama_get_model(ctx)));
} }
std::vector<llama_token> tokens(n_ctx * n_chunk); std::vector<llama_token> tokens(n_ctx * n_chunk);