CLI args use - instead of _, backwards compatible (#1416)

This commit is contained in:
Johannes Gäßler 2023-05-12 16:34:55 +02:00 committed by GitHub
parent 553fd4d4b5
commit 773ee249fb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -91,9 +91,13 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
bool escape_prompt = false; bool escape_prompt = false;
std::string arg; std::string arg;
gpt_params default_params; gpt_params default_params;
const std::string arg_prefix = "--";
for (int i = 1; i < argc; i++) { for (int i = 1; i < argc; i++) {
arg = argv[i]; arg = argv[i];
if (arg.compare(0, arg_prefix.size(), arg_prefix) == 0) {
std::replace(arg.begin(), arg.end(), '_', '-');
}
if (arg == "-s" || arg == "--seed") { if (arg == "-s" || arg == "--seed") {
#if defined(GGML_USE_CUBLAS) #if defined(GGML_USE_CUBLAS)
@ -141,27 +145,27 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
if (params.prompt.back() == '\n') { if (params.prompt.back() == '\n') {
params.prompt.pop_back(); params.prompt.pop_back();
} }
} else if (arg == "-n" || arg == "--n_predict") { } else if (arg == "-n" || arg == "--n-predict") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
} else if (arg == "--top_k") { } else if (arg == "--top-k") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.top_k = std::stoi(argv[i]); params.top_k = std::stoi(argv[i]);
} else if (arg == "-c" || arg == "--ctx_size") { } else if (arg == "-c" || arg == "--ctx-size") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.n_ctx = std::stoi(argv[i]); params.n_ctx = std::stoi(argv[i]);
} else if (arg == "--memory_f32") { } else if (arg == "--memory-f32") {
params.memory_f16 = false; params.memory_f16 = false;
} else if (arg == "--top_p") { } else if (arg == "--top-p") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -185,25 +189,25 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.typical_p = std::stof(argv[i]); params.typical_p = std::stof(argv[i]);
} else if (arg == "--repeat_last_n") { } else if (arg == "--repeat-last-n") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.repeat_last_n = std::stoi(argv[i]); params.repeat_last_n = std::stoi(argv[i]);
} else if (arg == "--repeat_penalty") { } else if (arg == "--repeat-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.repeat_penalty = std::stof(argv[i]); params.repeat_penalty = std::stof(argv[i]);
} else if (arg == "--frequency_penalty") { } else if (arg == "--frequency-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.frequency_penalty = std::stof(argv[i]); params.frequency_penalty = std::stof(argv[i]);
} else if (arg == "--presence_penalty") { } else if (arg == "--presence-penalty") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -215,19 +219,19 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.mirostat = std::stoi(argv[i]); params.mirostat = std::stoi(argv[i]);
} else if (arg == "--mirostat_lr") { } else if (arg == "--mirostat-lr") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.mirostat_eta = std::stof(argv[i]); params.mirostat_eta = std::stof(argv[i]);
} else if (arg == "--mirostat_ent") { } else if (arg == "--mirostat-ent") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
} }
params.mirostat_tau = std::stof(argv[i]); params.mirostat_tau = std::stof(argv[i]);
} else if (arg == "-b" || arg == "--batch_size") { } else if (arg == "-b" || arg == "--batch-size") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -310,7 +314,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
invalid_param = true; invalid_param = true;
break; break;
} }
} else if (arg == "--n_parts") { } else if (arg == "--n-parts") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
break; break;
@ -384,31 +388,31 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n"); fprintf(stderr, " --in-suffix STRING string to suffix after user inputs with (default: empty)\n");
fprintf(stderr, " -f FNAME, --file FNAME\n"); fprintf(stderr, " -f FNAME, --file FNAME\n");
fprintf(stderr, " prompt file to start generation.\n"); fprintf(stderr, " prompt file to start generation.\n");
fprintf(stderr, " -n N, --n_predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); fprintf(stderr, " -n N, --n-predict N number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict);
fprintf(stderr, " --top_k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k); fprintf(stderr, " --top-k N top-k sampling (default: %d, 0 = disabled)\n", params.top_k);
fprintf(stderr, " --top_p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p); fprintf(stderr, " --top-p N top-p sampling (default: %.1f, 1.0 = disabled)\n", (double)params.top_p);
fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z); fprintf(stderr, " --tfs N tail free sampling, parameter z (default: %.1f, 1.0 = disabled)\n", (double)params.tfs_z);
fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p); fprintf(stderr, " --typical N locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)\n", (double)params.typical_p);
fprintf(stderr, " --repeat_last_n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n); fprintf(stderr, " --repeat-last-n N last n tokens to consider for penalize (default: %d, 0 = disabled, -1 = ctx_size)\n", params.repeat_last_n);
fprintf(stderr, " --repeat_penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty); fprintf(stderr, " --repeat-penalty N penalize repeat sequence of tokens (default: %.1f, 1.0 = disabled)\n", (double)params.repeat_penalty);
fprintf(stderr, " --presence_penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty); fprintf(stderr, " --presence-penalty N repeat alpha presence penalty (default: %.1f, 0.0 = disabled)\n", (double)params.presence_penalty);
fprintf(stderr, " --frequency_penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty); fprintf(stderr, " --frequency-penalty N repeat alpha frequency penalty (default: %.1f, 0.0 = disabled)\n", (double)params.frequency_penalty);
fprintf(stderr, " --mirostat N use Mirostat sampling.\n"); fprintf(stderr, " --mirostat N use Mirostat sampling.\n");
fprintf(stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n"); fprintf(stderr, " Top K, Nucleus, Tail Free and Locally Typical samplers are ignored if used.\n");
fprintf(stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat); fprintf(stderr, " (default: %d, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)\n", params.mirostat);
fprintf(stderr, " --mirostat_lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta); fprintf(stderr, " --mirostat-lr N Mirostat learning rate, parameter eta (default: %.1f)\n", (double)params.mirostat_eta);
fprintf(stderr, " --mirostat_ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau); fprintf(stderr, " --mirostat-ent N Mirostat target entropy, parameter tau (default: %.1f)\n", (double)params.mirostat_tau);
fprintf(stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n"); fprintf(stderr, " -l TOKEN_ID(+/-)BIAS, --logit-bias TOKEN_ID(+/-)BIAS\n");
fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n"); fprintf(stderr, " modifies the likelihood of token appearing in the completion,\n");
fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n"); fprintf(stderr, " i.e. `--logit-bias 15043+1` to increase likelihood of token ' Hello',\n");
fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n"); fprintf(stderr, " or `--logit-bias 15043-1` to decrease likelihood of token ' Hello'\n");
fprintf(stderr, " -c N, --ctx_size N size of the prompt context (default: %d)\n", params.n_ctx); fprintf(stderr, " -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); fprintf(stderr, " --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
fprintf(stderr, " --no-penalize-nl do not penalize newline token\n"); fprintf(stderr, " --no-penalize-nl do not penalize newline token\n");
fprintf(stderr, " --memory_f32 use f32 instead of f16 for memory key+value\n"); fprintf(stderr, " --memory-f32 use f32 instead of f16 for memory key+value\n");
fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp); fprintf(stderr, " --temp N temperature (default: %.1f)\n", (double)params.temp);
fprintf(stderr, " --n_parts N number of model parts (default: -1 = determine from dimensions)\n"); fprintf(stderr, " --n-parts N number of model parts (default: -1 = determine from dimensions)\n");
fprintf(stderr, " -b N, --batch_size N batch size for prompt processing (default: %d)\n", params.n_batch); fprintf(stderr, " -b N, --batch-size N batch size for prompt processing (default: %d)\n", params.n_batch);
fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); fprintf(stderr, " --keep number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
if (llama_mlock_supported()) { if (llama_mlock_supported()) {