mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
speculative : threading options (#4959)
* speculative: expose draft threading * fix usage format * accept -td and -tbd args * speculative: revert default behavior when -td is unspecified * fix trailing whitespace
This commit is contained in:
parent
3e5ca7931c
commit
e0324285a5
@ -167,6 +167,24 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||||||
if (params.n_threads_batch <= 0) {
|
if (params.n_threads_batch <= 0) {
|
||||||
params.n_threads_batch = std::thread::hardware_concurrency();
|
params.n_threads_batch = std::thread::hardware_concurrency();
|
||||||
}
|
}
|
||||||
|
} else if (arg == "-td" || arg == "--threads-draft") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.n_threads_draft = std::stoi(argv[i]);
|
||||||
|
if (params.n_threads_draft <= 0) {
|
||||||
|
params.n_threads_draft = std::thread::hardware_concurrency();
|
||||||
|
}
|
||||||
|
} else if (arg == "-tbd" || arg == "--threads-batch-draft") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.n_threads_batch_draft = std::stoi(argv[i]);
|
||||||
|
if (params.n_threads_batch_draft <= 0) {
|
||||||
|
params.n_threads_batch_draft = std::thread::hardware_concurrency();
|
||||||
|
}
|
||||||
} else if (arg == "-p" || arg == "--prompt") {
|
} else if (arg == "-p" || arg == "--prompt") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -845,6 +863,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
|
printf(" -t N, --threads N number of threads to use during generation (default: %d)\n", params.n_threads);
|
||||||
printf(" -tb N, --threads-batch N\n");
|
printf(" -tb N, --threads-batch N\n");
|
||||||
printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
|
printf(" number of threads to use during batch and prompt processing (default: same as --threads)\n");
|
||||||
|
printf(" -td N, --threads-draft N");
|
||||||
|
printf(" number of threads to use during generation (default: same as --threads)");
|
||||||
|
printf(" -tbd N, --threads-batch-draft N\n");
|
||||||
|
printf(" number of threads to use during batch and prompt processing (default: same as --threads-draft)\n");
|
||||||
printf(" -p PROMPT, --prompt PROMPT\n");
|
printf(" -p PROMPT, --prompt PROMPT\n");
|
||||||
printf(" prompt to start generation with (default: empty)\n");
|
printf(" prompt to start generation with (default: empty)\n");
|
||||||
printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
|
printf(" -e, --escape process prompt escapes sequences (\\n, \\r, \\t, \\', \\\", \\\\)\n");
|
||||||
|
@ -46,7 +46,9 @@ struct gpt_params {
|
|||||||
uint32_t seed = -1; // RNG seed
|
uint32_t seed = -1; // RNG seed
|
||||||
|
|
||||||
int32_t n_threads = get_num_physical_cores();
|
int32_t n_threads = get_num_physical_cores();
|
||||||
|
int32_t n_threads_draft = -1;
|
||||||
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
|
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
|
||||||
|
int32_t n_threads_batch_draft = -1;
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
int32_t n_ctx = 512; // context size
|
int32_t n_ctx = 512; // context size
|
||||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
|
@ -65,6 +65,10 @@ int main(int argc, char ** argv) {
|
|||||||
// load the draft model
|
// load the draft model
|
||||||
params.model = params.model_draft;
|
params.model = params.model_draft;
|
||||||
params.n_gpu_layers = params.n_gpu_layers_draft;
|
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||||
|
if (params.n_threads_draft > 0) {
|
||||||
|
params.n_threads = params.n_threads_draft;
|
||||||
|
}
|
||||||
|
params.n_threads_batch = params.n_threads_batch_draft;
|
||||||
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
Loading…
Reference in New Issue
Block a user