From 204f76d52e3ad576113e7aab59bcbc052607cdd7 Mon Sep 17 00:00:00 2001 From: Mathias Bachmann Date: Mon, 31 Jul 2023 13:19:13 +0200 Subject: [PATCH] Fix: possible out-of-bounds error, remove default_params --- examples/common.cpp | 88 ++++++++++++++++++++++----------------------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 52ca32372..c82aafbbc 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -92,7 +92,7 @@ void process_escapes(std::string& input) { input.resize(output_idx); } -bool validate_params(const std::string& arg, int argc, int& i, char** argv, const gpt_params& default_params, bool optional = false) { +bool validate_params(const std::string& arg, int argc, int& i, char** argv, bool optional = false) { if (++i >= argc) { if (optional) { // Argument is optional and not present, return false @@ -105,7 +105,7 @@ bool validate_params(const std::string& arg, int argc, int& i, char** argv, cons const std::string& nextArg = argv[i]; - if (nextArg.empty() || (nextArg[0] == '-' && !std::isdigit(nextArg[1]))) { + if (nextArg.empty() || (nextArg.size() >= 2 && nextArg[0] == '-' && !std::isdigit(nextArg[1]))) { throw std::runtime_error("Missing value for parameter: " + arg); } @@ -129,28 +129,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } if (arg == "-s" || arg == "--seed") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.seed = std::stoul(argv[i]); } else if (arg == "-t" || arg == "--threads") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.n_threads = std::stoi(argv[i]); if (params.n_threads <= 0) { params.n_threads = std::thread::hardware_concurrency(); } } else if (arg == "-p" || arg == "--prompt") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.prompt = argv[i]; } else if (arg == "-e") { escape_prompt = true; } else if (arg == "--prompt-cache") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.path_prompt_cache = argv[i]; } else if (arg == "--prompt-cache-all") { params.prompt_cache_all = true; } else if (arg == "--prompt-cache-ro") { params.prompt_cache_ro = true; } else if (arg == "-f" || arg == "--file") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); @@ -162,89 +162,89 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.prompt.pop_back(); } } else if (arg == "-n" || arg == "--n-predict") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.n_predict = std::stoi(argv[i]); } else if (arg == "--top-k") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.top_k = std::stoi(argv[i]); } else if (arg == "-c" || arg == "--ctx-size") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.n_ctx = std::stoi(argv[i]); } else if (arg == "-gqa" || arg == "--gqa") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.n_gqa = std::stoi(argv[i]); } else if (arg == "-eps" || arg == "--rms-norm-eps") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.rms_norm_eps = std::stof(argv[i]); } else if (arg == "--rope-freq-base") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.rope_freq_base = std::stof(argv[i]); } else if (arg == "--rope-freq-scale") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.rope_freq_scale = std::stof(argv[i]); } else if (arg == "--memory-f32") { params.memory_f16 = false; } else if (arg == "--top-p") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.top_p = std::stof(argv[i]); } else if (arg == "--temp") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.temp = std::stof(argv[i]); } else if (arg == "--tfs") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.tfs_z = std::stof(argv[i]); } else if (arg == "--typical") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.typical_p = std::stof(argv[i]); } else if (arg == "--repeat-last-n") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.repeat_last_n = std::stoi(argv[i]); } else if (arg == "--repeat-penalty") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.repeat_penalty = std::stof(argv[i]); } else if (arg == "--frequency-penalty") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.frequency_penalty = std::stof(argv[i]); } else if (arg == "--presence-penalty") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.presence_penalty = std::stof(argv[i]); } else if (arg == "--mirostat") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.mirostat = std::stoi(argv[i]); } else if (arg == "--mirostat-lr") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.mirostat_eta = std::stof(argv[i]); } else if (arg == "--mirostat-ent") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.mirostat_tau = std::stof(argv[i]); } else if (arg == "--cfg-negative-prompt") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.cfg_negative_prompt = argv[i]; } else if (arg == "--cfg-scale") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.cfg_scale = std::stof(argv[i]); } else if (arg == "-b" || arg == "--batch-size") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.n_batch = std::stoi(argv[i]); params.n_batch = std::min(512, params.n_batch); } else if (arg == "--keep") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.n_keep = std::stoi(argv[i]); } else if (arg == "--chunks") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.n_chunks = std::stoi(argv[i]); } else if (arg == "-m" || arg == "--model") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.model = argv[i]; } else if (arg == "-a" || arg == "--alias") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.model_alias = argv[i]; } else if (arg == "--lora") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.lora_adapter = argv[i]; params.use_mmap = false; } else if (arg == "--lora-base") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.lora_base = argv[i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; @@ -261,7 +261,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "--mlock") { params.use_mlock = true; } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); #ifdef LLAMA_SUPPORTS_GPU_OFFLOAD params.n_gpu_layers = std::stoi(argv[i]); #else @@ -269,14 +269,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); #endif } else if (arg == "--main-gpu" || arg == "-mg") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); #ifdef GGML_USE_CUBLAS params.main_gpu = std::stoi(argv[i]); #else fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n"); #endif } else if (arg == "--tensor-split" || arg == "-ts") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); #ifdef GGML_USE_CUBLAS std::string arg_next = argv[i]; @@ -313,21 +313,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-r" || arg == "--reverse-prompt") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.antiprompt.push_back(argv[i]); } else if (arg == "--perplexity") { params.perplexity = true; } else if (arg == "--hellaswag") { params.hellaswag = true; } else if (arg == "--hellaswag-tasks") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.hellaswag_tasks = std::stoi(argv[i]); } else if (arg == "--ignore-eos") { params.logit_bias[llama_token_eos()] = -INFINITY; } else if (arg == "--no-penalize-nl") { params.penalize_nl = false; } else if (arg == "-l" || arg == "--logit-bias") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); std::stringstream ss(argv[i]); llama_token key; char sign; @@ -350,16 +350,16 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "--in-prefix-bos") { params.input_prefix_bos = true; } else if (arg == "--in-prefix") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.input_prefix = argv[i]; } else if (arg == "--in-suffix") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.input_suffix = argv[i]; } else if (arg == "--grammar") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); params.grammar = argv[i]; } else if (arg == "--grammar-file") { - validate_params(arg, argc, i, argv, default_params); + validate_params(arg, argc, i, argv); std::ifstream file(argv[i]); if (!file) { fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);