mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
sampling : rename penalty params + reduce size of "prev" vector
This commit is contained in:
parent
84ed48b473
commit
b526561583
@ -241,25 +241,26 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.repeat_last_n = std::stoi(argv[i]);
|
sparams.penalty_last_n = std::stoi(argv[i]);
|
||||||
|
sparams.n_prev = std::max(sparams.n_prev, sparams.penalty_last_n);
|
||||||
} else if (arg == "--repeat-penalty") {
|
} else if (arg == "--repeat-penalty") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.repeat_penalty = std::stof(argv[i]);
|
sparams.penalty_repeat = std::stof(argv[i]);
|
||||||
} else if (arg == "--frequency-penalty") {
|
} else if (arg == "--frequency-penalty") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.frequency_penalty = std::stof(argv[i]);
|
sparams.penalty_freq = std::stof(argv[i]);
|
||||||
} else if (arg == "--presence-penalty") {
|
} else if (arg == "--presence-penalty") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
sparams.presence_penalty = std::stof(argv[i]);
|
sparams.penalty_present = std::stof(argv[i]);
|
||||||
} else if (arg == "--mirostat") {
|
} else if (arg == "--mirostat") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -678,10 +679,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
|
printf(" --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)sparams.top_p);
|
||||||
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
|
printf(" --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)sparams.tfs_z);
|
||||||
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
|
printf(" --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)sparams.typical_p);
|
||||||
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.repeat_last_n);
|
printf(" --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", sparams.penalty_last_n);
|
||||||
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.repeat_penalty);
|
printf(" --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)sparams.penalty_repeat);
|
||||||
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.presence_penalty);
|
printf(" --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_present);
|
||||||
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.frequency_penalty);
|
printf(" --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)sparams.penalty_freq);
|
||||||
printf(" --mirostat N use Mirostat sampling.\n");
|
printf(" --mirostat N use Mirostat sampling.\n");
|
||||||
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
|
printf(" Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
|
||||||
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
|
printf(" (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", sparams.mirostat);
|
||||||
@ -1178,7 +1179,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||||
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
||||||
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
||||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.frequency_penalty);
|
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq);
|
||||||
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
|
dump_string_yaml_multiline(stream, "grammar", sparams.grammar.c_str());
|
||||||
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
|
fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n");
|
||||||
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false");
|
||||||
@ -1238,14 +1239,14 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
|
fprintf(stream, "numa: %s # default: false\n", params.numa ? "true" : "false");
|
||||||
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
|
fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type);
|
||||||
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
|
fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride);
|
||||||
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.presence_penalty);
|
fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present);
|
||||||
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
|
dump_string_yaml_multiline(stream, "prompt", params.prompt.c_str());
|
||||||
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
|
fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str());
|
||||||
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
|
fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false");
|
||||||
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
|
fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false");
|
||||||
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
|
dump_vector_int_yaml(stream, "prompt_tokens", prompt_tokens);
|
||||||
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
|
fprintf(stream, "random_prompt: %s # default: false\n", params.random_prompt ? "true" : "false");
|
||||||
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.repeat_penalty);
|
fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat);
|
||||||
|
|
||||||
fprintf(stream, "reverse_prompt:\n");
|
fprintf(stream, "reverse_prompt:\n");
|
||||||
for (std::string ap : params.antiprompt) {
|
for (std::string ap : params.antiprompt) {
|
||||||
|
@ -80,10 +80,10 @@ llama_token llama_sampling_sample(
|
|||||||
const float top_p = params.top_p;
|
const float top_p = params.top_p;
|
||||||
const float tfs_z = params.tfs_z;
|
const float tfs_z = params.tfs_z;
|
||||||
const float typical_p = params.typical_p;
|
const float typical_p = params.typical_p;
|
||||||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? params.n_prev : params.repeat_last_n;
|
const int32_t penalty_last_n = params.penalty_last_n < 0 ? params.n_prev : params.penalty_last_n;
|
||||||
const float repeat_penalty = params.repeat_penalty;
|
const float penalty_repeat = params.penalty_repeat;
|
||||||
const float alpha_presence = params.presence_penalty;
|
const float penalty_freq = params.penalty_freq;
|
||||||
const float alpha_frequency = params.frequency_penalty;
|
const float penalty_present = params.penalty_present;
|
||||||
const int mirostat = params.mirostat;
|
const int mirostat = params.mirostat;
|
||||||
const float mirostat_tau = params.mirostat_tau;
|
const float mirostat_tau = params.mirostat_tau;
|
||||||
const float mirostat_eta = params.mirostat_eta;
|
const float mirostat_eta = params.mirostat_eta;
|
||||||
@ -118,8 +118,8 @@ llama_token llama_sampling_sample(
|
|||||||
const float nl_logit = logits[llama_token_nl(ctx_main)];
|
const float nl_logit = logits[llama_token_nl(ctx_main)];
|
||||||
|
|
||||||
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
llama_sample_repetition_penalties(ctx_main, &cur_p,
|
||||||
prev.data() + prev.size() - repeat_last_n,
|
prev.data() + prev.size() - penalty_last_n,
|
||||||
repeat_last_n, repeat_penalty, alpha_frequency, alpha_presence);
|
penalty_last_n, penalty_repeat, penalty_freq, penalty_present);
|
||||||
|
|
||||||
if (!penalize_nl) {
|
if (!penalize_nl) {
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
|
@ -10,17 +10,17 @@
|
|||||||
|
|
||||||
// sampling parameters
|
// sampling parameters
|
||||||
typedef struct llama_sampling_params {
|
typedef struct llama_sampling_params {
|
||||||
int32_t n_prev = 256; // number of previous tokens to remember
|
int32_t n_prev = 64; // number of previous tokens to remember
|
||||||
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
int32_t n_probs = 0; // if greater than 0, output the probabilities of top n_probs tokens.
|
||||||
int32_t top_k = 40; // <= 0 to use vocab size
|
int32_t top_k = 40; // <= 0 to use vocab size
|
||||||
float top_p = 0.95f; // 1.0 = disabled
|
float top_p = 0.95f; // 1.0 = disabled
|
||||||
float tfs_z = 1.00f; // 1.0 = disabled
|
float tfs_z = 1.00f; // 1.0 = disabled
|
||||||
float typical_p = 1.00f; // 1.0 = disabled
|
float typical_p = 1.00f; // 1.0 = disabled
|
||||||
float temp = 0.80f; // 1.0 = disabled
|
float temp = 0.80f; // 1.0 = disabled
|
||||||
float repeat_penalty = 1.10f; // 1.0 = disabled
|
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
||||||
int32_t repeat_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
|
float penalty_repeat = 1.10f; // 1.0 = disabled
|
||||||
float frequency_penalty = 0.00f; // 0.0 = disabled
|
float penalty_freq = 0.00f; // 0.0 = disabled
|
||||||
float presence_penalty = 0.00f; // 0.0 = disabled
|
float penalty_present = 0.00f; // 0.0 = disabled
|
||||||
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
int32_t mirostat = 0; // 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
|
||||||
float mirostat_tau = 5.00f; // target entropy
|
float mirostat_tau = 5.00f; // target entropy
|
||||||
float mirostat_eta = 0.10f; // learning rate
|
float mirostat_eta = 0.10f; // learning rate
|
||||||
@ -34,7 +34,6 @@ typedef struct llama_sampling_params {
|
|||||||
float cfg_scale = 1.f; // How strong is guidance
|
float cfg_scale = 1.f; // How strong is guidance
|
||||||
|
|
||||||
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
|
||||||
|
|
||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
|
@ -415,8 +415,8 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
|
LOG_TEE("sampling: repeat_last_n = %d, repeat_penalty = %f, frequency_penalty = %f, presence_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
|
||||||
sparams.repeat_last_n, sparams.repeat_penalty, sparams.presence_penalty, sparams.frequency_penalty, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
|
sparams.penalty_last_n, sparams.penalty_repeat, sparams.penalty_freq, sparams.penalty_present, sparams.top_k, sparams.tfs_z, sparams.top_p, sparams.typical_p, sparams.temp, sparams.mirostat, sparams.mirostat_eta, sparams.mirostat_tau);
|
||||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
LOG_TEE("\n\n");
|
LOG_TEE("\n\n");
|
||||||
|
|
||||||
|
@ -1010,30 +1010,30 @@ static json format_generation_settings(llama_server_context &llama)
|
|||||||
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||||
|
|
||||||
return json{
|
return json{
|
||||||
{"n_ctx", llama.n_ctx},
|
{"n_ctx", llama.n_ctx},
|
||||||
{"model", llama.params.model_alias},
|
{"model", llama.params.model_alias},
|
||||||
{"seed", llama.params.seed},
|
{"seed", llama.params.seed},
|
||||||
{"temp", sparams.temp},
|
{"temp", sparams.temp},
|
||||||
{"top_k", sparams.top_k},
|
{"top_k", sparams.top_k},
|
||||||
{"top_p", sparams.top_p},
|
{"top_p", sparams.top_p},
|
||||||
{"tfs_z", sparams.tfs_z},
|
{"tfs_z", sparams.tfs_z},
|
||||||
{"typical_p", sparams.typical_p},
|
{"typical_p", sparams.typical_p},
|
||||||
{"repeat_last_n", sparams.repeat_last_n},
|
{"repeat_last_n", sparams.penalty_last_n},
|
||||||
{"repeat_penalty", sparams.repeat_penalty},
|
{"repeat_penalty", sparams.penalty_repeat},
|
||||||
{"presence_penalty", sparams.presence_penalty},
|
{"frequency_penalty", sparams.penalty_freq},
|
||||||
{"frequency_penalty", sparams.frequency_penalty},
|
{"presence_penalty", sparams.penalty_present},
|
||||||
{"mirostat", sparams.mirostat},
|
{"mirostat", sparams.mirostat},
|
||||||
{"mirostat_tau", sparams.mirostat_tau},
|
{"mirostat_tau", sparams.mirostat_tau},
|
||||||
{"mirostat_eta", sparams.mirostat_eta},
|
{"mirostat_eta", sparams.mirostat_eta},
|
||||||
{"penalize_nl", sparams.penalize_nl},
|
{"penalize_nl", sparams.penalize_nl},
|
||||||
{"stop", llama.params.antiprompt},
|
{"stop", llama.params.antiprompt},
|
||||||
{"n_predict", llama.params.n_predict},
|
{"n_predict", llama.params.n_predict},
|
||||||
{"n_keep", llama.params.n_keep},
|
{"n_keep", llama.params.n_keep},
|
||||||
{"ignore_eos", ignore_eos},
|
{"ignore_eos", ignore_eos},
|
||||||
{"stream", llama.stream},
|
{"stream", llama.stream},
|
||||||
{"logit_bias", sparams.logit_bias},
|
{"logit_bias", sparams.logit_bias},
|
||||||
{"n_probs", sparams.n_probs},
|
{"n_probs", sparams.n_probs},
|
||||||
{"grammar", llama.params.sparams.grammar},
|
{"grammar", llama.params.sparams.grammar},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1134,25 +1134,25 @@ static void parse_options_completion(const json &body, llama_server_context &lla
|
|||||||
auto & params = llama.params;
|
auto & params = llama.params;
|
||||||
auto & sparams = llama.params.sparams;
|
auto & sparams = llama.params.sparams;
|
||||||
|
|
||||||
llama.stream = json_value(body, "stream", false);
|
llama.stream = json_value(body, "stream", false);
|
||||||
params.n_predict = json_value(body, "n_predict", default_params.n_predict);
|
params.n_predict = json_value(body, "n_predict", default_params.n_predict);
|
||||||
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
|
sparams.top_k = json_value(body, "top_k", default_sparams.top_k);
|
||||||
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
|
sparams.top_p = json_value(body, "top_p", default_sparams.top_p);
|
||||||
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
|
sparams.tfs_z = json_value(body, "tfs_z", default_sparams.tfs_z);
|
||||||
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
|
sparams.typical_p = json_value(body, "typical_p", default_sparams.typical_p);
|
||||||
sparams.repeat_last_n = json_value(body, "repeat_last_n", default_sparams.repeat_last_n);
|
sparams.temp = json_value(body, "temperature", default_sparams.temp);
|
||||||
sparams.temp = json_value(body, "temperature", default_sparams.temp);
|
sparams.penalty_last_n = json_value(body, "repeat_last_n", default_sparams.penalty_last_n);
|
||||||
sparams.repeat_penalty = json_value(body, "repeat_penalty", default_sparams.repeat_penalty);
|
sparams.penalty_repeat = json_value(body, "repeat_penalty", default_sparams.penalty_repeat);
|
||||||
sparams.presence_penalty = json_value(body, "presence_penalty", default_sparams.presence_penalty);
|
sparams.penalty_freq = json_value(body, "frequency_penalty", default_sparams.penalty_freq);
|
||||||
sparams.frequency_penalty = json_value(body, "frequency_penalty", default_sparams.frequency_penalty);
|
sparams.penalty_present = json_value(body, "presence_penalty", default_sparams.penalty_present);
|
||||||
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
|
sparams.mirostat = json_value(body, "mirostat", default_sparams.mirostat);
|
||||||
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
sparams.mirostat_tau = json_value(body, "mirostat_tau", default_sparams.mirostat_tau);
|
||||||
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
sparams.mirostat_eta = json_value(body, "mirostat_eta", default_sparams.mirostat_eta);
|
||||||
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
sparams.penalize_nl = json_value(body, "penalize_nl", default_sparams.penalize_nl);
|
||||||
params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
params.n_keep = json_value(body, "n_keep", default_params.n_keep);
|
||||||
params.seed = json_value(body, "seed", default_params.seed);
|
params.seed = json_value(body, "seed", default_params.seed);
|
||||||
sparams.grammar = json_value(body, "grammar", default_sparams.grammar);
|
sparams.grammar = json_value(body, "grammar", default_sparams.grammar);
|
||||||
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
sparams.n_probs = json_value(body, "n_probs", default_sparams.n_probs);
|
||||||
|
|
||||||
if (body.count("prompt") != 0)
|
if (body.count("prompt") != 0)
|
||||||
{
|
{
|
||||||
|
21
llama.cpp
21
llama.cpp
@ -7414,8 +7414,15 @@ void llama_sample_temperature(struct llama_context * ctx, llama_token_data_array
|
|||||||
llama_sample_temp(ctx, candidates_p, temp);
|
llama_sample_temp(ctx, candidates_p, temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens_p, size_t last_tokens_size, float repeat_penalty, float alpha_frequency, float alpha_presence) {
|
void llama_sample_repetition_penalties(
|
||||||
if (last_tokens_size == 0 || (repeat_penalty == 1.0f && alpha_frequency == 0.0f && alpha_presence == 0.0f)) {
|
struct llama_context * ctx,
|
||||||
|
llama_token_data_array * candidates,
|
||||||
|
const llama_token * last_tokens,
|
||||||
|
size_t penalty_last_n,
|
||||||
|
float penalty_repeat,
|
||||||
|
float penalty_freq,
|
||||||
|
float penalty_present) {
|
||||||
|
if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -7423,8 +7430,8 @@ void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_d
|
|||||||
|
|
||||||
// Create a frequency map to count occurrences of each token in last_tokens
|
// Create a frequency map to count occurrences of each token in last_tokens
|
||||||
std::unordered_map<llama_token, int> token_count;
|
std::unordered_map<llama_token, int> token_count;
|
||||||
for (size_t i = 0; i < last_tokens_size; ++i) {
|
for (size_t i = 0; i < penalty_last_n; ++i) {
|
||||||
token_count[last_tokens_p[i]]++;
|
token_count[last_tokens[i]]++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply frequency and presence penalties to the candidates
|
// Apply frequency and presence penalties to the candidates
|
||||||
@ -7439,12 +7446,12 @@ void llama_sample_repetition_penalties(struct llama_context * ctx, llama_token_d
|
|||||||
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
// The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
|
||||||
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
// This is common fix for this problem, which is to multiply by the penalty instead of dividing.
|
||||||
if (candidates->data[i].logit <= 0) {
|
if (candidates->data[i].logit <= 0) {
|
||||||
candidates->data[i].logit *= repeat_penalty;
|
candidates->data[i].logit *= penalty_repeat;
|
||||||
} else {
|
} else {
|
||||||
candidates->data[i].logit /= repeat_penalty;
|
candidates->data[i].logit /= penalty_repeat;
|
||||||
}
|
}
|
||||||
|
|
||||||
candidates->data[i].logit -= float(count) * alpha_frequency + float(count > 0) * alpha_presence;
|
candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
|
||||||
}
|
}
|
||||||
|
|
||||||
candidates->sorted = false;
|
candidates->sorted = false;
|
||||||
|
8
llama.h
8
llama.h
@ -565,10 +565,10 @@ extern "C" {
|
|||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
const llama_token * last_tokens,
|
const llama_token * last_tokens,
|
||||||
size_t last_tokens_size,
|
size_t penalty_last_n,
|
||||||
float repeat_penalty,
|
float penalty_repeat,
|
||||||
float alpha_frequency,
|
float penalty_freq,
|
||||||
float alpha_presence);
|
float penalty_present);
|
||||||
|
|
||||||
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
/// @details Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
|
||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
||||||
|
Loading…
Reference in New Issue
Block a user