diff --git a/common/speculative.cpp b/common/speculative.cpp index 6acf84a23..810fa93e4 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -4,21 +4,31 @@ #include "common.h" #include "sampling.h" +#include + +#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 +#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 + struct common_speculative { struct common_speculative_params params; - llama_batch batch_dft; + llama_batch batch; + struct llama_context * ctx; struct common_sampler * smpl; - llama_tokens prompt_last; + llama_tokens prompt; }; -struct common_speculative * common_speculative_init(struct common_speculative_params params) { +struct common_speculative * common_speculative_init( + struct common_speculative_params params, + struct llama_context * ctx_dft) { auto * result = new common_speculative { - /* .params = */ params, - /* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1), - /* .smpl = */ nullptr, + /* .params = */ params, + /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .ctx = */ ctx_dft, + /* .smpl = */ nullptr, + /* .prompt = */ {}, }; // TODO: optimize or pass from outside? @@ -36,7 +46,7 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa COMMON_SAMPLER_TYPE_INFILL, }; - result->smpl = common_sampler_init(params.model_dft, sparams); + result->smpl = common_sampler_init(llama_get_model(ctx_dft), sparams); } #else { @@ -49,46 +59,104 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa COMMON_SAMPLER_TYPE_TOP_K, }; - result->smpl = common_sampler_init(params.model_dft, sparams); + result->smpl = common_sampler_init(llama_get_model(ctx_dft), sparams); } #endif - result->batch_dft = llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1); - return result; } void common_speculative_free(struct common_speculative * spec) { common_sampler_free(spec->smpl); - llama_batch_free(spec->batch_dft); + llama_batch_free(spec->batch); delete spec; } +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft) { + const struct llama_model * model_tgt = llama_get_model(ctx_tgt); + const struct llama_model * model_dft = llama_get_model(ctx_dft); + + const bool vocab_type_tgt = llama_vocab_type(model_tgt); + LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); + + const bool vocab_type_dft = llama_vocab_type(model_dft); + LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); + + if (vocab_type_tgt != vocab_type_dft) { + LOG_ERR("%s: draft model vocab type must match target model to use speculation but " + "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); + return false; + } + + if (llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || + llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || + llama_token_bos(model_tgt) != llama_token_bos(model_dft) || + llama_token_eos(model_tgt) != llama_token_eos(model_dft) + ) { + LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); + return false; + } + + { + const int n_vocab_tgt = llama_n_vocab(model_tgt); + const int n_vocab_dft = llama_n_vocab(model_dft); + + const int vocab_diff = std::abs(n_vocab_tgt - n_vocab_dft); + + if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { + LOG_ERR("%s: draft model vocab must closely match target model to use speculation but " + "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", + __func__, n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); + return false; + } + + for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { + const char * token_text_tgt = llama_token_get_text(model_tgt, i); + const char * token_text_dft = llama_token_get_text(model_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { + LOG_ERR("%s: draft model vocab must match target model to use speculation but " + "token %d content differs - target '%s', draft '%s'\n", __func__, i, + common_token_to_piece(ctx_tgt, i).c_str(), + common_token_to_piece(ctx_dft, i).c_str()); + return false; + } + } + } + + return true; +} + void common_speculative_add_draft( struct common_speculative * spec, struct llama_batch & batch_tgt, - const llama_tokens & prompt, + const llama_tokens & prompt_tgt, llama_token id_last, llama_token n_past_tgt) { + auto & batch = spec->batch; + auto & ctx = spec->ctx; + auto & smpl = spec->smpl; + auto & prompt = spec->prompt; int reuse_i = 0; int reuse_n = 0; - const int n_ctx = llama_n_ctx(spec->params.ctx_dft) - spec->params.n_draft; + const int n_ctx = llama_n_ctx(ctx) - spec->params.n_draft; - const int i_start = std::max(0, (int) prompt.size() - n_ctx); + const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); - for (int i = 0; i < (int) spec->prompt_last.size(); ++i) { + for (int i = 0; i < (int) prompt.size(); ++i) { int cur = 0; - while (i_start + cur < (int) prompt.size() && - i + cur < (int) spec->prompt_last.size() && - prompt[i_start + cur] == spec->prompt_last[i + cur]) { + while (i_start + cur < (int) prompt_tgt.size() && + i + cur < (int) prompt.size() && + prompt_tgt[i_start + cur] == prompt[i + cur]) { cur++; } - if ((cur >= spec->params.n_reuse || prompt.size() <= n_ctx) && cur > reuse_n) { + if ((cur >= spec->params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) { reuse_i = i; reuse_n = cur; } @@ -97,59 +165,59 @@ void common_speculative_add_draft( LOG_DBG("%s: reuse_i = %d, reuse_n = %d\n", __func__, reuse_i, reuse_n); if (reuse_n == 0) { - llama_kv_cache_clear(spec->params.ctx_dft); + llama_kv_cache_clear(ctx); - spec->prompt_last.clear(); + prompt.clear(); } else { - llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, 0, reuse_i); - llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, reuse_i + reuse_n, -1); - llama_kv_cache_seq_add(spec->params.ctx_dft, 0, reuse_i, -1, -reuse_i); + llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); + llama_kv_cache_seq_rm (ctx, 0, reuse_i + reuse_n, -1); + llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); - spec->prompt_last.erase(spec->prompt_last.begin(), spec->prompt_last.begin() + reuse_i); - spec->prompt_last.erase(spec->prompt_last.begin() + reuse_n, spec->prompt_last.end()); + prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + prompt.erase(prompt.begin() + reuse_n, prompt.end()); } - common_batch_clear(spec->batch_dft); + common_batch_clear(batch); - for (int i = i_start + reuse_n; i < (int) prompt.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt[i]); - common_batch_add(spec->batch_dft, prompt[i], i - i_start, { 0 }, false); + for (int i = i_start + reuse_n; i < (int) prompt_tgt.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); + common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); - spec->prompt_last.push_back(prompt[i]); + prompt.push_back(prompt_tgt[i]); } - const llama_pos n_past = prompt.size() - i_start; + const llama_pos n_past = prompt_tgt.size() - i_start; LOG_DBG("%s: n_past = %d\n", __func__, n_past); - if (spec->batch_dft.n_tokens > 0) { - LOG_DBG("%s: draft batch: %s\n", __func__, string_from(spec->params.ctx_dft, spec->batch_dft).c_str()); + if (batch.n_tokens > 0) { + LOG_DBG("%s: draft batch: %s\n", __func__, string_from(ctx, batch).c_str()); - llama_decode(spec->params.ctx_dft, spec->batch_dft); + llama_decode(ctx, batch); } - common_batch_clear(spec->batch_dft); - common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true); + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); - spec->prompt_last.push_back(id_last); + prompt.push_back(id_last); - LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(spec->params.ctx_dft, spec->prompt_last).c_str()); + LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(ctx, prompt).c_str()); - llama_decode(spec->params.ctx_dft, spec->batch_dft); + llama_decode(ctx, batch); - common_sampler_reset(spec->smpl); + common_sampler_reset(smpl); // sample n_draft tokens from the draft model for (int i = 0; i < spec->params.n_draft; ++i) { - common_batch_clear(spec->batch_dft); + common_batch_clear(batch); - common_sampler_sample(spec->smpl, spec->params.ctx_dft, 0, true); + common_sampler_sample(smpl, ctx, 0, true); - const auto * cur_p = common_sampler_get_candidates(spec->smpl); + const auto * cur_p = common_sampler_get_candidates(smpl); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(spec->params.ctx_dft, cur_p->data[k].id).c_str()); + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); } // add drafted token for each sequence @@ -160,7 +228,7 @@ void common_speculative_add_draft( break; } - common_sampler_accept(spec->smpl, id, true); + common_sampler_accept(smpl, id, true); common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true); @@ -168,12 +236,12 @@ void common_speculative_add_draft( break; } - common_batch_add(spec->batch_dft, id, n_past + i + 1, { 0 }, true); + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); // evaluate the drafted tokens on the draft model - llama_decode(spec->params.ctx_dft, spec->batch_dft); + llama_decode(ctx, batch); - spec->prompt_last.push_back(id); + prompt.push_back(id); } // don't waste time on small batches diff --git a/common/speculative.h b/common/speculative.h index b3a87e64c..b657b6229 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -11,16 +11,18 @@ struct common_speculative_params { int n_reuse = 256; float p_min = 0.9f; - - struct llama_model * model_dft = nullptr; - - struct llama_context * ctx_dft = nullptr; }; -struct common_speculative * common_speculative_init(struct common_speculative_params params); +struct common_speculative * common_speculative_init( + struct common_speculative_params params, + struct llama_context * ctx_dft); void common_speculative_free(struct common_speculative * spec); +bool common_speculative_are_compatible( + const struct llama_context * ctx_tgt, + const struct llama_context * ctx_dft); + // sample up to n_draft tokens and add them to the batch using the draft model // void common_speculative_add_draft( diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index cb6c35ce1..cdfd5b886 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -5,21 +5,14 @@ #include "log.h" #include "llama.h" -#include #include #include #include #include -#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 -#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 - int main(int argc, char ** argv) { common_params params; - // needed to get candidate probs even for temp <= 0.0 - params.sparams.n_probs = 128; - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { return 1; } @@ -63,55 +56,10 @@ int main(int argc, char ** argv) { model_dft = llama_init_dft.model; ctx_dft = llama_init_dft.context; - const bool vocab_type_tgt = llama_vocab_type(model_tgt); - LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt); - - const bool vocab_type_dft = llama_vocab_type(model_dft); - LOG_DBG("vocab_type dft: %d\n", vocab_type_dft); - - if (vocab_type_tgt != vocab_type_dft) { - LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__); - LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt); + if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { return 1; } - if ( - llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) || - llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) || - llama_token_bos(model_tgt) != llama_token_bos(model_dft) || - llama_token_eos(model_tgt) != llama_token_eos(model_dft) - ) { - LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__); - return 1; - } - - { - const int n_vocab_tgt = llama_n_vocab(model_tgt); - const int n_vocab_dft = llama_n_vocab(model_dft); - const int vocab_diff = n_vocab_tgt > n_vocab_dft - ? n_vocab_tgt - n_vocab_dft - : n_vocab_dft - n_vocab_tgt; - - if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { - LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__); - LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", - n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); - return 1; - } - - for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { - const char * token_text_tgt = llama_token_get_text(model_tgt, i); - const char * token_text_dft = llama_token_get_text(model_dft, i); - if (std::strcmp(token_text_tgt, token_text_dft) != 0) { - LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__); - LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i, - common_token_to_piece(ctx_tgt, i).c_str(), - common_token_to_piece(ctx_dft, i).c_str()); - return 1; - } - } - } - // Tokenize the prompt std::vector inp; inp = common_tokenize(ctx_tgt, params.prompt, true, true); @@ -167,10 +115,8 @@ int main(int argc, char ** argv) { params_spec.n_min = 5; params_spec.n_reuse = 256; params_spec.p_min = 0.9f; - params_spec.model_dft = model_dft; - params_spec.ctx_dft = ctx_dft; - struct common_speculative * spec = common_speculative_init(params_spec); + struct common_speculative * spec = common_speculative_init(params_spec, ctx_dft); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);