Add flag to make reverse prompt case insensitive

This commit is contained in:
Dewi Jones 2023-07-15 19:56:57 +00:00
parent 6e7cca4047
commit 39edee5136
3 changed files with 13 additions and 1 deletions

View File

@ -379,7 +379,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.export_cgraph = true; params.export_cgraph = true;
} else if (arg == "--verbose-prompt") { } else if (arg == "--verbose-prompt") {
params.verbose_prompt = true; params.verbose_prompt = true;
} else if (arg == "-r" || arg == "--reverse-prompt") { } else if (arg == "-r_ci" || arg == "--reverse_prompt_case_insensitive") {
params.reverse_prompt_case_insensitive = true;
}
else if (arg == "-r" || arg == "--reverse-prompt") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -465,6 +468,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n");
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " -r_ci make the check for the reverse prompt case insensitive\n");
fprintf(stderr, " halt generation at PROMPT, return control in interactive mode\n"); fprintf(stderr, " halt generation at PROMPT, return control in interactive mode\n");
fprintf(stderr, " (can be specified more than once for multiple prompts).\n"); fprintf(stderr, " (can be specified more than once for multiple prompts).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");

View File

@ -88,6 +88,7 @@ struct gpt_params {
bool numa = false; // attempt optimizations that help on some NUMA systems bool numa = false; // attempt optimizations that help on some NUMA systems
bool export_cgraph = false; // export the computation graph bool export_cgraph = false; // export the computation graph
bool verbose_prompt = false; // print prompt tokens before generation bool verbose_prompt = false; // print prompt tokens before generation
bool reverse_prompt_case_insensitive = false; // make the reverse prompt case insensitive
}; };
bool gpt_params_parse(int argc, char ** argv, gpt_params & params); bool gpt_params_parse(int argc, char ** argv, gpt_params & params);

View File

@ -7,6 +7,7 @@
#include "llama.h" #include "llama.h"
#include "build-info.h" #include "build-info.h"
#include <algorithm>
#include <cassert> #include <cassert>
#include <cinttypes> #include <cinttypes>
#include <cmath> #include <cmath>
@ -665,6 +666,12 @@ int main(int argc, char ** argv) {
? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding) ? last_output.length() - static_cast<size_t>(antiprompt.length() + extra_padding)
: 0; : 0;
if (params.reverse_prompt_case_insensitive) {
// convert both strings to lower case
std::transform(antiprompt.begin(), antiprompt.end(), antiprompt.begin(), ::tolower);
std::transform(last_output.begin(), last_output.end(), last_output.begin(), ::tolower);
}
if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) { if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) {
if (params.interactive) { if (params.interactive) {
is_interacting = true; is_interacting = true;