Command line args bounds checking (#424)

* command line args bounds checking

* unknown and invalid param exit codes 0 -> 1
This commit is contained in:
anzz1 2023-03-23 19:54:28 +02:00 committed by GitHub
parent a18c19259a
commit ea10d3ded2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

101
utils.cpp
View File

@ -26,41 +26,95 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency()); params.n_threads = std::max(1, (int32_t) std::thread::hardware_concurrency());
} }
bool invalid_param = false;
std::string arg;
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
std::string arg = argv[i]; arg = argv[i];
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
params.seed = std::stoi(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.seed = std::stoi(argv[i]);
} else if (arg == "-t" || arg == "--threads") { } else if (arg == "-t" || arg == "--threads") {
params.n_threads = std::stoi(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.n_threads = std::stoi(argv[i]);
} else if (arg == "-p" || arg == "--prompt") { } else if (arg == "-p" || arg == "--prompt") {
params.prompt = argv[++i]; if (++i >= argc) {
invalid_param = true;
break;
}
params.prompt = argv[i];
} else if (arg == "-f" || arg == "--file") { } else if (arg == "-f" || arg == "--file") {
std::ifstream file(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt)); std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
if (params.prompt.back() == '\n') { if (params.prompt.back() == '\n') {
params.prompt.pop_back(); params.prompt.pop_back();
} }
} else if (arg == "-n" || arg == "--n_predict") { } else if (arg == "-n" || arg == "--n_predict") {
params.n_predict = std::stoi(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.n_predict = std::stoi(argv[i]);
} else if (arg == "--top_k") { } else if (arg == "--top_k") {
params.top_k = std::stoi(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.top_k = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx_size") { } else if (arg == "-c" || arg == "--ctx_size") {
params.n_ctx = std::stoi(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.n_ctx = std::stoi(argv[i]);
} else if (arg == "--memory_f16") { } else if (arg == "--memory_f16") {
params.memory_f16 = true; params.memory_f16 = true;
} else if (arg == "--top_p") { } else if (arg == "--top_p") {
params.top_p = std::stof(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.top_p = std::stof(argv[i]);
} else if (arg == "--temp") { } else if (arg == "--temp") {
params.temp = std::stof(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.temp = std::stof(argv[i]);
} else if (arg == "--repeat_last_n") { } else if (arg == "--repeat_last_n") {
params.repeat_last_n = std::stoi(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.repeat_last_n = std::stoi(argv[i]);
} else if (arg == "--repeat_penalty") { } else if (arg == "--repeat_penalty") {
params.repeat_penalty = std::stof(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.repeat_penalty = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch_size") { } else if (arg == "-b" || arg == "--batch_size") {
params.n_batch = std::stoi(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.n_batch = std::stoi(argv[i]);
} else if (arg == "-m" || arg == "--model") { } else if (arg == "-m" || arg == "--model") {
params.model = argv[++i]; if (++i >= argc) {
invalid_param = true;
break;
}
params.model = argv[i];
} else if (arg == "-i" || arg == "--interactive") { } else if (arg == "-i" || arg == "--interactive") {
params.interactive = true; params.interactive = true;
} else if (arg == "--interactive-first") { } else if (arg == "--interactive-first") {
@ -70,13 +124,21 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else if (arg == "--color") { } else if (arg == "--color") {
params.use_color = true; params.use_color = true;
} else if (arg == "-r" || arg == "--reverse-prompt") { } else if (arg == "-r" || arg == "--reverse-prompt") {
params.antiprompt.push_back(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.antiprompt.push_back(argv[i]);
} else if (arg == "--perplexity") { } else if (arg == "--perplexity") {
params.perplexity = true; params.perplexity = true;
} else if (arg == "--ignore-eos") { } else if (arg == "--ignore-eos") {
params.ignore_eos = true; params.ignore_eos = true;
} else if (arg == "--n_parts") { } else if (arg == "--n_parts") {
params.n_parts = std::stoi(argv[++i]); if (++i >= argc) {
invalid_param = true;
break;
}
params.n_parts = std::stoi(argv[i]);
} else if (arg == "-h" || arg == "--help") { } else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params); gpt_print_usage(argc, argv, params);
exit(0); exit(0);
@ -85,9 +147,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else { } else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, params); gpt_print_usage(argc, argv, params);
exit(0); exit(1);
} }
} }
if (invalid_param) {
fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
gpt_print_usage(argc, argv, params);
exit(1);
}
return true; return true;
} }