sampling : rename penalty params + reduce size of "prev" vector

This commit is contained in:
Georgi Gerganov 2023-10-20 17:47:13 +03:00
parent 84ed48b473
commit b526561583
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
7 changed files with 86 additions and 79 deletions

View File

@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true;
break;
}
sparams.repeat_last_n = std::stoi(argv[i]);
sparams.penalty_last_n = std::stoi(argv[i]);
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
} else if (arg == "--repeat-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.repeat_penalty = std::stof(argv[i]);
sparams.penalty_repeat = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.frequency_penalty = std::stof(argv[i]);
sparams.penalty_freq = std::stof(argv[i]);
} else if (arg == "--presence-penalty") {
if (++i >= argc) {
invalid_param = true;
break;
}
sparams.presence_penalty = std::stof(argv[i]);
sparams.penalty_present = std::stof(argv[i]);
} else if (arg == "--mirostat") {
if (++i >= argc) {
invalid_param = true;
@ -678,10 +679,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
printf(" --mirostat N use Mirostat sampling.\n");
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
@ -1178,7 +1179,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
@ -1238,14 +1239,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty);
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);
fprintf(stream, "reverse_prompt:\n");
for (std::string ap : params.antiprompt) {

View File

@ -80,10 +80,10 @@ llama_token llama_sampling_sample(
const float top_p = params.top_p;
const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_prev : params.repeat_last_n;
const float repeat_penalty = params.repeat_penalty;
const float alpha_presence = params.presence_penalty;
const float alpha_frequency = params.frequency_penalty;
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float penalty_repeat = params.penalty_repeat;
const float penalty_freq = params.penalty_freq;
const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta;
@ -118,8 +118,8 @@ llama_token llama_sampling_sample(
const float nl_logit = logits[llama_token_nl(ctx_main)];
llama_sample_repetition_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - repeat_last_n,
repeat_last_n, repeat_penalty, alpha_frequency, alpha_presence);
prev.data() + prev.size() - penalty_last_n,
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) {

View File

@ -10,17 +10,17 @@
// sampling parameters
typedef struct llama_sampling_params {
int32_t n_prev = 256; // number of previous tokens to remember
int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled
float repeat_penalty = 1.10f; // 1.0 = disabled
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float frequency_penalty = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.10f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate
@ -34,7 +34,6 @@ typedef struct llama_sampling_params {
float cfg_scale = 1.f; // How strong is guidance
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
} llama_sampling_params;
// general sampler context

View File

@ -415,8 +415,8 @@ int main(int argc, char ** argv) {
}
}
}
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, frequency_penalty = %f, presence_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
sparams.penalty_last_n, sparams.penalty_repeat, sparams.penalty_freq, sparams.penalty_present, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");

View File

@ -1018,10 +1018,10 @@ static json format_generation_settings(llama_server_context &llama)
{"top_p", sparams.top_p},
{"tfs_z", sparams.tfs_z},
{"typical_p", sparams.typical_p},
{"repeat_last_n", sparams.repeat_last_n},
{"repeat_penalty", sparams.repeat_penalty},
{"presence_penalty", sparams.presence_penalty},
{"frequency_penalty", sparams.frequency_penalty},
{"repeat_last_n", sparams.penalty_last_n},
{"repeat_penalty", sparams.penalty_repeat},
{"frequency_penalty", sparams.penalty_freq},
{"presence_penalty", sparams.penalty_present},
{"mirostat", sparams.mirostat},
{"mirostat_tau", sparams.mirostat_tau},
{"mirostat_eta", sparams.mirostat_eta},
@ -1140,11 +1140,11 @@ static void parse_options_completion(const json &body, llama_server_context &lla
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n);
sparams.temp = json_value(body, "temperature", default_sparams.temp);
sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty);
sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty);
sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty);
sparams.penalty_last_n = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
sparams.penalty_repeat = json_value(body, "repeat_penalty", default_sparams.penalty_repeat);
sparams.penalty_freq = json_value(body, "frequency_penalty", default_sparams.penalty_freq);
sparams.penalty_present = json_value(body, "presence_penalty", default_sparams.penalty_present);
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);

View File

@ -7414,8 +7414,15 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array
llama_sample_temp(ctx, candidates_p, temp);
}
void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float repeat_penalty, float alpha_frequency, float alpha_presence) {
if (last_tokens_size == 0 || (repeat_penalty == 1.0f && alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
void llama_sample_repetition_penalties(
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present) {
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
return;
}
@ -7423,8 +7430,8 @@ void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_d
// Create a frequency map to count occurrences of each token in last_tokens
std::unordered_map<llama_token, int> token_count;
for (size_t i = 0; i < last_tokens_size; ++i) {
token_count[last_tokens_p[i]]++;
for (size_t i = 0; i < penalty_last_n; ++i) {
token_count[last_tokens[i]]++;
}
// Apply frequency and presence penalties to the candidates
@ -7439,12 +7446,12 @@ void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_d
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
if (candidates->data[i].logit <= 0) {
candidates->data[i].logit *= repeat_penalty;
candidates->data[i].logit *= penalty_repeat;
} else {
candidates->data[i].logit /= repeat_penalty;
candidates->data[i].logit /= penalty_repeat;
}
candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence;
candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
}
candidates->sorted = false;

View File

@ -565,10 +565,10 @@ extern "C" {
struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t last_tokens_size,
float repeat_penalty,
float alpha_frequency,
float alpha_presence);
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present);
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.