mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 03:31:46 +00:00
speculative : change default p_accept to 0.5 + CLI args (#3919)
ggml-ci
This commit is contained in:
parent
05816027d6
commit
8f961abdc4
@ -403,6 +403,18 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_sequences = std::stoi(argv[i]);
|
params.n_sequences = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--p-accept" || arg == "-pa") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.p_accept = std::stof(argv[i]);
|
||||||
|
} else if (arg == "--p-split" || arg == "-ps") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.p_split = std::stof(argv[i]);
|
||||||
} else if (arg == "-m" || arg == "--model") {
|
} else if (arg == "-m" || arg == "--model") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -778,6 +790,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||||
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
|
printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel);
|
||||||
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
|
printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences);
|
||||||
|
printf(" -pa N, --p-accept N speculative decoding accept probability (default: %.1f)\n", (double)params.p_accept);
|
||||||
|
printf(" -ps N, --p-split N speculative decoding split probability (default: %.1f)\n", (double)params.p_split);
|
||||||
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
||||||
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
|
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA. see examples/llava/README.md\n");
|
||||||
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
|
printf(" --image IMAGE_FILE path to an image file. use with multimodal models\n");
|
||||||
|
@ -44,6 +44,7 @@ int32_t get_num_physical_cores();
|
|||||||
|
|
||||||
struct gpt_params {
|
struct gpt_params {
|
||||||
uint32_t seed = -1; // RNG seed
|
uint32_t seed = -1; // RNG seed
|
||||||
|
|
||||||
int32_t n_threads = get_num_physical_cores();
|
int32_t n_threads = get_num_physical_cores();
|
||||||
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
|
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
|
||||||
int32_t n_predict = -1; // new tokens to predict
|
int32_t n_predict = -1; // new tokens to predict
|
||||||
@ -54,6 +55,8 @@ struct gpt_params {
|
|||||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||||
int32_t n_parallel = 1; // number of parallel sequences to decode
|
int32_t n_parallel = 1; // number of parallel sequences to decode
|
||||||
int32_t n_sequences = 1; // number of sequences to decode
|
int32_t n_sequences = 1; // number of sequences to decode
|
||||||
|
float p_accept = 0.5f; // speculative decoding accept probability
|
||||||
|
float p_split = 0.1f; // speculative decoding split probability
|
||||||
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
int32_t n_gpu_layers = -1; // number of layers to store in VRAM (-1 - use default)
|
||||||
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
int32_t n_gpu_layers_draft = -1; // number of layers to store in VRAM for the draft model (-1 - use default)
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
@ -66,7 +69,8 @@ struct gpt_params {
|
|||||||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
||||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
||||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||||
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED;
|
int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment
|
||||||
|
// pinging @cebtenzzre
|
||||||
|
|
||||||
// // sampling parameters
|
// // sampling parameters
|
||||||
struct llama_sampling_params sparams;
|
struct llama_sampling_params sparams;
|
||||||
|
@ -37,9 +37,11 @@ int main(int argc, char ** argv) {
|
|||||||
// max number of parallel drafting sequences (i.e. tree branches)
|
// max number of parallel drafting sequences (i.e. tree branches)
|
||||||
const int n_seq_dft = params.n_parallel;
|
const int n_seq_dft = params.n_parallel;
|
||||||
|
|
||||||
// TODO: make this configurable
|
// probability threshold for accepting a token from the draft model
|
||||||
const float p_accept = 0.80f;
|
const float p_accept = params.p_accept;
|
||||||
const float p_split = 0.10f;
|
|
||||||
|
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
||||||
|
const float p_split = params.p_split;
|
||||||
|
|
||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("speculative", "log"));
|
log_set_target(log_filename_generator("speculative", "log"));
|
||||||
|
Loading…
Reference in New Issue
Block a user