sampling : rename penalty params + reduce size of "prev" vector

This commit is contained in:
Georgi Gerganov 2023-10-20 17:47:13 +03:00
parent 84ed48b473
commit b526561583
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
7 changed files with 86 additions and 79 deletions

View File

@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.repeat_last_n = std::stoi(argv[i]); sparams.penalty_last_n = std::stoi(argv[i]);
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
} else if (arg == "--repeat-penalty") { } else if (arg == "--repeat-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.repeat_penalty = std::stof(argv[i]); sparams.penalty_repeat = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") { } else if (arg == "--frequency-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.frequency_penalty = std::stof(argv[i]); sparams.penalty_freq = std::stof(argv[i]);
} else if (arg == "--presence-penalty") { } else if (arg == "--presence-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
sparams.presence_penalty = std::stof(argv[i]); sparams.penalty_present = std::stof(argv[i]);
} else if (arg == "--mirostat") { } else if (arg == "--mirostat") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -678,10 +679,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n); printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty); printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty); printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty); printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
printf(" --mirostat N use Mirostat sampling.\n"); printf(" --mirostat N use Mirostat sampling.\n");
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
@ -1178,7 +1179,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str()); dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
@ -1238,14 +1239,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty); fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty); fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);
fprintf(stream, "reverse_prompt:\n"); fprintf(stream, "reverse_prompt:\n");
for (std::string ap : params.antiprompt) { for (std::string ap : params.antiprompt) {

View File

@ -80,10 +80,10 @@ llama_token llama_sampling_sample(
const float top_p = params.top_p; const float top_p = params.top_p;
const float tfs_z = params.tfs_z; const float tfs_z = params.tfs_z;
const float typical_p = params.typical_p; const float typical_p = params.typical_p;
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_prev : params.repeat_last_n; const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
const float repeat_penalty = params.repeat_penalty; const float penalty_repeat = params.penalty_repeat;
const float alpha_presence = params.presence_penalty; const float penalty_freq = params.penalty_freq;
const float alpha_frequency = params.frequency_penalty; const float penalty_present = params.penalty_present;
const int mirostat = params.mirostat; const int mirostat = params.mirostat;
const float mirostat_tau = params.mirostat_tau; const float mirostat_tau = params.mirostat_tau;
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = params.mirostat_eta;
@ -118,8 +118,8 @@ llama_token llama_sampling_sample(
const float nl_logit = logits[llama_token_nl(ctx_main)]; const float nl_logit = logits[llama_token_nl(ctx_main)];
llama_sample_repetition_penalties(ctx_main, &cur_p, llama_sample_repetition_penalties(ctx_main, &cur_p,
prev.data() + prev.size() - repeat_last_n, prev.data() + prev.size() - penalty_last_n,
repeat_last_n, repeat_penalty, alpha_frequency, alpha_presence); penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
if (!penalize_nl) { if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) { for (size_t idx = 0; idx < cur_p.size; idx++) {

View File

@ -10,17 +10,17 @@
// sampling parameters // sampling parameters
typedef struct llama_sampling_params { typedef struct llama_sampling_params {
int32_t n_prev = 256; // number of previous tokens to remember int32_t n_prev = 64; // number of previous tokens to remember
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
int32_t top_k = 40; // <= 0 to use vocab size int32_t top_k = 40; // <= 0 to use vocab size
float top_p = 0.95f; // 1.0 = disabled float top_p = 0.95f; // 1.0 = disabled
float tfs_z = 1.00f; // 1.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled
float typical_p = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled
float temp = 0.80f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled
float repeat_penalty = 1.10f; // 1.0 = disabled int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat = 1.10f; // 1.0 = disabled
float frequency_penalty = 0.00f; // 0.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled
float presence_penalty = 0.00f; // 0.0 = disabled float penalty_present = 0.00f; // 0.0 = disabled
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
float mirostat_tau = 5.00f; // target entropy float mirostat_tau = 5.00f; // target entropy
float mirostat_eta = 0.10f; // learning rate float mirostat_eta = 0.10f; // learning rate
@ -34,7 +34,6 @@ typedef struct llama_sampling_params {
float cfg_scale = 1.f; // How strong is guidance float cfg_scale = 1.f; // How strong is guidance
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
} llama_sampling_params; } llama_sampling_params;
// general sampler context // general sampler context

View File

@ -415,8 +415,8 @@ int main(int argc, char ** argv) {
} }
} }
} }
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, frequency_penalty = %f, presence_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); sparams.penalty_last_n, sparams.penalty_repeat, sparams.penalty_freq, sparams.penalty_present, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n"); LOG_TEE("\n\n");

View File

@ -1010,30 +1010,30 @@ static json format_generation_settings(llama_server_context &llama)
eos_bias->second < 0.0f && std::isinf(eos_bias->second); eos_bias->second < 0.0f && std::isinf(eos_bias->second);
return json{ return json{
{"n_ctx", llama.n_ctx}, {"n_ctx", llama.n_ctx},
{"model", llama.params.model_alias}, {"model", llama.params.model_alias},
{"seed", llama.params.seed}, {"seed", llama.params.seed},
{"temp", sparams.temp}, {"temp", sparams.temp},
{"top_k", sparams.top_k}, {"top_k", sparams.top_k},
{"top_p", sparams.top_p}, {"top_p", sparams.top_p},
{"tfs_z", sparams.tfs_z}, {"tfs_z", sparams.tfs_z},
{"typical_p", sparams.typical_p}, {"typical_p", sparams.typical_p},
{"repeat_last_n", sparams.repeat_last_n}, {"repeat_last_n", sparams.penalty_last_n},
{"repeat_penalty", sparams.repeat_penalty}, {"repeat_penalty", sparams.penalty_repeat},
{"presence_penalty", sparams.presence_penalty}, {"frequency_penalty", sparams.penalty_freq},
{"frequency_penalty", sparams.frequency_penalty}, {"presence_penalty", sparams.penalty_present},
{"mirostat", sparams.mirostat}, {"mirostat", sparams.mirostat},
{"mirostat_tau", sparams.mirostat_tau}, {"mirostat_tau", sparams.mirostat_tau},
{"mirostat_eta", sparams.mirostat_eta}, {"mirostat_eta", sparams.mirostat_eta},
{"penalize_nl", sparams.penalize_nl}, {"penalize_nl", sparams.penalize_nl},
{"stop", llama.params.antiprompt}, {"stop", llama.params.antiprompt},
{"n_predict", llama.params.n_predict}, {"n_predict", llama.params.n_predict},
{"n_keep", llama.params.n_keep}, {"n_keep", llama.params.n_keep},
{"ignore_eos", ignore_eos}, {"ignore_eos", ignore_eos},
{"stream", llama.stream}, {"stream", llama.stream},
{"logit_bias", sparams.logit_bias}, {"logit_bias", sparams.logit_bias},
{"n_probs", sparams.n_probs}, {"n_probs", sparams.n_probs},
{"grammar", llama.params.sparams.grammar}, {"grammar", llama.params.sparams.grammar},
}; };
} }
@ -1134,25 +1134,25 @@ static void parse_options_completion(const json &body, llama_server_context &lla
auto & params = llama.params; auto & params = llama.params;
auto & sparams = llama.params.sparams; auto & sparams = llama.params.sparams;
llama.stream = json_value(body, "stream", false); llama.stream = json_value(body, "stream", false);
params.n_predict = json_value(body, "n_predict", default_params.n_predict); params.n_predict = json_value(body, "n_predict", default_params.n_predict);
sparams.top_k = json_value(body, "top_k", default_sparams.top_k); sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
sparams.top_p = json_value(body, "top_p", default_sparams.top_p); sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z); sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p); sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n); sparams.temp = json_value(body, "temperature", default_sparams.temp);
sparams.temp = json_value(body, "temperature", default_sparams.temp); sparams.penalty_last_n = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty); sparams.penalty_repeat = json_value(body, "repeat_penalty", default_sparams.penalty_repeat);
sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty); sparams.penalty_freq = json_value(body, "frequency_penalty", default_sparams.penalty_freq);
sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty); sparams.penalty_present = json_value(body, "presence_penalty", default_sparams.penalty_present);
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat); sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
params.n_keep = json_value(body, "n_keep", default_params.n_keep); params.n_keep = json_value(body, "n_keep", default_params.n_keep);
params.seed = json_value(body, "seed", default_params.seed); params.seed = json_value(body, "seed", default_params.seed);
sparams.grammar = json_value(body, "grammar", default_sparams.grammar); sparams.grammar = json_value(body, "grammar", default_sparams.grammar);
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
if (body.count("prompt") != 0) if (body.count("prompt") != 0)
{ {

View File

@ -7414,8 +7414,15 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array
llama_sample_temp(ctx, candidates_p, temp); llama_sample_temp(ctx, candidates_p, temp);
} }
void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float repeat_penalty, float alpha_frequency, float alpha_presence) { void llama_sample_repetition_penalties(
if (last_tokens_size == 0 || (repeat_penalty == 1.0f && alpha_frequency == 0.0f && alpha_presence == 0.0f)) { struct llama_context * ctx,
llama_token_data_array * candidates,
const llama_token * last_tokens,
size_t penalty_last_n,
float penalty_repeat,
float penalty_freq,
float penalty_present) {
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
return; return;
} }
@ -7423,8 +7430,8 @@ void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_d
// Create a frequency map to count occurrences of each token in last_tokens // Create a frequency map to count occurrences of each token in last_tokens
std::unordered_map<llama_token, int> token_count; std::unordered_map<llama_token, int> token_count;
for (size_t i = 0; i < last_tokens_size; ++i) { for (size_t i = 0; i < penalty_last_n; ++i) {
token_count[last_tokens_p[i]]++; token_count[last_tokens[i]]++;
} }
// Apply frequency and presence penalties to the candidates // Apply frequency and presence penalties to the candidates
@ -7439,12 +7446,12 @@ void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_d
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
// This is common fix for this problem, which is to multiply by the penalty instead of dividing. // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
if (candidates->data[i].logit <= 0) { if (candidates->data[i].logit <= 0) {
candidates->data[i].logit *= repeat_penalty; candidates->data[i].logit *= penalty_repeat;
} else { } else {
candidates->data[i].logit /= repeat_penalty; candidates->data[i].logit /= penalty_repeat;
} }
candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence; candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
} }
candidates->sorted = false; candidates->sorted = false;

View File

@ -565,10 +565,10 @@ extern "C" {
struct llama_context * ctx, struct llama_context * ctx,
llama_token_data_array * candidates, llama_token_data_array * candidates,
const llama_token * last_tokens, const llama_token * last_tokens,
size_t last_tokens_size, size_t penalty_last_n,
float repeat_penalty, float penalty_repeat,
float alpha_frequency, float penalty_freq,
float alpha_presence); float penalty_present);
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.