introduce validate_params,

use it in gpt_params_parse.

Parameter validation was not working in most cases
This commit is contained in:
maddes8cht 2023-07-30 15:28:23 +02:00
parent 11f3ca06b8
commit 9df732dae4

View File

@ -92,6 +92,27 @@ void process_escapes(std::string& input) {
input.resize(output_idx); input.resize(output_idx);
} }
bool validate_params(const std::string& arg, int argc, int& i, char** argv, const gpt_params& default_params, bool optional = false) {
if (++i >= argc) {
if (optional) {
// Argument is optional and not present, return false
return false;
} else {
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
throw std::runtime_error("Invalid parameter for argument: " + arg);
}
}
const std::string& nextArg = argv[i];
if (nextArg.empty() || (nextArg[0] == '-' && !std::isdigit(nextArg[1]))) {
throw std::runtime_error("Missing value for parameter: " + arg);
}
// Validation succeeded, return true to indicate that
return true;
}
bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
bool invalid_param = false; bool invalid_param = false;
bool escape_prompt = false; bool escape_prompt = false;
@ -99,6 +120,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
gpt_params default_params; gpt_params default_params;
const std::string arg_prefix = "--"; const std::string arg_prefix = "--";
try{
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
arg = argv[i]; arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) { if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
@ -106,43 +129,28 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.seed = std::stoul(argv[i]); params.seed = std::stoul(argv[i]);
} else if (arg == "-t" || arg == "--threads") { } else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.n_threads = std::stoi(argv[i]); params.n_threads = std::stoi(argv[i]);
if (params.n_threads <= 0) { if (params.n_threads <= 0) {
params.n_threads = std::thread::hardware_concurrency(); params.n_threads = std::thread::hardware_concurrency();
} }
} else if (arg == "-p" || arg == "--prompt") { } else if (arg == "-p" || arg == "--prompt") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.prompt = argv[i]; params.prompt = argv[i];
} else if (arg == "-e") { } else if (arg == "-e") {
escape_prompt = true; escape_prompt = true;
} else if (arg == "--prompt-cache") { } else if (arg == "--prompt-cache") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.path_prompt_cache = argv[i]; params.path_prompt_cache = argv[i];
} else if (arg == "--prompt-cache-all") { } else if (arg == "--prompt-cache-all") {
params.prompt_cache_all = true; params.prompt_cache_all = true;
} else if (arg == "--prompt-cache-ro") { } else if (arg == "--prompt-cache-ro") {
params.prompt_cache_ro = true; params.prompt_cache_ro = true;
} else if (arg == "-f" || arg == "--file") { } else if (arg == "-f" || arg == "--file") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
std::ifstream file(argv[i]); std::ifstream file(argv[i]);
if (!file) { if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -154,170 +162,89 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.prompt.pop_back(); params.prompt.pop_back();
} }
} else if (arg == "-n" || arg == "--n-predict") { } else if (arg == "-n" || arg == "--n-predict") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
} else if (arg == "--top-k") { } else if (arg == "--top-k") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.top_k = std::stoi(argv[i]); params.top_k = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx-size") { } else if (arg == "-c" || arg == "--ctx-size") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.n_ctx = std::stoi(argv[i]); params.n_ctx = std::stoi(argv[i]);
} else if (arg == "-gqa" || arg == "--gqa") { } else if (arg == "-gqa" || arg == "--gqa") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.n_gqa = std::stoi(argv[i]); params.n_gqa = std::stoi(argv[i]);
} else if (arg == "-eps" || arg == "--rms-norm-eps") { } else if (arg == "-eps" || arg == "--rms-norm-eps") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.rms_norm_eps = std::stof(argv[i]); params.rms_norm_eps = std::stof(argv[i]);
} else if (arg == "--rope-freq-base") { } else if (arg == "--rope-freq-base") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.rope_freq_base = std::stof(argv[i]); params.rope_freq_base = std::stof(argv[i]);
} else if (arg == "--rope-freq-scale") { } else if (arg == "--rope-freq-scale") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.rope_freq_scale = std::stof(argv[i]); params.rope_freq_scale = std::stof(argv[i]);
} else if (arg == "--memory-f32") { } else if (arg == "--memory-f32") {
params.memory_f16 = false; params.memory_f16 = false;
} else if (arg == "--top-p") { } else if (arg == "--top-p") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.top_p = std::stof(argv[i]); params.top_p = std::stof(argv[i]);
} else if (arg == "--temp") { } else if (arg == "--temp") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.temp = std::stof(argv[i]); params.temp = std::stof(argv[i]);
} else if (arg == "--tfs") { } else if (arg == "--tfs") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.tfs_z = std::stof(argv[i]); params.tfs_z = std::stof(argv[i]);
} else if (arg == "--typical") { } else if (arg == "--typical") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.typical_p = std::stof(argv[i]); params.typical_p = std::stof(argv[i]);
} else if (arg == "--repeat-last-n") { } else if (arg == "--repeat-last-n") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.repeat_last_n = std::stoi(argv[i]); params.repeat_last_n = std::stoi(argv[i]);
} else if (arg == "--repeat-penalty") { } else if (arg == "--repeat-penalty") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.repeat_penalty = std::stof(argv[i]); params.repeat_penalty = std::stof(argv[i]);
} else if (arg == "--frequency-penalty") { } else if (arg == "--frequency-penalty") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.frequency_penalty = std::stof(argv[i]); params.frequency_penalty = std::stof(argv[i]);
} else if (arg == "--presence-penalty") { } else if (arg == "--presence-penalty") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.presence_penalty = std::stof(argv[i]); params.presence_penalty = std::stof(argv[i]);
} else if (arg == "--mirostat") { } else if (arg == "--mirostat") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.mirostat = std::stoi(argv[i]); params.mirostat = std::stoi(argv[i]);
} else if (arg == "--mirostat-lr") { } else if (arg == "--mirostat-lr") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.mirostat_eta = std::stof(argv[i]); params.mirostat_eta = std::stof(argv[i]);
} else if (arg == "--mirostat-ent") { } else if (arg == "--mirostat-ent") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.mirostat_tau = std::stof(argv[i]); params.mirostat_tau = std::stof(argv[i]);
} else if (arg == "--cfg-negative-prompt") { } else if (arg == "--cfg-negative-prompt") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.cfg_negative_prompt = argv[i]; params.cfg_negative_prompt = argv[i];
} else if (arg == "--cfg-scale") { } else if (arg == "--cfg-scale") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.cfg_scale = std::stof(argv[i]); params.cfg_scale = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch-size") { } else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.n_batch = std::stoi(argv[i]); params.n_batch = std::stoi(argv[i]);
params.n_batch = std::min(512, params.n_batch); params.n_batch = std::min(512, params.n_batch);
} else if (arg == "--keep") { } else if (arg == "--keep") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.n_keep = std::stoi(argv[i]); params.n_keep = std::stoi(argv[i]);
} else if (arg == "--chunks") { } else if (arg == "--chunks") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.n_chunks = std::stoi(argv[i]); params.n_chunks = std::stoi(argv[i]);
} else if (arg == "-m" || arg == "--model") { } else if (arg == "-m" || arg == "--model") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.model = argv[i]; params.model = argv[i];
} else if (arg == "-a" || arg == "--alias") { } else if (arg == "-a" || arg == "--alias") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.model_alias = argv[i]; params.model_alias = argv[i];
} else if (arg == "--lora") { } else if (arg == "--lora") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.lora_adapter = argv[i]; params.lora_adapter = argv[i];
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--lora-base") { } else if (arg == "--lora-base") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.lora_base = argv[i]; params.lora_base = argv[i];
} else if (arg == "-i" || arg == "--interactive") { } else if (arg == "-i" || arg == "--interactive") {
params.interactive = true; params.interactive = true;
@ -334,10 +261,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--mlock") { } else if (arg == "--mlock") {
params.use_mlock = true; params.use_mlock = true;
} else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") { } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD #ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params.n_gpu_layers = std::stoi(argv[i]); params.n_gpu_layers = std::stoi(argv[i]);
#else #else
@ -345,20 +269,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif #endif
} else if (arg == "--main-gpu" || arg == "-mg") { } else if (arg == "--main-gpu" || arg == "-mg") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
params.main_gpu = std::stoi(argv[i]); params.main_gpu = std::stoi(argv[i]);
#else #else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n"); fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
#endif #endif
} else if (arg == "--tensor-split" || arg == "-ts") { } else if (arg == "--tensor-split" || arg == "-ts") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS #ifdef GGML_USE_CUBLAS
std::string arg_next = argv[i]; std::string arg_next = argv[i];
@ -395,30 +313,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--verbose-prompt") { } else if (arg == "--verbose-prompt") {
params.verbose_prompt = true; params.verbose_prompt = true;
} else if (arg == "-r" || arg == "--reverse-prompt") { } else if (arg == "-r" || arg == "--reverse-prompt") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.antiprompt.push_back(argv[i]); params.antiprompt.push_back(argv[i]);
} else if (arg == "--perplexity") { } else if (arg == "--perplexity") {
params.perplexity = true; params.perplexity = true;
} else if (arg == "--hellaswag") { } else if (arg == "--hellaswag") {
params.hellaswag = true; params.hellaswag = true;
} else if (arg == "--hellaswag-tasks") { } else if (arg == "--hellaswag-tasks") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.hellaswag_tasks = std::stoi(argv[i]); params.hellaswag_tasks = std::stoi(argv[i]);
} else if (arg == "--ignore-eos") { } else if (arg == "--ignore-eos") {
params.logit_bias[llama_token_eos()] = -INFINITY; params.logit_bias[llama_token_eos()] = -INFINITY;
} else if (arg == "--no-penalize-nl") { } else if (arg == "--no-penalize-nl") {
params.penalize_nl = false; params.penalize_nl = false;
} else if (arg == "-l" || arg == "--logit-bias") { } else if (arg == "-l" || arg == "--logit-bias") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
std::stringstream ss(argv[i]); std::stringstream ss(argv[i]);
llama_token key; llama_token key;
char sign; char sign;
@ -441,28 +350,16 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--in-prefix-bos") { } else if (arg == "--in-prefix-bos") {
params.input_prefix_bos = true; params.input_prefix_bos = true;
} else if (arg == "--in-prefix") { } else if (arg == "--in-prefix") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.input_prefix = argv[i]; params.input_prefix = argv[i];
} else if (arg == "--in-suffix") { } else if (arg == "--in-suffix") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.input_suffix = argv[i]; params.input_suffix = argv[i];
} else if (arg == "--grammar") { } else if (arg == "--grammar") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
params.grammar = argv[i]; params.grammar = argv[i];
} else if (arg == "--grammar-file") { } else if (arg == "--grammar-file") {
if (++i >= argc) { validate_params(arg, argc, i, argv, default_params);
invalid_param = true;
break;
}
std::ifstream file(argv[i]); std::ifstream file(argv[i]);
if (!file) { if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]); fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
@ -475,21 +372,27 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
std::back_inserter(params.grammar) std::back_inserter(params.grammar)
); );
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); // fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, default_params); // gpt_print_usage(argc, argv, default_params);
exit(1); std::string errorMessage = "Unknown argument: " + arg;
throw std::runtime_error(errorMessage);
} }
} }
/* Code block obsolete, as check is now performed in void validate_params
if (invalid_param) { if (invalid_param) {
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, default_params); gpt_print_usage(argc, argv, default_params);
exit(1); exit(1);
} }
*/
if (params.prompt_cache_all && if (params.prompt_cache_all &&
(params.interactive || params.interactive_first || (params.interactive || params.interactive_first ||
params.instruct)) { params.instruct)) {
fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n"); // fprintf(stderr, "error: --prompt-cache-all not supported in interactive mode yet\n");
gpt_print_usage(argc, argv, default_params); throw std::runtime_error("--prompt-cache-all not supported in interactive mode yet. Exiting.");
// gpt_print_usage(argc, argv, default_params);
exit(1); exit(1);
} }
@ -500,6 +403,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} }
return true; return true;
} catch (const std::exception& e) {
// Handle exceptions thrown by validate_params or other parts of the function
fprintf(stderr, "\nError: Parameter '%s':\n %s\n\n", arg.c_str(), e.what());
fprintf(stderr, "For detailed help, use: -h or --help\n");
return false; // Return false to indicate an error in parameter parsing
}
} }
void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {