From 8dae7ce68437faf1fa96ec0e7687b8700956ef20 Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Thu, 17 Aug 2023 07:29:44 -0600 Subject: [PATCH] Add --cfg-negative-prompt-file option for examples (#2591) Add --cfg-negative-prompt-file option for examples --- examples/common.cpp | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/examples/common.cpp b/examples/common.cpp index 9f8aab9a2..bd39d9220 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -274,6 +274,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.cfg_negative_prompt = argv[i]; + } else if (arg == "--cfg-negative-prompt-file") { + if (++i >= argc) { + invalid_param = true; + break; + } + std::ifstream file(argv[i]); + if (!file) { + fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); + invalid_param = true; + break; + } + std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.cfg_negative_prompt)); + if (params.cfg_negative_prompt.back() == '\n') { + params.cfg_negative_prompt.pop_back(); + } } else if (arg == "--cfg-scale") { if (++i >= argc) { invalid_param = true; @@ -567,8 +582,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stdout, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stdout, " --grammar GRAMMAR BNF-like grammar to constrain generations (see samples in grammars/ dir)\n"); fprintf(stdout, " --grammar-file FNAME file to read grammar from\n"); - fprintf(stdout, " --cfg-negative-prompt PROMPT \n"); + fprintf(stdout, " --cfg-negative-prompt PROMPT\n"); fprintf(stdout, " negative prompt to use for guidance. (default: empty)\n"); + fprintf(stdout, " --cfg-negative-prompt-file FNAME\n"); + fprintf(stdout, " negative prompt file to use for guidance. (default: empty)\n"); fprintf(stdout, " --cfg-scale N strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); fprintf(stdout, " --rope-scale N RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale); fprintf(stdout, " --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base);