Add back top_k (#56)

* Add back top_k

* Update utils.cpp

* Update utils.h

---------

Co-authored-by: Bill Hamilton <bill.hamilton@shopify.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
beiller 2023-03-12 16:23:15 -04:00 committed by GitHub
parent eb062bb012
commit 02f0c6fe7f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 12 additions and 89 deletions

View File

@ -825,6 +825,7 @@ int main(int argc, char ** argv) {
if (i >= embd_inp.size()) { if (i >= embd_inp.size()) {
// sample next token // sample next token
const float top_k = params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
const float temp = params.temp; const float temp = params.temp;
const float repeat_penalty = params.repeat_penalty; const float repeat_penalty = params.repeat_penalty;
@ -836,7 +837,7 @@ int main(int argc, char ** argv) {
{ {
const int64_t t_start_sample_us = ggml_time_us(); const int64_t t_start_sample_us = ggml_time_us();
id = llama_sample_top_p(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_p, temp, rng); id = llama_sample_top_p_top_k(vocab, logits.data() + (logits.size() - n_vocab), last_n_tokens, repeat_penalty, top_k, top_p, temp, rng);
last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.erase(last_n_tokens.begin());
last_n_tokens.push_back(id); last_n_tokens.push_back(id);

View File

@ -301,25 +301,8 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
return true; return true;
} }
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const float * logits,
int top_k,
double top_p,
double temp,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
std::vector<std::pair<double, gpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const double scale = 1.0/temp;
for (int i = 0; i < n_logits; ++i) {
logits_id.push_back(std::make_pair(logits[i]*scale, i));
}
}
void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k) {
// find the top K tokens // find the top K tokens
std::partial_sort( std::partial_sort(
logits_id.begin(), logits_id.begin(),
@ -329,63 +312,14 @@ gpt_vocab::id gpt_sample_top_k_top_p(
}); });
logits_id.resize(top_k); logits_id.resize(top_k);
double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
} }
// compute probs for the top K tokens gpt_vocab::id llama_sample_top_p_top_k(
std::vector<double> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}
//printf("\n");
//for (int i = 0; i < (int) probs.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
//}
//exit(0);
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
gpt_vocab::id llama_sample_top_p(
const gpt_vocab & vocab, const gpt_vocab & vocab,
const float * logits, const float * logits,
std::vector<gpt_vocab::id> & last_n_tokens, std::vector<gpt_vocab::id> & last_n_tokens,
double repeat_penalty, double repeat_penalty,
int top_k,
double top_p, double top_p,
double temp, double temp,
std::mt19937 & rng) { std::mt19937 & rng) {
@ -412,12 +346,7 @@ gpt_vocab::id llama_sample_top_p(
} }
} }
std::sort( sample_top_k(logits_id, top_k);
logits_id.begin(),
logits_id.end(),
[](const std::pair<double, gpt_vocab::id> & a, const std::pair<double, gpt_vocab::id> & b) {
return a.first > b.first;
});
double maxl = -INFINITY; double maxl = -INFINITY;
for (const auto & kv : logits_id) { for (const auto & kv : logits_id) {

19
utils.h
View File

@ -19,7 +19,7 @@ struct gpt_params {
int32_t repeat_last_n = 64; // last n tokens to penalize int32_t repeat_last_n = 64; // last n tokens to penalize
// sampling parameters // sampling parameters
int32_t top_k = 40; // unused int32_t top_k = 40;
float top_p = 0.95f; float top_p = 0.95f;
float temp = 0.80f; float temp = 0.80f;
float repeat_penalty = 1.30f; float repeat_penalty = 1.30f;
@ -77,25 +77,18 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// - consider only the top K tokens // - consider only the top K tokens
// - from them, consider only the top tokens with cumulative probability > P // - from them, consider only the top tokens with cumulative probability > P
// //
// TODO: not sure if this implementation is correct gpt_vocab::id llama_sample_top_p_top_k(
// TODO: temperature is not implemented
//
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab, const gpt_vocab & vocab,
const float * logits, const float * logits,
std::vector<gpt_vocab::id> & last_n_tokens,
double repeat_penalty,
int top_k, int top_k,
double top_p, double top_p,
double temp, double temp,
std::mt19937 & rng); std::mt19937 & rng);
gpt_vocab::id llama_sample_top_p( // filer to top K tokens from list of logits
const gpt_vocab & vocab, void sample_top_k(std::vector<std::pair<double, gpt_vocab::id>> & logits_id, int top_k);
const float * logits,
std::vector<gpt_vocab::id> & last_n_tokens,
double repeat_penalty,
double top_p,
double temp,
std::mt19937 & rng);
// //
// Quantization // Quantization