From 6e6587656fef25bd2e411c0af5708ab46df0864d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 20 Oct 2023 17:05:46 +0300 Subject: [PATCH] llama : combine repetition, frequency and presence penalties in 1 call --- common/sampling.cpp | 26 ++++++-------- common/sampling.h | 8 ++--- llama.cpp | 50 +++++++++------------------ llama.h | 10 ++---- tests/test-sampling.cpp | 75 +++++++++++------------------------------ 5 files changed, 51 insertions(+), 118 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 7a10ddbd0..58c9d0dd2 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -3,7 +3,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) { struct llama_sampling_context * result = new llama_sampling_context(); - result->params = params; + result->params = params; result->grammar = nullptr; // if there is a grammar, parse it @@ -71,17 +71,16 @@ llama_token llama_sampling_sample( struct llama_context * ctx_main, struct llama_context * ctx_cfg, const int idx) { - const int n_ctx = llama_n_ctx(ctx_main); - const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); - const llama_sampling_params & params = ctx_sampling->params; + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n; + const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_prev : params.repeat_last_n; const float repeat_penalty = params.repeat_penalty; const float alpha_presence = params.presence_penalty; const float alpha_frequency = params.frequency_penalty; @@ -97,7 +96,7 @@ llama_token llama_sampling_sample( float * logits = llama_get_logits_ith(ctx_main, idx); - // Apply params.logit_bias map + // apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; } @@ -117,14 +116,10 @@ llama_token llama_sampling_sample( // apply penalties if (!prev.empty()) { const float nl_logit = logits[llama_token_nl(ctx_main)]; - const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx); - llama_sample_repetition_penalty(ctx_main, &cur_p, - prev.data() + prev.size() - last_n_repeat, - last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p, - prev.data() + prev.size() - last_n_repeat, - last_n_repeat, alpha_frequency, alpha_presence); + llama_sample_repetition_penalties(ctx_main, &cur_p, + prev.data() + prev.size() - repeat_last_n, + repeat_last_n, repeat_penalty, alpha_frequency, alpha_presence); if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { @@ -141,7 +136,7 @@ llama_token llama_sampling_sample( } if (temp <= 0) { - // Greedy sampling + // greedy sampling id = llama_sample_token_greedy(ctx_main, &cur_p); } else { if (mirostat == 1) { @@ -152,8 +147,9 @@ llama_token llama_sampling_sample( llama_sample_temp(ctx_main, &cur_p, temp); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else { - // Temperature sampling + // temperature sampling size_t min_keep = std::max(1, params.n_probs); + llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); diff --git a/common/sampling.h b/common/sampling.h index 80b6a6b57..227c611d7 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -10,6 +10,8 @@ // sampling parameters typedef struct llama_sampling_params { + int32_t n_prev = 256; // number of previous tokens to remember + int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled @@ -22,11 +24,9 @@ typedef struct llama_sampling_params { int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate - int32_t n_prev = 256; // number of previous tokens to remember - bool penalize_nl = true; // consider newlines as a repeatable token - int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. + std::string grammar; // optional BNF-like grammar to constrain sampling // Classifier-Free Guidance // https://arxiv.org/abs/2306.17806 @@ -35,8 +35,6 @@ typedef struct llama_sampling_params { std::unordered_map logit_bias; // logit bias for specific tokens - std::string grammar = ""; // optional BNF-like grammar to constrain sampling - } llama_sampling_params; // general sampler context diff --git a/llama.cpp b/llama.cpp index a2391e44e..d19d4222c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1018,8 +1018,8 @@ enum e_model { }; static const size_t kB = 1024; -static const size_t MB = kB*kB; -static const size_t GB = kB*kB*kB; +static const size_t MB = 1024*kB; +static const size_t GB = 1024*MB; struct llama_hparams { bool vocab_only; @@ -7414,37 +7414,8 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array llama_sample_temp(ctx, candidates_p, temp); } -void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) { - if (last_tokens_size == 0 || penalty == 1.0f) { - return; - } - - const int64_t t_start_sample_us = ggml_time_us(); - - for (size_t i = 0; i < candidates->size; ++i) { - const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id); - if (token_iter == last_tokens + last_tokens_size) { - continue; - } - - // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. - // This is common fix for this problem, which is to multiply by the penalty instead of dividing. - if (candidates->data[i].logit <= 0) { - candidates->data[i].logit *= penalty; - } else { - candidates->data[i].logit /= penalty; - } - } - - candidates->sorted = false; - - if (ctx) { - ctx->t_sample_us += ggml_time_us() - t_start_sample_us; - } -} - -void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) { - if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) { +void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float repeat_penalty, float alpha_frequency, float alpha_presence) { + if (last_tokens_size == 0 || (repeat_penalty == 1.0f && alpha_frequency == 0.0f && alpha_presence == 0.0f)) { return; } @@ -7458,12 +7429,21 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l // Apply frequency and presence penalties to the candidates for (size_t i = 0; i < candidates->size; ++i) { - auto token_iter = token_count.find(candidates->data[i].id); + const auto token_iter = token_count.find(candidates->data[i].id); if (token_iter == token_count.end()) { continue; } - int count = token_iter->second; + const int count = token_iter->second; + + // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. + // This is common fix for this problem, which is to multiply by the penalty instead of dividing. + if (candidates->data[i].logit <= 0) { + candidates->data[i].logit *= repeat_penalty; + } else { + candidates->data[i].logit /= repeat_penalty; + } + candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence; } diff --git a/llama.h b/llama.h index 51010e037..4bdd6b098 100644 --- a/llama.h +++ b/llama.h @@ -560,19 +560,13 @@ extern "C" { LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - LLAMA_API void llama_sample_repetition_penalty( - struct llama_context * ctx, - llama_token_data_array * candidates, - const llama_token * last_tokens, - size_t last_tokens_size, - float penalty); - /// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - LLAMA_API void llama_sample_frequency_and_presence_penalties( + LLAMA_API void llama_sample_repetition_penalties( struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, + float repeat_penalty, float alpha_frequency, float alpha_presence); diff --git a/tests/test-sampling.cpp b/tests/test-sampling.cpp index 019c0d462..32e58941c 100644 --- a/tests/test-sampling.cpp +++ b/tests/test-sampling.cpp @@ -8,11 +8,9 @@ #include #include #include -#include #include #include - static void dump(const llama_token_data_array * candidates) { for (size_t i = 0; i < candidates->size; i++) { printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit); @@ -21,7 +19,6 @@ static void dump(const llama_token_data_array * candidates) { #define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0) - static void test_top_k(const std::vector & probs, const std::vector & expected_probs, int k) { size_t n_vocab = probs.size(); std::vector candidates; @@ -37,13 +34,12 @@ static void test_top_k(const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float p) { size_t n_vocab = probs.size(); std::vector candidates; @@ -59,13 +55,12 @@ static void test_top_p(const std::vector & probs, const std::vector & probs, const std::vector & expected_probs, float z) { size_t n_vocab = probs.size(); std::vector candidates; @@ -80,13 +75,12 @@ static void test_tfs(const std::vector & probs, const std::vector llama_sample_tail_free(nullptr, &candidates_p, z, 1); DUMP(&candidates_p); - assert(candidates_p.size == expected_probs.size()); + GGML_ASSERT(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); + GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); } } - static void test_typical(const std::vector & probs, const std::vector & expected_probs, float p) { size_t n_vocab = probs.size(); std::vector candidates; @@ -101,18 +95,17 @@ static void test_typical(const std::vector & probs, const std::vector & probs, const std::vector & last_tokens, - const std::vector & expected_probs, float penalty + const std::vector & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence ) { - assert(probs.size() == expected_probs.size()); + GGML_ASSERT(probs.size() == expected_probs.size()); size_t n_vocab = probs.size(); std::vector candidates; @@ -125,41 +118,13 @@ static void test_repetition_penalty( llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; llama_sample_softmax(nullptr, &candidates_p); DUMP(&candidates_p); - llama_sample_repetition_penalty(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), penalty); + llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence); llama_sample_softmax(nullptr, &candidates_p); DUMP(&candidates_p); - assert(candidates_p.size == expected_probs.size()); + GGML_ASSERT(candidates_p.size == expected_probs.size()); for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6); - } -} - - -static void test_frequency_presence_penalty( - const std::vector & probs, const std::vector & last_tokens, - const std::vector & expected_probs, float alpha_frequency, float alpha_presence -) { - assert(probs.size() == expected_probs.size()); - - size_t n_vocab = probs.size(); - std::vector candidates; - candidates.reserve(n_vocab); - for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) { - float logit = log(probs[token_id]); - candidates.emplace_back(llama_token_data{token_id, logit, 0.0f}); - } - - llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; - llama_sample_softmax(nullptr, &candidates_p); - // DUMP(&candidates_p); - llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence); - llama_sample_softmax(nullptr, &candidates_p); - // DUMP(&candidates_p); - - assert(candidates_p.size == expected_probs.size()); - for (size_t i = 0; i < candidates_p.size; i++) { - assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); + GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3); } } @@ -181,13 +146,13 @@ int main(void) { test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f); test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f); - test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f); - test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f); - test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f); - test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 5.0f, 5.0f); - test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 5.0f, 5.0f); - test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f); + test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f); printf("OK\n");