mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-27 03:44:35 +00:00
llama : combine repetition, frequency and presence penalties in 1 call
This commit is contained in:
parent
cd1e937821
commit
6e6587656f
@ -71,17 +71,16 @@ llama_token llama_sampling_sample(
|
|||||||
struct llama_context * ctx_main,
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_cfg,
|
struct llama_context * ctx_cfg,
|
||||||
const int idx) {
|
const int idx) {
|
||||||
const int n_ctx = llama_n_ctx(ctx_main);
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
|
||||||
|
|
||||||
const llama_sampling_params & params = ctx_sampling->params;
|
const llama_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||||
const float top_p = params.top_p;
|
const float top_p = params.top_p;
|
||||||
const float tfs_z = params.tfs_z;
|
const float tfs_z = params.tfs_z;
|
||||||
const float typical_p = params.typical_p;
|
const float typical_p = params.typical_p;
|
||||||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
|
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_prev : params.repeat_last_n;
|
||||||
const float repeat_penalty = params.repeat_penalty;
|
const float repeat_penalty = params.repeat_penalty;
|
||||||
const float alpha_presence = params.presence_penalty;
|
const float alpha_presence = params.presence_penalty;
|
||||||
const float alpha_frequency = params.frequency_penalty;
|
const float alpha_frequency = params.frequency_penalty;
|
||||||
@ -97,7 +96,7 @@ llama_token llama_sampling_sample(
|
|||||||
|
|
||||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
// Apply params.logit_bias map
|
// apply params.logit_bias map
|
||||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||||
logits[it->first] += it->second;
|
logits[it->first] += it->second;
|
||||||
}
|
}
|
||||||
@ -117,14 +116,10 @@ llama_token llama_sampling_sample(
|
|||||||
// apply penalties
|
// apply penalties
|
||||||
if (!prev.empty()) {
|
if (!prev.empty()) {
|
||||||
const float nl_logit = logits[llama_token_nl(ctx_main)];
|
const float nl_logit = logits[llama_token_nl(ctx_main)];
|
||||||
const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
|
|
||||||
|
|
||||||
llama_sample_repetition_penalty(ctx_main, &cur_p,
|
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||||
prev.data() + prev.size() - last_n_repeat,
|
prev.data() + prev.size() - repeat_last_n,
|
||||||
last_n_repeat, repeat_penalty);
|
repeat_last_n, repeat_penalty, alpha_frequency, alpha_presence);
|
||||||
llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
|
|
||||||
prev.data() + prev.size() - last_n_repeat,
|
|
||||||
last_n_repeat, alpha_frequency, alpha_presence);
|
|
||||||
|
|
||||||
if (!penalize_nl) {
|
if (!penalize_nl) {
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
@ -141,7 +136,7 @@ llama_token llama_sampling_sample(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (temp <= 0) {
|
if (temp <= 0) {
|
||||||
// Greedy sampling
|
// greedy sampling
|
||||||
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
||||||
} else {
|
} else {
|
||||||
if (mirostat == 1) {
|
if (mirostat == 1) {
|
||||||
@ -152,8 +147,9 @@ llama_token llama_sampling_sample(
|
|||||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||||
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
||||||
} else {
|
} else {
|
||||||
// Temperature sampling
|
// temperature sampling
|
||||||
size_t min_keep = std::max(1, params.n_probs);
|
size_t min_keep = std::max(1, params.n_probs);
|
||||||
|
|
||||||
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
||||||
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
||||||
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
||||||
|
@ -10,6 +10,8 @@
|
|||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct llama_sampling_params {
|
typedef struct llama_sampling_params {
|
||||||
|
int32_t n_prev = 256; // number of previous tokens to remember
|
||||||
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
int32_t top_k = 40; // <= 0 to use vocab size
|
int32_t top_k = 40; // <= 0 to use vocab size
|
||||||
float top_p = 0.95f; // 1.0 = disabled
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
float tfs_z = 1.00f; // 1.0 = disabled
|
float tfs_z = 1.00f; // 1.0 = disabled
|
||||||
@ -22,11 +24,9 @@ typedef struct llama_sampling_params {
|
|||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
int32_t n_prev = 256; // number of previous tokens to remember
|
|
||||||
|
|
||||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||||
|
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
std::string grammar; // optional BNF-like grammar to constrain sampling
|
||||||
|
|
||||||
// Classifier-Free Guidance
|
// Classifier-Free Guidance
|
||||||
// https://arxiv.org/abs/2306.17806
|
// https://arxiv.org/abs/2306.17806
|
||||||
@ -35,8 +35,6 @@ typedef struct llama_sampling_params {
|
|||||||
|
|
||||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||||
|
|
||||||
std::string grammar = ""; // optional BNF-like grammar to constrain sampling
|
|
||||||
|
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
|
50
llama.cpp
50
llama.cpp
@ -1018,8 +1018,8 @@ enum e_model {
|
|||||||
};
|
};
|
||||||
|
|
||||||
static const size_t kB = 1024;
|
static const size_t kB = 1024;
|
||||||
static const size_t MB = kB*kB;
|
static const size_t MB = 1024*kB;
|
||||||
static const size_t GB = kB*kB*kB;
|
static const size_t GB = 1024*MB;
|
||||||
|
|
||||||
struct llama_hparams {
|
struct llama_hparams {
|
||||||
bool vocab_only;
|
bool vocab_only;
|
||||||
@ -7414,37 +7414,8 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array
|
|||||||
llama_sample_temp(ctx, candidates_p, temp);
|
llama_sample_temp(ctx, candidates_p, temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty) {
|
void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float repeat_penalty, float alpha_frequency, float alpha_presence) {
|
||||||
if (last_tokens_size == 0 || penalty == 1.0f) {
|
if (last_tokens_size == 0 || (repeat_penalty == 1.0f && alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
|
||||||
|
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
|
||||||
const auto * token_iter = std::find(last_tokens, last_tokens + last_tokens_size, candidates->data[i].id);
|
|
||||||
if (token_iter == last_tokens + last_tokens_size) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
|
||||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
|
||||||
if (candidates->data[i].logit <= 0) {
|
|
||||||
candidates->data[i].logit *= penalty;
|
|
||||||
} else {
|
|
||||||
candidates->data[i].logit /= penalty;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
candidates->sorted = false;
|
|
||||||
|
|
||||||
if (ctx) {
|
|
||||||
ctx->t_sample_us += ggml_time_us() - t_start_sample_us;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float alpha_frequency, float alpha_presence) {
|
|
||||||
if (last_tokens_size == 0 || (alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -7458,12 +7429,21 @@ void llama_sample_frequency_and_presence_penalties(struct llama_context * ctx, l
|
|||||||
|
|
||||||
// Apply frequency and presence penalties to the candidates
|
// Apply frequency and presence penalties to the candidates
|
||||||
for (size_t i = 0; i < candidates->size; ++i) {
|
for (size_t i = 0; i < candidates->size; ++i) {
|
||||||
auto token_iter = token_count.find(candidates->data[i].id);
|
const auto token_iter = token_count.find(candidates->data[i].id);
|
||||||
if (token_iter == token_count.end()) {
|
if (token_iter == token_count.end()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
int count = token_iter->second;
|
const int count = token_iter->second;
|
||||||
|
|
||||||
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||||
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||||
|
if (candidates->data[i].logit <= 0) {
|
||||||
|
candidates->data[i].logit *= repeat_penalty;
|
||||||
|
} else {
|
||||||
|
candidates->data[i].logit /= repeat_penalty;
|
||||||
|
}
|
||||||
|
|
||||||
candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence;
|
candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
10
llama.h
10
llama.h
@ -560,19 +560,13 @@ extern "C" {
|
|||||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||||
|
|
||||||
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
/// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
|
||||||
LLAMA_API void llama_sample_repetition_penalty(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
llama_token_data_array * candidates,
|
|
||||||
const llama_token * last_tokens,
|
|
||||||
size_t last_tokens_size,
|
|
||||||
float penalty);
|
|
||||||
|
|
||||||
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
/// @details Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
|
||||||
LLAMA_API void llama_sample_frequency_and_presence_penalties(
|
LLAMA_API void llama_sample_repetition_penalties(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
const llama_token * last_tokens,
|
const llama_token * last_tokens,
|
||||||
size_t last_tokens_size,
|
size_t last_tokens_size,
|
||||||
|
float repeat_penalty,
|
||||||
float alpha_frequency,
|
float alpha_frequency,
|
||||||
float alpha_presence);
|
float alpha_presence);
|
||||||
|
|
||||||
|
@ -8,11 +8,9 @@
|
|||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <iostream>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
||||||
|
|
||||||
static void dump(const llama_token_data_array * candidates) {
|
static void dump(const llama_token_data_array * candidates) {
|
||||||
for (size_t i = 0; i < candidates->size; i++) {
|
for (size_t i = 0; i < candidates->size; i++) {
|
||||||
printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
|
printf("%d: %f (%f)\n", candidates->data[i].id, candidates->data[i].p, candidates->data[i].logit);
|
||||||
@ -21,7 +19,6 @@ static void dump(const llama_token_data_array * candidates) {
|
|||||||
|
|
||||||
#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
|
#define DUMP(__candidates) do { printf("%s:%d (%s)\n", __FILE__, __LINE__, __func__); dump((__candidates)); printf("-\n"); } while(0)
|
||||||
|
|
||||||
|
|
||||||
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
static void test_top_k(const std::vector<float> & probs, const std::vector<float> & expected_probs, int k) {
|
||||||
size_t n_vocab = probs.size();
|
size_t n_vocab = probs.size();
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -37,13 +34,12 @@ static void test_top_k(const std::vector<float> & probs, const std::vector<float
|
|||||||
llama_sample_top_k(nullptr, &candidates_p, k, 1);
|
llama_sample_top_k(nullptr, &candidates_p, k, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
|
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_top_p(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
size_t n_vocab = probs.size();
|
size_t n_vocab = probs.size();
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -59,13 +55,12 @@ static void test_top_p(const std::vector<float> & probs, const std::vector<float
|
|||||||
llama_sample_top_p(nullptr, &candidates_p, p, 1);
|
llama_sample_top_p(nullptr, &candidates_p, p, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
static void test_tfs(const std::vector<float> & probs, const std::vector<float> & expected_probs, float z) {
|
||||||
size_t n_vocab = probs.size();
|
size_t n_vocab = probs.size();
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -80,13 +75,12 @@ static void test_tfs(const std::vector<float> & probs, const std::vector<float>
|
|||||||
llama_sample_tail_free(nullptr, &candidates_p, z, 1);
|
llama_sample_tail_free(nullptr, &candidates_p, z, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
static void test_typical(const std::vector<float> & probs, const std::vector<float> & expected_probs, float p) {
|
||||||
size_t n_vocab = probs.size();
|
size_t n_vocab = probs.size();
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -101,18 +95,17 @@ static void test_typical(const std::vector<float> & probs, const std::vector<flo
|
|||||||
llama_sample_typical(nullptr, &candidates_p, p, 1);
|
llama_sample_typical(nullptr, &candidates_p, p, 1);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void test_repetition_penalties(
|
||||||
static void test_repetition_penalty(
|
|
||||||
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
||||||
const std::vector<float> & expected_probs, float penalty
|
const std::vector<float> & expected_probs, float repeat_penalty, float alpha_frequency, float alpha_presence
|
||||||
) {
|
) {
|
||||||
assert(probs.size() == expected_probs.size());
|
GGML_ASSERT(probs.size() == expected_probs.size());
|
||||||
|
|
||||||
size_t n_vocab = probs.size();
|
size_t n_vocab = probs.size();
|
||||||
std::vector<llama_token_data> candidates;
|
std::vector<llama_token_data> candidates;
|
||||||
@ -125,41 +118,13 @@ static void test_repetition_penalty(
|
|||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sample_softmax(nullptr, &candidates_p);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
llama_sample_repetition_penalty(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), penalty);
|
llama_sample_repetition_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), repeat_penalty, alpha_frequency, alpha_presence);
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
llama_sample_softmax(nullptr, &candidates_p);
|
||||||
DUMP(&candidates_p);
|
DUMP(&candidates_p);
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
GGML_ASSERT(candidates_p.size == expected_probs.size());
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
for (size_t i = 0; i < candidates_p.size; i++) {
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-6);
|
GGML_ASSERT(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
static void test_frequency_presence_penalty(
|
|
||||||
const std::vector<float> & probs, const std::vector<llama_token> & last_tokens,
|
|
||||||
const std::vector<float> & expected_probs, float alpha_frequency, float alpha_presence
|
|
||||||
) {
|
|
||||||
assert(probs.size() == expected_probs.size());
|
|
||||||
|
|
||||||
size_t n_vocab = probs.size();
|
|
||||||
std::vector<llama_token_data> candidates;
|
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < (llama_token)n_vocab; token_id++) {
|
|
||||||
float logit = log(probs[token_id]);
|
|
||||||
candidates.emplace_back(llama_token_data{token_id, logit, 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
|
||||||
// DUMP(&candidates_p);
|
|
||||||
llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, (const llama_token *) last_tokens.data(), last_tokens.size(), alpha_frequency, alpha_presence);
|
|
||||||
llama_sample_softmax(nullptr, &candidates_p);
|
|
||||||
// DUMP(&candidates_p);
|
|
||||||
|
|
||||||
assert(candidates_p.size == expected_probs.size());
|
|
||||||
for (size_t i = 0; i < candidates_p.size; i++) {
|
|
||||||
assert(fabs(candidates_p.data[i].p - expected_probs[i]) < 1e-3);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -181,13 +146,13 @@ int main(void) {
|
|||||||
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
|
||||||
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
|
||||||
|
|
||||||
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f, 0.0f, 0.0f);
|
||||||
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
||||||
test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f, 0.0f, 0.0f);
|
||||||
|
|
||||||
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 5.0f, 5.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 1.0f, 5.0f, 5.0f);
|
||||||
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 5.0f, 5.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 1.0f, 5.0f, 5.0f);
|
||||||
test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f);
|
test_repetition_penalties({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 1.0f, 5.0f, 5.0f);
|
||||||
|
|
||||||
printf("OK\n");
|
printf("OK\n");
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user