Fix: possible out-of-bounds error,

remove default_params
This commit is contained in:
Mathias Bachmann 2023-07-31 13:19:13 +02:00
parent a9a2647536
commit 204f76d52e

View File

@ -92,7 +92,7 @@ void process_escapes(std::string& input) {
input.resize(output_idx); input.resize(output_idx);
} }
bool validate_params(const std::string& arg, int argc, int& i, char** argv, const gpt_params& default_params, bool optional = false) { bool validate_params(const std::string& arg, int argc, int& i, char** argv, bool optional = false) {
if (++i >= argc) { if (++i >= argc) {
if (optional) { if (optional) {
// Argument is optional and not present, return false // Argument is optional and not present, return false
@ -105,7 +105,7 @@ bool validate_params(const std::string& arg, int argc, int& i, char** argv, cons
const std::string& nextArg = argv[i]; const std::string& nextArg = argv[i];
if (nextArg.empty() || (nextArg[0] == '-' && !std::isdigit(nextArg[1]))) { if (nextArg.empty() || (nextArg.size() >= 2 && nextArg[0] == '-' && !std::isdigit(nextArg[1]))) {
throw std::runtime_error("Missing value for parameter: " + arg); throw std::runtime_error("Missing value for parameter: " + arg);
} }
@ -129,28 +129,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.seed = std::stoul(argv[i]); params.seed = std::stoul(argv[i]);
} else if (arg == "-t" || arg == "--threads") { } else if (arg == "-t" || arg == "--threads") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.n_threads = std::stoi(argv[i]); params.n_threads = std::stoi(argv[i]);
if (params.n_threads <= 0) { if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency(); params.n_threads = std::thread::hardware_concurrency();
} }
} else if (arg == "-p" || arg == "--prompt") { } else if (arg == "-p" || arg == "--prompt") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.prompt = argv[i]; params.prompt = argv[i];
} else if (arg == "-e") { } else if (arg == "-e") {
escape_prompt = true; escape_prompt = true;
} else if (arg == "--prompt-cache") { } else if (arg == "--prompt-cache") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.path_prompt_cache = argv[i]; params.path_prompt_cache = argv[i];
} else if (arg == "--prompt-cache-all") { } else if (arg == "--prompt-cache-all") {
params.prompt_cache_all = true; params.prompt_cache_all = true;
} else if (arg == "--prompt-cache-ro") { } else if (arg == "--prompt-cache-ro") {
params.prompt_cache_ro = true; params.prompt_cache_ro = true;
} else if (arg == "-f" || arg == "--file") { } else if (arg == "-f" || arg == "--file") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
std::ifstream file(argv[i]); std::ifstream file(argv[i]);
if (!file) { if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -162,89 +162,89 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.prompt.pop_back(); params.prompt.pop_back();
} }
} else if (arg == "-n" || arg == "--n-predict") { } else if (arg == "-n" || arg == "--n-predict") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
} else if (arg == "--top-k") { } else if (arg == "--top-k") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.top_k = std::stoi(argv[i]); params.top_k = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx-size") { } else if (arg == "-c" || arg == "--ctx-size") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.n_ctx = std::stoi(argv[i]); params.n_ctx = std::stoi(argv[i]);
} else if (arg == "-gqa" || arg == "--gqa") { } else if (arg == "-gqa" || arg == "--gqa") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.n_gqa = std::stoi(argv[i]); params.n_gqa = std::stoi(argv[i]);
} else if (arg == "-eps" || arg == "--rms-norm-eps") { } else if (arg == "-eps" || arg == "--rms-norm-eps") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.rms_norm_eps = std::stof(argv[i]); params.rms_norm_eps = std::stof(argv[i]);
} else if (arg == "--rope-freq-base") { } else if (arg == "--rope-freq-base") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.rope_freq_base = std::stof(argv[i]); params.rope_freq_base = std::stof(argv[i]);
} else if (arg == "--rope-freq-scale") { } else if (arg == "--rope-freq-scale") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.rope_freq_scale = std::stof(argv[i]); params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--memory-f32") { } else if (arg == "--memory-f32") {
params.memory_f16 = false; params.memory_f16 = false;
} else if (arg == "--top-p") { } else if (arg == "--top-p") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.top_p = std::stof(argv[i]); params.top_p = std::stof(argv[i]);
} else if (arg == "--temp") { } else if (arg == "--temp") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.temp = std::stof(argv[i]); params.temp = std::stof(argv[i]);
} else if (arg == "--tfs") { } else if (arg == "--tfs") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.tfs_z = std::stof(argv[i]); params.tfs_z = std::stof(argv[i]);
} else if (arg == "--typical") { } else if (arg == "--typical") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.typical_p = std::stof(argv[i]); params.typical_p = std::stof(argv[i]);
} else if (arg == "--repeat-last-n") { } else if (arg == "--repeat-last-n") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.repeat_last_n = std::stoi(argv[i]); params.repeat_last_n = std::stoi(argv[i]);
} else if (arg == "--repeat-penalty") { } else if (arg == "--repeat-penalty") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.repeat_penalty = std::stof(argv[i]); params.repeat_penalty = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") { } else if (arg == "--frequency-penalty") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.frequency_penalty = std::stof(argv[i]); params.frequency_penalty = std::stof(argv[i]);
} else if (arg == "--presence-penalty") { } else if (arg == "--presence-penalty") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.presence_penalty = std::stof(argv[i]); params.presence_penalty = std::stof(argv[i]);
} else if (arg == "--mirostat") { } else if (arg == "--mirostat") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.mirostat = std::stoi(argv[i]); params.mirostat = std::stoi(argv[i]);
} else if (arg == "--mirostat-lr") { } else if (arg == "--mirostat-lr") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.mirostat_eta = std::stof(argv[i]); params.mirostat_eta = std::stof(argv[i]);
} else if (arg == "--mirostat-ent") { } else if (arg == "--mirostat-ent") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.mirostat_tau = std::stof(argv[i]); params.mirostat_tau = std::stof(argv[i]);
} else if (arg == "--cfg-negative-prompt") { } else if (arg == "--cfg-negative-prompt") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.cfg_negative_prompt = argv[i]; params.cfg_negative_prompt = argv[i];
} else if (arg == "--cfg-scale") { } else if (arg == "--cfg-scale") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.cfg_scale = std::stof(argv[i]); params.cfg_scale = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch-size") { } else if (arg == "-b" || arg == "--batch-size") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.n_batch = std::stoi(argv[i]); params.n_batch = std::stoi(argv[i]);
params.n_batch = std::min(512, params.n_batch); params.n_batch = std::min(512, params.n_batch);
} else if (arg == "--keep") { } else if (arg == "--keep") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.n_keep = std::stoi(argv[i]); params.n_keep = std::stoi(argv[i]);
} else if (arg == "--chunks") { } else if (arg == "--chunks") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.n_chunks = std::stoi(argv[i]); params.n_chunks = std::stoi(argv[i]);
} else if (arg == "-m" || arg == "--model") { } else if (arg == "-m" || arg == "--model") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.model = argv[i]; params.model = argv[i];
} else if (arg == "-a" || arg == "--alias") { } else if (arg == "-a" || arg == "--alias") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.model_alias = argv[i]; params.model_alias = argv[i];
} else if (arg == "--lora") { } else if (arg == "--lora") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.lora_adapter = argv[i]; params.lora_adapter = argv[i];
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--lora-base") { } else if (arg == "--lora-base") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.lora_base = argv[i]; params.lora_base = argv[i];
} else if (arg == "-i" || arg == "--interactive") { } else if (arg == "-i" || arg == "--interactive") {
params.interactive = true; params.interactive = true;
@ -261,7 +261,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--mlock") { } else if (arg == "--mlock") {
params.use_mlock = true; params.use_mlock = true;
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD #ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params.n_gpu_layers = std::stoi(argv[i]); params.n_gpu_layers = std::stoi(argv[i]);
#else #else
@ -269,14 +269,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif #endif
} else if (arg == "--main-gpu" || arg == "-mg") { } else if (arg == "--main-gpu" || arg == "-mg") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
params.main_gpu = std::stoi(argv[i]); params.main_gpu = std::stoi(argv[i]);
#else #else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n"); fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
#endif #endif
} else if (arg == "--tensor-split" || arg == "-ts") { } else if (arg == "--tensor-split" || arg == "-ts") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
std::string arg_next = argv[i]; std::string arg_next = argv[i];
@ -313,21 +313,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--verbose-prompt") { } else if (arg == "--verbose-prompt") {
params.verbose_prompt = true; params.verbose_prompt = true;
} else if (arg == "-r" || arg == "--reverse-prompt") { } else if (arg == "-r" || arg == "--reverse-prompt") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.antiprompt.push_back(argv[i]); params.antiprompt.push_back(argv[i]);
} else if (arg == "--perplexity") { } else if (arg == "--perplexity") {
params.perplexity = true; params.perplexity = true;
} else if (arg == "--hellaswag") { } else if (arg == "--hellaswag") {
params.hellaswag = true; params.hellaswag = true;
} else if (arg == "--hellaswag-tasks") { } else if (arg == "--hellaswag-tasks") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.hellaswag_tasks = std::stoi(argv[i]); params.hellaswag_tasks = std::stoi(argv[i]);
} else if (arg == "--ignore-eos") { } else if (arg == "--ignore-eos") {
params.logit_bias[llama_token_eos()] = -INFINITY; params.logit_bias[llama_token_eos()] = -INFINITY;
} else if (arg == "--no-penalize-nl") { } else if (arg == "--no-penalize-nl") {
params.penalize_nl = false; params.penalize_nl = false;
} else if (arg == "-l" || arg == "--logit-bias") { } else if (arg == "-l" || arg == "--logit-bias") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
std::stringstream ss(argv[i]); std::stringstream ss(argv[i]);
llama_token key; llama_token key;
char sign; char sign;
@ -350,16 +350,16 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--in-prefix-bos") { } else if (arg == "--in-prefix-bos") {
params.input_prefix_bos = true; params.input_prefix_bos = true;
} else if (arg == "--in-prefix") { } else if (arg == "--in-prefix") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.input_prefix = argv[i]; params.input_prefix = argv[i];
} else if (arg == "--in-suffix") { } else if (arg == "--in-suffix") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.input_suffix = argv[i]; params.input_suffix = argv[i];
} else if (arg == "--grammar") { } else if (arg == "--grammar") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
params.grammar = argv[i]; params.grammar = argv[i];
} else if (arg == "--grammar-file") { } else if (arg == "--grammar-file") {
validate_params(arg, argc, i, argv, default_params); validate_params(arg, argc, i, argv);
std::ifstream file(argv[i]); std::ifstream file(argv[i]);
if (!file) { if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);