mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-24 18:34:36 +00:00
llama : remove cfg smooth factor as it is only a reparameterization of the guidance scale (#2280)
This commit is contained in:
parent
73643f5fb1
commit
ab0e26bdfb
@ -260,12 +260,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.cfg_scale = std::stof(argv[i]);
|
params.cfg_scale = std::stof(argv[i]);
|
||||||
} else if (arg == "--cfg-smooth-factor") {
|
|
||||||
if (++i >= argc) {
|
|
||||||
invalid_param = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
params.cfg_smooth_factor = std::stof(argv[i]);
|
|
||||||
} else if (arg == "-b" || arg == "--batch-size") {
|
} else if (arg == "-b" || arg == "--batch-size") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -509,7 +503,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
fprintf(stderr, " --cfg-negative-prompt PROMPT \n");
|
fprintf(stderr, " --cfg-negative-prompt PROMPT \n");
|
||||||
fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n");
|
fprintf(stderr, " negative prompt to use for guidance. (default: empty)\n");
|
||||||
fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
|
fprintf(stderr, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale);
|
||||||
fprintf(stderr, " --cfg-smooth-factor N smooth factor between old and new logits (default: %f, 1.0 = no smoothing)\n", params.cfg_smooth_factor);
|
|
||||||
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
|
||||||
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
|
fprintf(stderr, " --rope-freq-base N RoPE base frequency (default: %.1f)\n", params.rope_freq_base);
|
||||||
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
|
fprintf(stderr, " --rope-freq-scale N RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale);
|
||||||
|
@ -55,7 +55,6 @@ struct gpt_params {
|
|||||||
// https://arxiv.org/abs/2306.17806
|
// https://arxiv.org/abs/2306.17806
|
||||||
std::string cfg_negative_prompt; // string to help guidance
|
std::string cfg_negative_prompt; // string to help guidance
|
||||||
float cfg_scale = 1.f; // How strong is guidance
|
float cfg_scale = 1.f; // How strong is guidance
|
||||||
float cfg_smooth_factor = 1.f; // Smooth factor between old and new logits
|
|
||||||
|
|
||||||
std::string model = "models/7B/ggml-model.bin"; // model path
|
std::string model = "models/7B/ggml-model.bin"; // model path
|
||||||
std::string model_alias = "unknown"; // model alias
|
std::string model_alias = "unknown"; // model alias
|
||||||
|
@ -557,7 +557,7 @@ int main(int argc, char ** argv) {
|
|||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
if (ctx_guidance) {
|
if (ctx_guidance) {
|
||||||
llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale, params.cfg_smooth_factor);
|
llama_sample_classifier_free_guidance(ctx, &candidates_p, ctx_guidance, params.cfg_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply penalties
|
// Apply penalties
|
||||||
|
14
llama.cpp
14
llama.cpp
@ -2218,8 +2218,7 @@ void llama_sample_classifier_free_guidance(
|
|||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
struct llama_context * guidance_ctx,
|
struct llama_context * guidance_ctx,
|
||||||
float scale,
|
float scale) {
|
||||||
float smooth_factor) {
|
|
||||||
int64_t t_start_sample_us = ggml_time_us();
|
int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
assert(ctx);
|
assert(ctx);
|
||||||
@ -2240,16 +2239,7 @@ void llama_sample_classifier_free_guidance(
|
|||||||
for (int i = 0; i < n_vocab; ++i) {
|
for (int i = 0; i < n_vocab; ++i) {
|
||||||
float logit_guidance = logits_guidance[i];
|
float logit_guidance = logits_guidance[i];
|
||||||
float logit_base = logits_base[i];
|
float logit_base = logits_base[i];
|
||||||
logits_guidance[i] = scale * (logit_base - logit_guidance) + logit_guidance;
|
candidates->data[i].logit = scale * (logit_base - logit_guidance) + logit_guidance;
|
||||||
}
|
|
||||||
|
|
||||||
llama_log_softmax(logits_guidance, n_vocab);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_vocab; ++i) {
|
|
||||||
float logit_base = logits_base[i];
|
|
||||||
float logit_guidance = logits_guidance[i];
|
|
||||||
|
|
||||||
candidates->data[i].logit = smooth_factor * logit_guidance + (1.f - smooth_factor) * logit_base;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
|
4
llama.h
4
llama.h
@ -344,13 +344,11 @@ extern "C" {
|
|||||||
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
/// @param candidates A vector of `llama_token_data` containing the candidate tokens, the logits must be directly extracted from the original generation context without being sorted.
|
||||||
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
/// @params guidance_ctx A separate context from the same model. Other than a negative prompt at the beginning, it should have all generated and user input tokens copied from the main context.
|
||||||
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
/// @params scale Guidance strength. 1.0f means no guidance. Higher values mean stronger guidance.
|
||||||
/// @params smooth_factor Smooth factor between guidance logits and original logits. 1.0f means only use guidance logits. 0.0f means only original logits.
|
|
||||||
LLAMA_API void llama_sample_classifier_free_guidance(
|
LLAMA_API void llama_sample_classifier_free_guidance(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_token_data_array * candidates,
|
llama_token_data_array * candidates,
|
||||||
struct llama_context * guidance_ctx,
|
struct llama_context * guidance_ctx,
|
||||||
float scale,
|
float scale);
|
||||||
float smooth_factor);
|
|
||||||
|
|
||||||
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
/// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
|
||||||
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
LLAMA_API void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates);
|
||||||
|
Loading…
Reference in New Issue
Block a user