From 39edee5136d1fad5b1b810fe66bee630d0e055a4 Mon Sep 17 00:00:00 2001 From: Dewi Jones Date: Sat, 15 Jul 2023 19:56:57 +0000 Subject: [PATCH] Add flag to make reverse prompt case insensitive --- examples/common.cpp | 6 +++++- examples/common.h | 1 + examples/main/main.cpp | 7 +++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/examples/common.cpp b/examples/common.cpp index 8705127cb..b5d09f12b 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -379,7 +379,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.export_cgraph = true; } else if (arg == "--verbose-prompt") { params.verbose_prompt = true; - } else if (arg == "-r" || arg == "--reverse-prompt") { + } else if (arg == "-r_ci" || arg == "--reverse_prompt_case_insensitive") { + params.reverse_prompt_case_insensitive = true; + } + else if (arg == "-r" || arg == "--reverse-prompt") { if (++i >= argc) { invalid_param = true; break; @@ -465,6 +468,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); + fprintf(stderr, " -r_ci make the check for the reverse prompt case insensitive\n"); fprintf(stderr, " halt generation at PROMPT, return control in interactive mode\n"); fprintf(stderr, " (can be specified more than once for multiple prompts).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); diff --git a/examples/common.h b/examples/common.h index f52fef629..47a0846ae 100644 --- a/examples/common.h +++ b/examples/common.h @@ -88,6 +88,7 @@ struct gpt_params { bool numa = false; // attempt optimizations that help on some NUMA systems bool export_cgraph = false; // export the computation graph bool verbose_prompt = false; // print prompt tokens before generation + bool reverse_prompt_case_insensitive = false; // make the reverse prompt case insensitive }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bcbcf12b0..01d7b90c2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -7,6 +7,7 @@ #include "llama.h" #include "build-info.h" +#include #include #include #include @@ -665,6 +666,12 @@ int main(int argc, char ** argv) { ? last_output.length() - static_cast(antiprompt.length() + extra_padding) : 0; + if (params.reverse_prompt_case_insensitive) { + // convert both strings to lower case + std::transform(antiprompt.begin(), antiprompt.end(), antiprompt.begin(), ::tolower); + std::transform(last_output.begin(), last_output.end(), last_output.begin(), ::tolower); + } + if (last_output.find(antiprompt.c_str(), search_start_pos) != std::string::npos) { if (params.interactive) { is_interacting = true;