From fe043ff1ff07fdef1899778e52dafbad26037d38 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 17 Nov 2024 18:55:27 +0200 Subject: [PATCH] speculative : clean-up and add comments and TODOs [no ci] --- common/sampling.h | 11 ++++ common/speculative.cpp | 7 +-- common/speculative.h | 13 ++++ .../speculative-simple/speculative-simple.cpp | 61 ++++++++++++++----- 4 files changed, 70 insertions(+), 22 deletions(-) diff --git a/common/sampling.h b/common/sampling.h index 9e61690aa..23cfae1ac 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -60,6 +60,17 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam // llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false); +// generalized version of common_sampler_sample +// +// will cross-reference the sampled tokens with a batch of draft tokens +// if the sampler disagrees at some point, we stop and return the sampled tokens up to now +// +// `common_sampler_sample_n(gsmpl, ctx, { idx }, {})` is equivalent to `common_sampler_sample(gsmpl, ctx, idx)` +// +// requires: idxs.size() == draft.size() + 1 +// +// returns at least 1 token, up to idxs.size() +// std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first = false); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/common/speculative.cpp b/common/speculative.cpp index d16cc3c8e..2726760ad 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -4,11 +4,6 @@ #include "common.h" #include "sampling.h" -#include - -struct seq_draft { -}; - struct common_speculative { struct common_speculative_params params; @@ -140,7 +135,7 @@ void common_speculative_add_draft( } // don't waste time on small batches - // TODO: do not evaluate the draft model for tha many rounds + // TODO: do not evaluate the draft model for that many rounds if (batch_tgt.n_tokens < spec->params.n_min) { batch_tgt.n_tokens = 1; spec->tokens.resize(0); diff --git a/common/speculative.h b/common/speculative.h index 0952e5e70..a2df2667a 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -19,8 +19,21 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa void common_speculative_free(struct common_speculative * spec); +// TODO: remove void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens); +// sample up to n_draft tokens and add them to the batch using the draft model +// +// TODO: change to: +// +// void common_speculative_add_draft( +// struct common_speculative * spec, +// struct llama_batch & batch_tgt, +// llama_token * tokens, +// int32_t n_tokens); +// +// and update the internal logic to compute only the new tokens +// void common_speculative_add_draft( struct common_speculative * spec, struct llama_batch & batch_tgt, diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 31a09e61d..aeccfd369 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -120,7 +120,6 @@ int main(int argc, char ** argv) { } } - // Tokenize the prompt std::vector inp; inp = common_tokenize(ctx_tgt, params.prompt, true, true); @@ -139,18 +138,6 @@ int main(int argc, char ** argv) { LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); } - const int n_input = inp.size(); - - const auto t_enc_start = ggml_time_us(); - - // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1)); - - // note: keep the last token separate! - llama_token id_last = inp.back(); - - int n_past = inp.size() - 1; - // how many tokens to draft each time int n_draft = params.n_draft; @@ -161,9 +148,25 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; + // ================================================ + // everything until here is standard initialization + // the relevant stuff for speculative decoding starts here + + const int n_input = inp.size(); + + const auto t_enc_start = ggml_time_us(); + // target model sampling context struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams); + // eval the prompt + llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1)); + + // note: keep the last token separate! + llama_token id_last = inp.back(); + + int n_past = inp.size() - 1; + // init the speculator struct common_speculative_params params_spec; params_spec.n_draft = n_draft; @@ -174,6 +177,13 @@ int main(int argc, char ** argv) { struct common_speculative * spec = common_speculative_init(params_spec); // feed the prompt to the speculator + // + // this has to be kept synchronized with the target context + // + // TODO: simplify this by moving the context management logic in the common_speculative instance + // for example, the common_speculative_add_draft can pass the entire context (or part of it) and the + // speculator will automatically compute any new tokens that are not present in its context + // common_speculative_set_prompt(spec, inp.data(), n_input - 1); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); @@ -188,23 +198,41 @@ int main(int argc, char ** argv) { common_batch_add (batch_tgt, id_last, n_past, { 0 }, true); // optionally, append draft tokens to the target batch + // + // this is the most important part of the speculation. the more probable tokens that are provided here + // the better the performance will be. in theory, this computation can be performed asynchronously and even + // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens + // from a cache or lookup tables. + // common_speculative_add_draft(spec, batch_tgt, id_last, n_past); - // evaluate the target model on the drafted tokens + // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); llama_decode(ctx_tgt, batch_tgt); } - // process the full target batch and return the accepted token based on the target sampler + // sample from the full target batch and return the accepted tokens based on the target sampler + // + // for each token to be accepted, the sampler would have to sample that same token + // in such cases, instead of decoding the sampled token as we normally do, we simply continue with the + // available logits from the batch and sample the next token until we run out of logits or the sampler + // disagrees with the draft + // const auto ids = common_speculative_sample(spec, smpl, ctx_tgt); + GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token + n_past += ids.size(); n_drafted += batch_tgt.n_tokens - 1; n_accept += ids.size() - 1; // process the accepted tokens and update contexts + // + // this is the standard token post-processing that we normally do + // in this case, we do it for a group of accepted tokens at once + // { llama_token id; std::string token_str; @@ -232,7 +260,7 @@ int main(int argc, char ** argv) { break; } - LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str()); + LOG_DBG("accepted %d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, id, token_str.c_str()); { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); @@ -241,6 +269,7 @@ int main(int argc, char ** argv) { llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1); } + // remember the last accepted token for the next iteration id_last = id; } }