From 2485d7a4d39406cd0f468e35551b472cceb5bd61 Mon Sep 17 00:00:00 2001 From: DannyDaemonic Date: Tue, 2 May 2023 18:46:20 -0700 Subject: [PATCH] Process escape sequences given in prompts (#1173) --- examples/common.cpp | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/examples/common.cpp b/examples/common.cpp index 222b4fa73..1a2f4743a 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -66,6 +66,33 @@ int32_t get_num_physical_cores() { return n_threads > 0 ? (n_threads <= 4 ? n_threads : n_threads / 2) : 4; } +std::string process_escapes(const char* input) { + std::string output; + + if (input != nullptr) { + std::size_t input_len = std::strlen(input); + output.reserve(input_len); + + for (std::size_t i = 0; i < input_len; ++i) { + if (input[i] == '\\' && i + 1 < input_len) { + switch (input[++i]) { + case 'n': output.push_back('\n'); break; + case 't': output.push_back('\t'); break; + case '\'': output.push_back('\''); break; + case '\"': output.push_back('\"'); break; + case '\\': output.push_back('\\'); break; + default: output.push_back('\\'); + output.push_back(input[i]); break; + } + } else { + output.push_back(input[i]); + } + } + } + + return output; +} + bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool invalid_param = false; std::string arg; @@ -91,7 +118,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { invalid_param = true; break; } - params.prompt = argv[i]; + params.prompt = process_escapes(argv[i]); } else if (arg == "--session") { if (++i >= argc) { invalid_param = true;