mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
llama : combine repetition, frequency and presence penalties in 1 call
This commit is contained in:
parent
cd1e937821
commit
6e6587656f
@ -3,7 +3,7 @@
|
||||
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params) {
|
||||
struct llama_sampling_context * result = new llama_sampling_context();
|
||||
|
||||
result->params = params;
|
||||
result->params = params;
|
||||
result->grammar = nullptr;
|
||||
|
||||
// if there is a grammar, parse it
|
||||
@ -71,17 +71,16 @@ llama_token llama_sampling_sample(
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_cfg,
|
||||
const int idx) {
|
||||
const int n_ctx = llama_n_ctx(ctx_main);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
||||
const float temp = params.temp;
|
||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
const float tfs_z = params.tfs_z;
|
||||
const float typical_p = params.typical_p;
|
||||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
|
||||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_prev : params.repeat_last_n;
|
||||
const float repeat_penalty = params.repeat_penalty;
|
||||
const float alpha_presence = params.presence_penalty;
|
||||
const float alpha_frequency = params.frequency_penalty;
|
||||
@ -97,7 +96,7 @@ llama_token llama_sampling_sample(
|
||||
|
||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||
|
||||
// Apply params.logit_bias map
|
||||
// apply params.logit_bias map
|
||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||
logits[it->first] += it->second;
|
||||
}
|
||||
@ -117,14 +116,10 @@ llama_token llama_sampling_sample(
|
||||
// apply penalties
|
||||
if (!prev.empty()) {
|
||||
const float nl_logit = logits[llama_token_nl(ctx_main)];
|
||||
const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
|
||||
|
||||
llama_sample_repetition_penalty(ctx_main, &cur_p,
|
||||
prev.data() + prev.size() - last_n_repeat,
|
||||
last_n_repeat, repeat_penalty);
|
||||
llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
|
||||
prev.data() + prev.size() - last_n_repeat,
|
||||
last_n_repeat, alpha_frequency, alpha_presence);
|
||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||
prev.data() + prev.size() - repeat_last_n,
|
||||
repeat_last_n, repeat_penalty, alpha_frequency, alpha_presence);
|
||||
|
||||
if (!penalize_nl) {
|
||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||
@ -141,7 +136,7 @@ llama_token llama_sampling_sample(
|
||||
}
|
||||
|
||||
if (temp <= 0) {
|
||||
// Greedy sampling
|
||||
// greedy sampling
|
||||
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
||||
} else {
|
||||
if (mirostat == 1) {
|
||||
@ -152,8 +147,9 @@ llama_token llama_sampling_sample(
|
||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
// temperature sampling
|
||||
size_t min_keep = std::max(1, params.n_probs);
|
||||
|
||||
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
||||
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
||||
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
||||
|
@ -10,6 +10,8 @@
|
||||
|
||||
// sampling parameters
|
||||
typedef struct llama_sampling_params {
|
||||
int32_t n_prev = 256; // number of previous tokens to remember
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
int32_t top_k = 40; // <= 0 to use vocab size
|
||||
float top_p = 0.95f; // 1.0 = disabled
|
||||
float tfs_z = 1.00f; // 1.0 = disabled
|
||||
@ -22,11 +24,9 @@ typedef struct llama_sampling_params {
|
||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||
float mirostat_tau = 5.00f; // target entropy
|
||||
float mirostat_eta = 0.10f; // learning rate
|
||||
int32_t n_prev = 256; // number of previous tokens to remember
|
||||
|
||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||
|
||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||
|
||||
// Classifier-Free Guidance
|
||||
// https://arxiv.org/abs/2306.17806
|
||||
@ -35,8 +35,6 @@ typedef struct llama_sampling_params {
|
||||
|
||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||
|
||||
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
||||
|
||||
} llama_sampling_params;
|
||||
|
||||
// general sampler context
|
||||
|
50
llama.cpp
50
llama.cpp
@ -1018,8 +1018,8 @@ enum e_model {
|
||||
};
|
||||
|
||||
static const size_t kB = 1024;
|
||||
static const size_t MB = kB*kB;
|
||||
static const size_t GB = kB*kB*kB;
|
||||
static const size_t MB = 1024*kB;
|
||||
static const size_t GB = 1024*MB;
|
||||
|
||||
struct llama_hparams {
|
||||
bool vocab_only;
|
||||
@ -7414,37 +7414,8 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array
|
||||
llama_sample_temp(ctx, candidates_p, temp);
|
||||
}
|
||||
|
||||
void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) {
|
||||
if (last_tokens_size == 0 || penalty == 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id);
|
||||
if (token_iter == last_tokens + last_tokens_size) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||
if (candidates->data[i].logit <= 0) {
|
||||
candidates->data[i].logit *= penalty;
|
||||
} else {
|
||||
candidates->data[i].logit /= penalty;
|
||||
}
|
||||
}
|
||||
|
||||
candidates->sorted = false;
|
||||
|
||||
if (ctx) {
|
||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
|
||||
if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
|
||||
void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float repeat_penalty, float alpha_frequency, float alpha_presence) {
|
||||
if (last_tokens_size == 0 || (repeat_penalty == 1.0f && alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
|
||||
return;
|
||||
}
|
||||
|
||||
@ -7458,12 +7429,21 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
|
||||
|
||||
// Apply frequency and presence penalties to the candidates
|
||||
for (size_t i = 0; i < candidates->size; ++i) {
|
||||
auto token_iter = token_count.find(candidates->data[i].id);
|
||||
const auto token_iter = token_count.find(candidates->data[i].id);
|
||||
if (token_iter == token_count.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
int count = token_iter->second;
|
||||
const int count = token_iter->second;
|
||||
|
||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||
if (candidates->data[i].logit <= 0) {
|
||||
candidates->data[i].logit *= repeat_penalty;
|
||||
} else {
|
||||
candidates->data[i].logit /= repeat_penalty;
|
||||
}
|
||||
|
||||
candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence;
|
||||
}
|
||||
|
||||
|
10
llama.h
10
llama.h
@ -560,19 +560,13 @@ extern "C" {
|
||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||
|
||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||
LLAMA_API void llama_sample_repetition_penalty(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t last_tokens_size,
|
||||
float penalty);
|
||||
|
||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(
|
||||
LLAMA_API void llama_sample_repetition_penalties(
|
||||
struct llama_context * ctx,
|
||||
llama_token_data_array * candidates,
|
||||
const llama_token * last_tokens,
|
||||
size_t last_tokens_size,
|
||||
float repeat_penalty,
|
||||
float alpha_frequency,
|
||||
float alpha_presence);
|
||||
|
||||
|
@ -8,11 +8,9 @@
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
#include <cassert>
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
||||
|
||||
static void dump(const llama_token_data_array * candidates) {
|
||||
for (size_t i = 0; i < candidates->size; i++) {
|
||||
printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
|
||||
@ -21,7 +19,6 @@ static void dump(const llama_token_data_array * candidates) {
|
||||
|
||||
#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
|
||||
|
||||
|
||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
@ -37,13 +34,12 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
||||
llama_sample_top_k(nullptr, &candidates_p, k, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
|
||||
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
@ -59,13 +55,12 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
||||
llama_sample_top_p(nullptr, &candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
@ -80,13 +75,12 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
||||
llama_sample_tail_free(nullptr, &candidates_p, z, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
@ -101,18 +95,17 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
||||
llama_sample_typical(nullptr, &candidates_p, p, 1);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void test_repetition_penalty(
|
||||
static void test_repetition_penalties(
|
||||
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
||||
const std::vector<float> & expected_probs, float penalty
|
||||
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
|
||||
) {
|
||||
assert(probs.size() == expected_probs.size());
|
||||
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
@ -125,41 +118,13 @@ static void test_repetition_penalty(
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
llama_sample_repetition_penalty(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), penalty);
|
||||
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void test_frequency_presence_penalty(
|
||||
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
||||
const std::vector<float> & expected_probs, float alpha_frequency, float alpha_presence
|
||||
) {
|
||||
assert(probs.size() == expected_probs.size());
|
||||
|
||||
size_t n_vocab = probs.size();
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
||||
float logit = log(probs[token_id]);
|
||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
// DUMP(&candidates_p);
|
||||
llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence);
|
||||
llama_sample_softmax(nullptr, &candidates_p);
|
||||
// DUMP(&candidates_p);
|
||||
|
||||
assert(candidates_p.size == expected_probs.size());
|
||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||
}
|
||||
}
|
||||
|
||||
@ -181,13 +146,13 @@ int main(void) {
|
||||
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
||||
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
||||
|
||||
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f);
|
||||
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
|
||||
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
|
||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
|
||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
||||
|
||||
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 5.0f, 5.0f);
|
||||
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 5.0f, 5.0f);
|
||||
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f);
|
||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
|
||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
||||
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
||||
|
||||
printf("OK\n");
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user