From b52656158352c1b38d6c2780eb7ebb1b411b2f5e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 20 Oct 2023 17:47:13 +0300 Subject: [PATCH] sampling : rename penalty params + reduce size of "prev" vector --- common/common.cpp | 23 +++++----- common/sampling.cpp | 12 +++--- common/sampling.h | 11 +++-- examples/main/main.cpp | 4 +- examples/server/server.cpp | 86 +++++++++++++++++++------------------- llama.cpp | 21 ++++++---- llama.h | 8 ++-- 7 files changed, 86 insertions(+), 79 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index eb5ebb20f..2ef902bd5 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - sparams.repeat_last_n = std::stoi(argv[i]); + sparams.penalty_last_n = std::stoi(argv[i]); + sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n); } else if (arg == "--repeat-penalty") { if (++i >= argc) { invalid_param = true; break; } - sparams.repeat_penalty = std::stof(argv[i]); + sparams.penalty_repeat = std::stof(argv[i]); } else if (arg == "--frequency-penalty") { if (++i >= argc) { invalid_param = true; break; } - sparams.frequency_penalty = std::stof(argv[i]); + sparams.penalty_freq = std::stof(argv[i]); } else if (arg == "--presence-penalty") { if (++i >= argc) { invalid_param = true; break; } - sparams.presence_penalty = std::stof(argv[i]); + sparams.penalty_present = std::stof(argv[i]); } else if (arg == "--mirostat") { if (++i >= argc) { invalid_param = true; @@ -678,10 +679,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p); printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z); printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p); - printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n); - printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty); - printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty); - printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty); + printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n); + printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat); + printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present); + printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq); printf(" --mirostat N use Mirostat sampling.\n"); printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat); @@ -1178,7 +1179,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty); + fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str()); fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); @@ -1238,14 +1239,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false"); fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty); + fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str()); fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens); fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false"); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty); + fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); fprintf(stream, "reverse_prompt:\n"); for (std::string ap : params.antiprompt) { diff --git a/common/sampling.cpp b/common/sampling.cpp index 58c9d0dd2..3db2cede8 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -80,10 +80,10 @@ llama_token llama_sampling_sample( const float top_p = params.top_p; const float tfs_z = params.tfs_z; const float typical_p = params.typical_p; - const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_prev : params.repeat_last_n; - const float repeat_penalty = params.repeat_penalty; - const float alpha_presence = params.presence_penalty; - const float alpha_frequency = params.frequency_penalty; + const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n; + const float penalty_repeat = params.penalty_repeat; + const float penalty_freq = params.penalty_freq; + const float penalty_present = params.penalty_present; const int mirostat = params.mirostat; const float mirostat_tau = params.mirostat_tau; const float mirostat_eta = params.mirostat_eta; @@ -118,8 +118,8 @@ llama_token llama_sampling_sample( const float nl_logit = logits[llama_token_nl(ctx_main)]; llama_sample_repetition_penalties(ctx_main, &cur_p, - prev.data() + prev.size() - repeat_last_n, - repeat_last_n, repeat_penalty, alpha_frequency, alpha_presence); + prev.data() + prev.size() - penalty_last_n, + penalty_last_n, penalty_repeat, penalty_freq, penalty_present); if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { diff --git a/common/sampling.h b/common/sampling.h index 227c611d7..ccfeb13af 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -10,17 +10,17 @@ // sampling parameters typedef struct llama_sampling_params { - int32_t n_prev = 256; // number of previous tokens to remember + int32_t n_prev = 64; // number of previous tokens to remember int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens. int32_t top_k = 40; // <= 0 to use vocab size float top_p = 0.95f; // 1.0 = disabled float tfs_z = 1.00f; // 1.0 = disabled float typical_p = 1.00f; // 1.0 = disabled float temp = 0.80f; // 1.0 = disabled - float repeat_penalty = 1.10f; // 1.0 = disabled - int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) - float frequency_penalty = 0.00f; // 0.0 = disabled - float presence_penalty = 0.00f; // 0.0 = disabled + int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) + float penalty_repeat = 1.10f; // 1.0 = disabled + float penalty_freq = 0.00f; // 0.0 = disabled + float penalty_present = 0.00f; // 0.0 = disabled int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 float mirostat_tau = 5.00f; // target entropy float mirostat_eta = 0.10f; // learning rate @@ -34,7 +34,6 @@ typedef struct llama_sampling_params { float cfg_scale = 1.f; // How strong is guidance std::unordered_map logit_bias; // logit bias for specific tokens - } llama_sampling_params; // general sampler context diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 540708ef1..cc9974b28 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -415,8 +415,8 @@ int main(int argc, char ** argv) { } } } - LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", - sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); + LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, frequency_penalty = %f, presence_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", + sparams.penalty_last_n, sparams.penalty_repeat, sparams.penalty_freq, sparams.penalty_present, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau); LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); LOG_TEE("\n\n"); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 18fe901dc..50fdb2d3a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1010,30 +1010,30 @@ static json format_generation_settings(llama_server_context &llama) eos_bias->second < 0.0f && std::isinf(eos_bias->second); return json{ - {"n_ctx", llama.n_ctx}, - {"model", llama.params.model_alias}, - {"seed", llama.params.seed}, - {"temp", sparams.temp}, - {"top_k", sparams.top_k}, - {"top_p", sparams.top_p}, - {"tfs_z", sparams.tfs_z}, - {"typical_p", sparams.typical_p}, - {"repeat_last_n", sparams.repeat_last_n}, - {"repeat_penalty", sparams.repeat_penalty}, - {"presence_penalty", sparams.presence_penalty}, - {"frequency_penalty", sparams.frequency_penalty}, - {"mirostat", sparams.mirostat}, - {"mirostat_tau", sparams.mirostat_tau}, - {"mirostat_eta", sparams.mirostat_eta}, - {"penalize_nl", sparams.penalize_nl}, - {"stop", llama.params.antiprompt}, - {"n_predict", llama.params.n_predict}, - {"n_keep", llama.params.n_keep}, - {"ignore_eos", ignore_eos}, - {"stream", llama.stream}, - {"logit_bias", sparams.logit_bias}, - {"n_probs", sparams.n_probs}, - {"grammar", llama.params.sparams.grammar}, + {"n_ctx", llama.n_ctx}, + {"model", llama.params.model_alias}, + {"seed", llama.params.seed}, + {"temp", sparams.temp}, + {"top_k", sparams.top_k}, + {"top_p", sparams.top_p}, + {"tfs_z", sparams.tfs_z}, + {"typical_p", sparams.typical_p}, + {"repeat_last_n", sparams.penalty_last_n}, + {"repeat_penalty", sparams.penalty_repeat}, + {"frequency_penalty", sparams.penalty_freq}, + {"presence_penalty", sparams.penalty_present}, + {"mirostat", sparams.mirostat}, + {"mirostat_tau", sparams.mirostat_tau}, + {"mirostat_eta", sparams.mirostat_eta}, + {"penalize_nl", sparams.penalize_nl}, + {"stop", llama.params.antiprompt}, + {"n_predict", llama.params.n_predict}, + {"n_keep", llama.params.n_keep}, + {"ignore_eos", ignore_eos}, + {"stream", llama.stream}, + {"logit_bias", sparams.logit_bias}, + {"n_probs", sparams.n_probs}, + {"grammar", llama.params.sparams.grammar}, }; } @@ -1134,25 +1134,25 @@ static void parse_options_completion(const json &body, llama_server_context &lla auto & params = llama.params; auto & sparams = llama.params.sparams; - llama.stream = json_value(body, "stream", false); - params.n_predict = json_value(body, "n_predict", default_params.n_predict); - sparams.top_k = json_value(body, "top_k", default_sparams.top_k); - sparams.top_p = json_value(body, "top_p", default_sparams.top_p); - sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z); - sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p); - sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n); - sparams.temp = json_value(body, "temperature", default_sparams.temp); - sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty); - sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty); - sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty); - sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat); - sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); - sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); - sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); - params.n_keep = json_value(body, "n_keep", default_params.n_keep); - params.seed = json_value(body, "seed", default_params.seed); - sparams.grammar = json_value(body, "grammar", default_sparams.grammar); - sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); + llama.stream = json_value(body, "stream", false); + params.n_predict = json_value(body, "n_predict", default_params.n_predict); + sparams.top_k = json_value(body, "top_k", default_sparams.top_k); + sparams.top_p = json_value(body, "top_p", default_sparams.top_p); + sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z); + sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p); + sparams.temp = json_value(body, "temperature", default_sparams.temp); + sparams.penalty_last_n = json_value(body, "repeat_last_n", default_sparams.penalty_last_n); + sparams.penalty_repeat = json_value(body, "repeat_penalty", default_sparams.penalty_repeat); + sparams.penalty_freq = json_value(body, "frequency_penalty", default_sparams.penalty_freq); + sparams.penalty_present = json_value(body, "presence_penalty", default_sparams.penalty_present); + sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat); + sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau); + sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta); + sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl); + params.n_keep = json_value(body, "n_keep", default_params.n_keep); + params.seed = json_value(body, "seed", default_params.seed); + sparams.grammar = json_value(body, "grammar", default_sparams.grammar); + sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs); if (body.count("prompt") != 0) { diff --git a/llama.cpp b/llama.cpp index d19d4222c..365349335 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7414,8 +7414,15 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array llama_sample_temp(ctx, candidates_p, temp); } -void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float repeat_penalty, float alpha_frequency, float alpha_presence) { - if (last_tokens_size == 0 || (repeat_penalty == 1.0f && alpha_frequency == 0.0f && alpha_presence == 0.0f)) { +void llama_sample_repetition_penalties( + struct llama_context * ctx, + llama_token_data_array * candidates, + const llama_token * last_tokens, + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present) { + if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) { return; } @@ -7423,8 +7430,8 @@ void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_d // Create a frequency map to count occurrences of each token in last_tokens std::unordered_map token_count; - for (size_t i = 0; i < last_tokens_size; ++i) { - token_count[last_tokens_p[i]]++; + for (size_t i = 0; i < penalty_last_n; ++i) { + token_count[last_tokens[i]]++; } // Apply frequency and presence penalties to the candidates @@ -7439,12 +7446,12 @@ void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_d // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong. // This is common fix for this problem, which is to multiply by the penalty instead of dividing. if (candidates->data[i].logit <= 0) { - candidates->data[i].logit *= repeat_penalty; + candidates->data[i].logit *= penalty_repeat; } else { - candidates->data[i].logit /= repeat_penalty; + candidates->data[i].logit /= penalty_repeat; } - candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence; + candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present; } candidates->sorted = false; diff --git a/llama.h b/llama.h index 4bdd6b098..306f5b383 100644 --- a/llama.h +++ b/llama.h @@ -565,10 +565,10 @@ extern "C" { struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, - size_t last_tokens_size, - float repeat_penalty, - float alpha_frequency, - float alpha_presence); + size_t penalty_last_n, + float penalty_repeat, + float penalty_freq, + float penalty_present); /// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 /// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.