mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
speculative : PoC for speeding-up inference via speculative sampling (#2926)
* speculative : initial example * speculative : print encoding speed * speculative : add --draft CLI arg
This commit is contained in:
parent
8f429fa511
commit
47068e5170
@ -305,6 +305,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_keep = std::stoi(argv[i]);
|
params.n_keep = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--draft") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.n_draft = std::stoi(argv[i]);
|
||||||
} else if (arg == "--chunks") {
|
} else if (arg == "--chunks") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -317,6 +323,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.model = argv[i];
|
params.model = argv[i];
|
||||||
|
} else if (arg == "-md" || arg == "--model-draft") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.model_draft = argv[i];
|
||||||
} else if (arg == "-a" || arg == "--alias") {
|
} else if (arg == "-a" || arg == "--alias") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -638,6 +650,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
fprintf(stdout, " --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
|
fprintf(stdout, " --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
|
||||||
fprintf(stdout, " --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
|
fprintf(stdout, " --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
|
||||||
fprintf(stdout, " --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
fprintf(stdout, " --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
||||||
|
fprintf(stdout, " --draft N number of tokens to draft for speculative decoding (default: %d)\n", params.n_draft);
|
||||||
fprintf(stdout, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
fprintf(stdout, " --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks);
|
||||||
if (llama_mlock_supported()) {
|
if (llama_mlock_supported()) {
|
||||||
fprintf(stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
fprintf(stdout, " --mlock force system to keep model in RAM rather than swapping or compressing\n");
|
||||||
@ -669,6 +682,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
fprintf(stdout, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
|
||||||
fprintf(stdout, " -m FNAME, --model FNAME\n");
|
fprintf(stdout, " -m FNAME, --model FNAME\n");
|
||||||
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
|
fprintf(stdout, " model path (default: %s)\n", params.model.c_str());
|
||||||
|
fprintf(stdout, " -md FNAME, --model-draft FNAME\n");
|
||||||
|
fprintf(stdout, " draft model for speculative decoding (default: %s)\n", params.model.c_str());
|
||||||
fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n");
|
fprintf(stdout, " -ld LOGDIR, --logdir LOGDIR\n");
|
||||||
fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n");
|
fprintf(stdout, " path under which to save YAML logs (no logging if unset)\n");
|
||||||
fprintf(stdout, "\n");
|
fprintf(stdout, "\n");
|
||||||
@ -832,6 +847,130 @@ std::string llama_detokenize_bpe(llama_context * ctx, const std::vector<llama_to
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// Sampling utils
|
||||||
|
//
|
||||||
|
|
||||||
|
llama_token llama_sample_token(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
struct llama_context * ctx_guidance,
|
||||||
|
struct llama_grammar * grammar,
|
||||||
|
const struct gpt_params & params,
|
||||||
|
const std::vector<llama_token> & last_tokens,
|
||||||
|
std::vector<llama_token_data> & candidates,
|
||||||
|
int idx) {
|
||||||
|
const int n_ctx = llama_n_ctx(ctx);
|
||||||
|
const int n_vocab = llama_n_vocab(ctx);
|
||||||
|
|
||||||
|
const float temp = params.temp;
|
||||||
|
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||||
|
const float top_p = params.top_p;
|
||||||
|
const float tfs_z = params.tfs_z;
|
||||||
|
const float typical_p = params.typical_p;
|
||||||
|
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
|
||||||
|
const float repeat_penalty = params.repeat_penalty;
|
||||||
|
const float alpha_presence = params.presence_penalty;
|
||||||
|
const float alpha_frequency = params.frequency_penalty;
|
||||||
|
const int mirostat = params.mirostat;
|
||||||
|
const float mirostat_tau = params.mirostat_tau;
|
||||||
|
const float mirostat_eta = params.mirostat_eta;
|
||||||
|
const bool penalize_nl = params.penalize_nl;
|
||||||
|
|
||||||
|
llama_token id = 0;
|
||||||
|
|
||||||
|
float * logits = llama_get_logits(ctx) + idx * n_vocab;
|
||||||
|
|
||||||
|
// Apply params.logit_bias map
|
||||||
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||||
|
logits[it->first] += it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
candidates.clear();
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
|
if (ctx_guidance) {
|
||||||
|
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
|
||||||
|
}
|
||||||
|
|
||||||
|
// apply penalties
|
||||||
|
if (!last_tokens.empty()) {
|
||||||
|
const float nl_logit = logits[llama_token_nl(ctx)];
|
||||||
|
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
|
||||||
|
|
||||||
|
llama_sample_repetition_penalty(ctx, &cur_p,
|
||||||
|
last_tokens.data() + last_tokens.size() - last_n_repeat,
|
||||||
|
last_n_repeat, repeat_penalty);
|
||||||
|
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
|
||||||
|
last_tokens.data() + last_tokens.size() - last_n_repeat,
|
||||||
|
last_n_repeat, alpha_frequency, alpha_presence);
|
||||||
|
|
||||||
|
if (!penalize_nl) {
|
||||||
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
|
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
|
||||||
|
cur_p.data[idx].logit = nl_logit;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (grammar != NULL) {
|
||||||
|
llama_sample_grammar(ctx, &cur_p, grammar);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (temp <= 0) {
|
||||||
|
// Greedy sampling
|
||||||
|
id = llama_sample_token_greedy(ctx, &cur_p);
|
||||||
|
} else {
|
||||||
|
if (mirostat == 1) {
|
||||||
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||||
|
const int mirostat_m = 100;
|
||||||
|
llama_sample_temperature(ctx, &cur_p, temp);
|
||||||
|
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||||
|
} else if (mirostat == 2) {
|
||||||
|
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||||
|
llama_sample_temperature(ctx, &cur_p, temp);
|
||||||
|
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||||
|
} else {
|
||||||
|
// Temperature sampling
|
||||||
|
llama_sample_top_k (ctx, &cur_p, top_k, 1);
|
||||||
|
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
|
||||||
|
llama_sample_typical (ctx, &cur_p, typical_p, 1);
|
||||||
|
llama_sample_top_p (ctx, &cur_p, top_p, 1);
|
||||||
|
llama_sample_temperature(ctx, &cur_p, temp);
|
||||||
|
|
||||||
|
{
|
||||||
|
const int n_top = 10;
|
||||||
|
LOG("top %d candidates:\n", n_top);
|
||||||
|
|
||||||
|
for (int i = 0; i < n_top; i++) {
|
||||||
|
const llama_token id = cur_p.data[i].id;
|
||||||
|
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
id = llama_sample_token(ctx, &cur_p);
|
||||||
|
|
||||||
|
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// printf("`%d`", candidates_p.size);
|
||||||
|
|
||||||
|
if (grammar != NULL) {
|
||||||
|
llama_grammar_accept_token(ctx, grammar, id);
|
||||||
|
}
|
||||||
|
|
||||||
|
return id;
|
||||||
|
}
|
||||||
|
|
||||||
|
//
|
||||||
|
// YAML utils
|
||||||
|
//
|
||||||
|
|
||||||
// returns true if successful, false otherwise
|
// returns true if successful, false otherwise
|
||||||
bool create_directory_with_parents(const std::string & path) {
|
bool create_directory_with_parents(const std::string & path) {
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
@ -1070,6 +1209,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta);
|
fprintf(stream, "mirostat_lr: %f # default: 0.1\n", params.mirostat_eta);
|
||||||
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
|
fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false");
|
||||||
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
|
fprintf(stream, "model: %s # default: models/7B/ggml-model.bin\n", params.model.c_str());
|
||||||
|
fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str());
|
||||||
fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false");
|
fprintf(stream, "mtest: %s # default: false\n", params.mem_test ? "true" : "false");
|
||||||
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
|
fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false");
|
||||||
fprintf(stream, "n_gpu_layers: %d # default: 0\n", params.n_gpu_layers);
|
fprintf(stream, "n_gpu_layers: %d # default: 0\n", params.n_gpu_layers);
|
||||||
|
@ -32,6 +32,7 @@ struct gpt_params {
|
|||||||
int32_t n_ctx = 512; // context size
|
int32_t n_ctx = 512; // context size
|
||||||
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
int32_t n_batch = 512; // batch size for prompt processing (must be >=32 to use BLAS)
|
||||||
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
||||||
|
int32_t n_draft = 16; // number of tokens to draft during speculative decoding
|
||||||
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
int32_t n_chunks = -1; // max number of chunks to process (-1 = unlimited)
|
||||||
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
|
int32_t n_gpu_layers = 0; // number of layers to store in VRAM
|
||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
@ -63,6 +64,7 @@ struct gpt_params {
|
|||||||
float cfg_scale = 1.f; // How strong is guidance
|
float cfg_scale = 1.f; // How strong is guidance
|
||||||
|
|
||||||
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
|
std::string model = "models/7B/ggml-model-f16.gguf"; // model path
|
||||||
|
std::string model_draft = ""; // draft model for speculative decoding
|
||||||
std::string model_alias = "unknown"; // model alias
|
std::string model_alias = "unknown"; // model alias
|
||||||
std::string prompt = "";
|
std::string prompt = "";
|
||||||
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
|
||||||
@ -156,6 +158,40 @@ std::string llama_detokenize_bpe(
|
|||||||
llama_context * ctx,
|
llama_context * ctx,
|
||||||
const std::vector<llama_token> & tokens);
|
const std::vector<llama_token> & tokens);
|
||||||
|
|
||||||
|
//
|
||||||
|
// Sampling utils
|
||||||
|
//
|
||||||
|
|
||||||
|
// this is a common sampling function used across the examples for convenience
|
||||||
|
// it can serve as a starting point for implementing your own sampling function
|
||||||
|
//
|
||||||
|
// required:
|
||||||
|
// - ctx: context to use for sampling
|
||||||
|
// - params: sampling parameters
|
||||||
|
//
|
||||||
|
// optional:
|
||||||
|
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
|
||||||
|
// - grammar: grammar to use for sampling, ignore if NULL
|
||||||
|
// - last_tokens: needed for repetition penalty, ignore if empty
|
||||||
|
// - idx: sample from llama_get_logits(ctx) + idx * n_vocab
|
||||||
|
//
|
||||||
|
// returns:
|
||||||
|
// - token: sampled token
|
||||||
|
// - candidates: vector of candidate tokens
|
||||||
|
//
|
||||||
|
llama_token llama_sample_token(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
struct llama_context * ctx_guidance,
|
||||||
|
struct llama_grammar * grammar,
|
||||||
|
const struct gpt_params & params,
|
||||||
|
const std::vector<llama_token> & last_tokens,
|
||||||
|
std::vector<llama_token_data> & candidates,
|
||||||
|
int idx = 0);
|
||||||
|
|
||||||
|
//
|
||||||
|
// YAML utils
|
||||||
|
//
|
||||||
|
|
||||||
bool create_directory_with_parents(const std::string & path);
|
bool create_directory_with_parents(const std::string & path);
|
||||||
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
|
void dump_vector_float_yaml(FILE * stream, const char * prop_name, const std::vector<float> & data);
|
||||||
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);
|
void dump_vector_int_yaml(FILE * stream, const char * prop_name, const std::vector<int> & data);
|
||||||
|
@ -23,6 +23,7 @@ else()
|
|||||||
add_subdirectory(train-text-from-scratch)
|
add_subdirectory(train-text-from-scratch)
|
||||||
add_subdirectory(convert-llama2c-to-ggml)
|
add_subdirectory(convert-llama2c-to-ggml)
|
||||||
add_subdirectory(simple)
|
add_subdirectory(simple)
|
||||||
|
add_subdirectory(speculative)
|
||||||
add_subdirectory(embd-input)
|
add_subdirectory(embd-input)
|
||||||
add_subdirectory(llama-bench)
|
add_subdirectory(llama-bench)
|
||||||
add_subdirectory(beam-search)
|
add_subdirectory(beam-search)
|
||||||
|
@ -116,7 +116,7 @@ int main(int argc, char ** argv) {
|
|||||||
#ifndef LOG_DISABLE_LOGS
|
#ifndef LOG_DISABLE_LOGS
|
||||||
log_set_target(log_filename_generator("main", "log"));
|
log_set_target(log_filename_generator("main", "log"));
|
||||||
LOG_TEE("Log start\n");
|
LOG_TEE("Log start\n");
|
||||||
log_dump_cmdline(argc,argv);
|
log_dump_cmdline(argc, argv);
|
||||||
#endif // LOG_DISABLE_LOGS
|
#endif // LOG_DISABLE_LOGS
|
||||||
|
|
||||||
// TODO: Dump params ?
|
// TODO: Dump params ?
|
||||||
@ -425,8 +425,9 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
LOG_TEE("\n\n");
|
LOG_TEE("\n\n");
|
||||||
|
|
||||||
|
struct llama_grammar * grammar = NULL;
|
||||||
grammar_parser::parse_state parsed_grammar;
|
grammar_parser::parse_state parsed_grammar;
|
||||||
llama_grammar * grammar = NULL;
|
|
||||||
if (!params.grammar.empty()) {
|
if (!params.grammar.empty()) {
|
||||||
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||||
// will be empty (default) if there are parse errors
|
// will be empty (default) if there are parse errors
|
||||||
@ -450,8 +451,8 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TODO: replace with ring-buffer
|
// TODO: replace with ring-buffer
|
||||||
std::vector<llama_token> last_n_tokens(n_ctx);
|
std::vector<llama_token> last_tokens(n_ctx);
|
||||||
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
|
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
const char *control_message;
|
const char *control_message;
|
||||||
@ -492,6 +493,11 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<llama_token> embd;
|
std::vector<llama_token> embd;
|
||||||
std::vector<llama_token> embd_guidance;
|
std::vector<llama_token> embd_guidance;
|
||||||
|
|
||||||
|
const int n_vocab = llama_n_vocab(ctx);
|
||||||
|
|
||||||
|
std::vector<llama_token_data> candidates;
|
||||||
|
candidates.reserve(n_vocab);
|
||||||
|
|
||||||
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
while ((n_remain != 0 && !is_antiprompt) || params.interactive) {
|
||||||
// predict
|
// predict
|
||||||
if (embd.size() > 0) {
|
if (embd.size() > 0) {
|
||||||
@ -529,8 +535,8 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
|
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
|
||||||
|
|
||||||
// insert n_left/2 tokens at the start of embd from last_n_tokens
|
// insert n_left/2 tokens at the start of embd from last_tokens
|
||||||
embd.insert(embd.begin(), last_n_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_n_tokens.end() - embd.size());
|
embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size());
|
||||||
|
|
||||||
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
|
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
|
||||||
|
|
||||||
@ -629,20 +635,6 @@ int main(int argc, char ** argv) {
|
|||||||
embd_guidance.clear();
|
embd_guidance.clear();
|
||||||
|
|
||||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||||
const float temp = params.temp;
|
|
||||||
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
|
|
||||||
const float top_p = params.top_p;
|
|
||||||
const float tfs_z = params.tfs_z;
|
|
||||||
const float typical_p = params.typical_p;
|
|
||||||
const int32_t repeat_last_n = params.repeat_last_n < 0 ? n_ctx : params.repeat_last_n;
|
|
||||||
const float repeat_penalty = params.repeat_penalty;
|
|
||||||
const float alpha_presence = params.presence_penalty;
|
|
||||||
const float alpha_frequency = params.frequency_penalty;
|
|
||||||
const int mirostat = params.mirostat;
|
|
||||||
const float mirostat_tau = params.mirostat_tau;
|
|
||||||
const float mirostat_eta = params.mirostat_eta;
|
|
||||||
const bool penalize_nl = params.penalize_nl;
|
|
||||||
|
|
||||||
// optionally save the session on first sample (for faster prompt loading next time)
|
// optionally save the session on first sample (for faster prompt loading next time)
|
||||||
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
|
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
|
||||||
need_to_save_session = false;
|
need_to_save_session = false;
|
||||||
@ -651,98 +643,12 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("saved session to %s\n", path_session.c_str());
|
LOG("saved session to %s\n", path_session.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token id = 0;
|
const llama_token id = llama_sample_token(ctx, ctx_guidance, grammar, params, last_tokens, candidates);
|
||||||
|
|
||||||
{
|
last_tokens.erase(last_tokens.begin());
|
||||||
auto logits = llama_get_logits(ctx);
|
last_tokens.push_back(id);
|
||||||
auto n_vocab = llama_n_vocab(ctx);
|
|
||||||
|
|
||||||
// Apply params.logit_bias map
|
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_tokens));
|
||||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
|
||||||
logits[it->first] += it->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<llama_token_data> candidates;
|
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
|
||||||
|
|
||||||
if (ctx_guidance) {
|
|
||||||
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply penalties
|
|
||||||
float nl_logit = logits[llama_token_nl(ctx)];
|
|
||||||
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), repeat_last_n), n_ctx);
|
|
||||||
llama_sample_repetition_penalty(ctx, &cur_p,
|
|
||||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
|
||||||
last_n_repeat, repeat_penalty);
|
|
||||||
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
|
|
||||||
last_n_tokens.data() + last_n_tokens.size() - last_n_repeat,
|
|
||||||
last_n_repeat, alpha_frequency, alpha_presence);
|
|
||||||
if (!penalize_nl) {
|
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
|
||||||
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
|
|
||||||
cur_p.data[idx].logit = nl_logit;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (grammar != NULL) {
|
|
||||||
llama_sample_grammar(ctx, &cur_p, grammar);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (temp <= 0) {
|
|
||||||
// Greedy sampling
|
|
||||||
id = llama_sample_token_greedy(ctx, &cur_p);
|
|
||||||
} else {
|
|
||||||
if (mirostat == 1) {
|
|
||||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
|
||||||
const int mirostat_m = 100;
|
|
||||||
llama_sample_temperature(ctx, &cur_p, temp);
|
|
||||||
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
|
||||||
} else if (mirostat == 2) {
|
|
||||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
|
||||||
llama_sample_temperature(ctx, &cur_p, temp);
|
|
||||||
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
|
||||||
} else {
|
|
||||||
// Temperature sampling
|
|
||||||
llama_sample_top_k (ctx, &cur_p, top_k, 1);
|
|
||||||
llama_sample_tail_free (ctx, &cur_p, tfs_z, 1);
|
|
||||||
llama_sample_typical (ctx, &cur_p, typical_p, 1);
|
|
||||||
llama_sample_top_p (ctx, &cur_p, top_p, 1);
|
|
||||||
llama_sample_temperature(ctx, &cur_p, temp);
|
|
||||||
|
|
||||||
{
|
|
||||||
const int n_top = 10;
|
|
||||||
LOG("top %d candidates:\n", n_top);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_top; i++) {
|
|
||||||
const llama_token id = cur_p.data[i].id;
|
|
||||||
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
id = llama_sample_token(ctx, &cur_p);
|
|
||||||
|
|
||||||
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// printf("`%d`", candidates_p.size);
|
|
||||||
|
|
||||||
if (grammar != NULL) {
|
|
||||||
llama_grammar_accept_token(ctx, grammar, id);
|
|
||||||
}
|
|
||||||
|
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
|
||||||
last_n_tokens.push_back(id);
|
|
||||||
|
|
||||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, last_n_tokens));
|
|
||||||
}
|
|
||||||
|
|
||||||
embd.push_back(id);
|
embd.push_back(id);
|
||||||
|
|
||||||
@ -758,8 +664,8 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
|
LOG("embd_inp.size(): %d, n_consumed: %d\n", (int) embd_inp.size(), n_consumed);
|
||||||
while ((int) embd_inp.size() > n_consumed) {
|
while ((int) embd_inp.size() > n_consumed) {
|
||||||
embd.push_back(embd_inp[n_consumed]);
|
embd.push_back(embd_inp[n_consumed]);
|
||||||
last_n_tokens.erase(last_n_tokens.begin());
|
last_tokens.erase(last_tokens.begin());
|
||||||
last_n_tokens.push_back(embd_inp[n_consumed]);
|
last_tokens.push_back(embd_inp[n_consumed]);
|
||||||
++n_consumed;
|
++n_consumed;
|
||||||
if ((int) embd.size() >= params.n_batch) {
|
if ((int) embd.size() >= params.n_batch) {
|
||||||
break;
|
break;
|
||||||
@ -792,7 +698,7 @@ int main(int argc, char ** argv) {
|
|||||||
// check for reverse prompt
|
// check for reverse prompt
|
||||||
if (params.antiprompt.size()) {
|
if (params.antiprompt.size()) {
|
||||||
std::string last_output;
|
std::string last_output;
|
||||||
for (auto id : last_n_tokens) {
|
for (auto id : last_tokens) {
|
||||||
last_output += llama_token_to_piece(ctx, id);
|
last_output += llama_token_to_piece(ctx, id);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -823,7 +729,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// deal with end of text token in interactive mode
|
// deal with end of text token in interactive mode
|
||||||
if (last_n_tokens.back() == llama_token_eos(ctx)) {
|
if (last_tokens.back() == llama_token_eos(ctx)) {
|
||||||
LOG("found EOS token\n");
|
LOG("found EOS token\n");
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
@ -925,7 +831,7 @@ int main(int argc, char ** argv) {
|
|||||||
if (grammar != NULL) {
|
if (grammar != NULL) {
|
||||||
llama_grammar_free(grammar);
|
llama_grammar_free(grammar);
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> grammar_rules( parsed_grammar.c_rules());
|
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||||
grammar = llama_grammar_init(
|
grammar = llama_grammar_init(
|
||||||
grammar_rules.data(), grammar_rules.size(),
|
grammar_rules.data(), grammar_rules.size(),
|
||||||
parsed_grammar.symbol_ids.at("root"));
|
parsed_grammar.symbol_ids.at("root"));
|
||||||
|
8
examples/speculative/CMakeLists.txt
Normal file
8
examples/speculative/CMakeLists.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
set(TARGET speculative)
|
||||||
|
add_executable(${TARGET} speculative.cpp)
|
||||||
|
install(TARGETS ${TARGET} RUNTIME)
|
||||||
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||||
|
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||||
|
if(TARGET BUILD_INFO)
|
||||||
|
add_dependencies(${TARGET} BUILD_INFO)
|
||||||
|
endif()
|
234
examples/speculative/speculative.cpp
Normal file
234
examples/speculative/speculative.cpp
Normal file
@ -0,0 +1,234 @@
|
|||||||
|
#ifndef _GNU_SOURCE
|
||||||
|
#define _GNU_SOURCE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "build-info.h"
|
||||||
|
|
||||||
|
#include "common.h"
|
||||||
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
gpt_params params;
|
||||||
|
|
||||||
|
if (gpt_params_parse(argc, argv, params) == false) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.model_draft.empty()) {
|
||||||
|
fprintf(stderr, "%s: error: --model-draft is required\n", __func__);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifndef LOG_DISABLE_LOGS
|
||||||
|
log_set_target(log_filename_generator("speculative", "log"));
|
||||||
|
LOG_TEE("Log start\n");
|
||||||
|
log_dump_cmdline(argc, argv);
|
||||||
|
#endif // LOG_DISABLE_LOGS
|
||||||
|
|
||||||
|
// init llama.cpp
|
||||||
|
llama_backend_init(params.numa);
|
||||||
|
|
||||||
|
llama_model * model_tgt = NULL;
|
||||||
|
llama_model * model_dft = NULL;
|
||||||
|
|
||||||
|
llama_context * ctx_tgt = NULL;
|
||||||
|
llama_context * ctx_dft = NULL;
|
||||||
|
|
||||||
|
// load the target model
|
||||||
|
params.perplexity = true; // HACK: enable logits_all = true
|
||||||
|
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
|
// load the draft model
|
||||||
|
params.model = params.model_draft;
|
||||||
|
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
|
// tokenize the prompt
|
||||||
|
std::vector<llama_token> inp;
|
||||||
|
inp = ::llama_tokenize(ctx_tgt, params.prompt, true);
|
||||||
|
|
||||||
|
const int max_context_size = llama_n_ctx(ctx_tgt);
|
||||||
|
const int max_tokens_list_size = max_context_size - 4;
|
||||||
|
|
||||||
|
if ((int) inp.size() > max_tokens_list_size) {
|
||||||
|
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
fprintf(stderr, "\n\n");
|
||||||
|
|
||||||
|
for (auto id : inp) {
|
||||||
|
fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
fflush(stderr);
|
||||||
|
|
||||||
|
const int n_input = inp.size();
|
||||||
|
|
||||||
|
const auto t_enc_start = ggml_time_us();
|
||||||
|
|
||||||
|
// eval the prompt with both models
|
||||||
|
llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads);
|
||||||
|
llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads);
|
||||||
|
llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads);
|
||||||
|
|
||||||
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
|
||||||
|
// the 2 models should have the same vocab
|
||||||
|
const int n_ctx = llama_n_ctx(ctx_tgt);
|
||||||
|
const int n_vocab = llama_n_vocab(ctx_tgt);
|
||||||
|
//GGML_ASSERT(n_vocab == llama_n_vocab(ctx_dft));
|
||||||
|
|
||||||
|
// how many tokens to draft each time
|
||||||
|
const int n_draft = params.n_draft;
|
||||||
|
|
||||||
|
int n_predict = 0;
|
||||||
|
int n_drafted = 0;
|
||||||
|
int n_accept = 0;
|
||||||
|
|
||||||
|
int n_past_tgt = inp.size();
|
||||||
|
int n_past_dft = inp.size();
|
||||||
|
|
||||||
|
std::vector<llama_token> drafted;
|
||||||
|
|
||||||
|
std::vector<llama_token> last_tokens(n_ctx);
|
||||||
|
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
||||||
|
|
||||||
|
for (auto & id : inp) {
|
||||||
|
last_tokens.erase(last_tokens.begin());
|
||||||
|
last_tokens.push_back(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<llama_token_data> candidates;
|
||||||
|
candidates.reserve(n_vocab);
|
||||||
|
|
||||||
|
// used to determine end of generation
|
||||||
|
bool has_eos = false;
|
||||||
|
|
||||||
|
const auto t_dec_start = ggml_time_us();
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
LOG("drafted: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_dft, drafted));
|
||||||
|
|
||||||
|
// sample from the drafted tokens if any
|
||||||
|
int i_dft = 0;
|
||||||
|
while (true) {
|
||||||
|
const llama_token id = llama_sample_token(ctx_tgt, NULL, NULL, params, last_tokens, candidates, i_dft);
|
||||||
|
|
||||||
|
last_tokens.erase(last_tokens.begin());
|
||||||
|
last_tokens.push_back(id);
|
||||||
|
|
||||||
|
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
|
||||||
|
|
||||||
|
const std::string token_str = llama_token_to_piece(ctx_tgt, id);
|
||||||
|
printf("%s", token_str.c_str());
|
||||||
|
fflush(stdout);
|
||||||
|
|
||||||
|
if (id == llama_token_eos(ctx_tgt)) {
|
||||||
|
has_eos = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
++n_predict;
|
||||||
|
|
||||||
|
if (i_dft < (int) drafted.size() && id == drafted[i_dft]) {
|
||||||
|
LOG("drafted token %d accepted\n", id);
|
||||||
|
++n_accept;
|
||||||
|
++n_past_tgt;
|
||||||
|
++n_past_dft;
|
||||||
|
++i_dft;
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// the drafted token was rejected or we are out of drafted tokens
|
||||||
|
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
|
||||||
|
++n_past_dft;
|
||||||
|
|
||||||
|
drafted.clear();
|
||||||
|
drafted.push_back(id);
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (n_predict > params.n_predict || has_eos) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sample n_draft tokens from the draft model picking the best token
|
||||||
|
int n_past_cur = n_past_dft;
|
||||||
|
for (int i = 0; i < n_draft; ++i) {
|
||||||
|
float * logits = llama_get_logits(ctx_dft);
|
||||||
|
|
||||||
|
candidates.clear();
|
||||||
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
|
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
||||||
|
|
||||||
|
// computes softmax and sorts the candidates
|
||||||
|
llama_sample_softmax(ctx_dft, &cur_p);
|
||||||
|
|
||||||
|
for (int i = 0; i < 3; ++i) {
|
||||||
|
LOG(" - draft candidate %d: %d (%.3f)\n", i, cur_p.data[i].id, cur_p.data[i].p);
|
||||||
|
}
|
||||||
|
|
||||||
|
// too low probability, stop drafting
|
||||||
|
if (cur_p.data[0].p < 2*cur_p.data[1].p) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
drafted.push_back(cur_p.data[0].id);
|
||||||
|
++n_drafted;
|
||||||
|
|
||||||
|
if (i < n_draft - 1) {
|
||||||
|
// evaluate the drafted token on the draft model
|
||||||
|
llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
|
||||||
|
++n_past_cur;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// evaluate the target model on the drafted tokens
|
||||||
|
llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads);
|
||||||
|
++n_past_tgt;
|
||||||
|
|
||||||
|
drafted.erase(drafted.begin());
|
||||||
|
}
|
||||||
|
|
||||||
|
auto t_dec_end = ggml_time_us();
|
||||||
|
|
||||||
|
LOG_TEE("\n\n");
|
||||||
|
|
||||||
|
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
||||||
|
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
||||||
|
|
||||||
|
// TODO: make sure these numbers are computed correctly
|
||||||
|
LOG_TEE("\n");
|
||||||
|
LOG_TEE("n_draft = %d\n", n_draft);
|
||||||
|
LOG_TEE("n_predict = %d\n", n_predict);
|
||||||
|
LOG_TEE("n_drafted = %d\n", n_drafted);
|
||||||
|
LOG_TEE("n_accept = %d\n", n_accept);
|
||||||
|
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||||
|
|
||||||
|
LOG_TEE("\ndraft:\n");
|
||||||
|
llama_print_timings(ctx_dft);
|
||||||
|
|
||||||
|
LOG_TEE("\ntarget:\n");
|
||||||
|
llama_print_timings(ctx_tgt);
|
||||||
|
|
||||||
|
llama_free(ctx_tgt);
|
||||||
|
llama_free_model(model_tgt);
|
||||||
|
|
||||||
|
llama_free(ctx_dft);
|
||||||
|
llama_free_model(model_dft);
|
||||||
|
|
||||||
|
llama_backend_free();
|
||||||
|
|
||||||
|
fprintf(stderr, "\n\n");
|
||||||
|
|
||||||
|
return 0;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user