diff --git a/common/arg.cpp b/common/arg.cpp index 4115b2f75..35500670f 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -609,7 +609,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.n_draft = value; } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--draft-min"}, "N", + string_format("minimum number of draft tokens to use for speculative decoding (default: %d)", params.n_draft_min), + [](common_params & params, int value) { + params.n_draft_min = value; + } + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-ps", "--p-split"}, "N", string_format("speculative decoding split probability (default: %.1f)", (double)params.p_split), @@ -1454,7 +1461,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n"); } } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-sm", "--split-mode"}, "{none,layer,row}", "how to split the model across multiple GPUs, one of:\n" @@ -1599,7 +1606,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.model_draft = value; } - ).set_examples({LLAMA_EXAMPLE_SPECULATIVE})); + ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-mu", "--model-url"}, "MODEL_URL", "model download url (default: unused)", diff --git a/common/common.h b/common/common.h index 29d678c7b..42c17ed3c 100644 --- a/common/common.h +++ b/common/common.h @@ -162,6 +162,7 @@ struct common_params { int32_t n_ubatch = 512; // physical batch size for prompt processing (must be >=32 to use BLAS) int32_t n_keep = 0; // number of tokens to keep from initial prompt int32_t n_draft = 5; // number of tokens to draft during speculative decoding + int32_t n_draft_min = 0; // minimum number of draft tokens to use for speculative decoding int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited) int32_t n_parallel = 1; // number of parallel sequences to decode int32_t n_sequences = 1; // number of sequences to decode diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index d7e572cf8..6dee64834 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -13,9 +13,6 @@ int main(int argc, char ** argv) { common_params params; - // minimum size of the draft to use - const int n_min = 5; - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { return 1; } @@ -142,7 +139,7 @@ int main(int argc, char ** argv) { // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { // do not waste time on small drafts - if (draft.size() < n_min) { + if (draft.size() < params.n_draft_min) { draft.clear(); }