From 8f419181d1c20d8195148680df15b6f093cb1512 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 24 Nov 2024 19:19:12 +0200 Subject: [PATCH] common : final touches ggml-ci --- common/common.h | 2 +- common/speculative.cpp | 37 +++++++++++++------ common/speculative.h | 13 +++---- .../speculative-simple/speculative-simple.cpp | 17 +++++++-- 4 files changed, 45 insertions(+), 24 deletions(-) diff --git a/common/common.h b/common/common.h index c9fb2b62a..5c579b5ab 100644 --- a/common/common.h +++ b/common/common.h @@ -156,7 +156,7 @@ struct common_params_sampling { }; struct common_params_speculative { - int32_t n_ctx = 4096; // draft context size + int32_t n_ctx = 0; // draft context size int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding int32_t n_min = 5; // minimum number of draft tokens to use for speculative decoding int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) diff --git a/common/speculative.cpp b/common/speculative.cpp index 316ea9e1e..fe315a270 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -142,6 +142,8 @@ llama_tokens common_speculative_gen_draft( const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt for (int i = 0; i < (int) prompt.size(); ++i) { int cur = 0; while (i_start + cur < (int) prompt_tgt.size() && @@ -166,6 +168,8 @@ llama_tokens common_speculative_gen_draft( prompt.clear(); } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute if (reuse_i + reuse_n < (int) prompt.size() && prompt[reuse_i + reuse_n] == id_last) { for (int i = reuse_i + reuse_n + 1; i < (int) prompt.size(); ++i) { result.push_back(prompt[i]); @@ -174,42 +178,51 @@ llama_tokens common_speculative_gen_draft( break; } } + return result; } - llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); - llama_kv_cache_seq_rm (ctx, 0, reuse_i + reuse_n, -1); - llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); + if (reuse_i > 0) { + llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i); + llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i); - prompt.erase(prompt.begin(), prompt.begin() + reuse_i); - prompt.erase(prompt.begin() + reuse_n, prompt.end()); + prompt.erase(prompt.begin(), prompt.begin() + reuse_i); + } + + if (reuse_n < (int) prompt.size()) { + llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1); + + prompt.erase(prompt.begin() + reuse_n, prompt.end()); + } } + // prepare a batch to evaluate any new tokens in the prompt common_batch_clear(batch); - for (int i = i_start + reuse_n; i < (int) prompt_tgt.size(); ++i) { + for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); prompt.push_back(prompt_tgt[i]); } - const llama_pos n_past = prompt_tgt.size() - i_start; - - LOG_DBG("%s: n_past = %d\n", __func__, n_past); - + // we should rarely end-up here during normal decoding if (batch.n_tokens > 0) { - LOG_DBG("%s: draft batch: %s\n", __func__, string_from(ctx, batch).c_str()); + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); llama_decode(ctx, batch); } + const llama_pos n_past = prompt.size(); + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + common_batch_clear(batch); common_batch_add (batch, id_last, n_past, { 0 }, true); prompt.push_back(id_last); - LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(ctx, prompt).c_str()); + //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); llama_decode(ctx, batch); diff --git a/common/speculative.h b/common/speculative.h index 9fb669fde..50ec03446 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -6,10 +6,10 @@ struct common_speculative; struct common_speculative_params { - int n_draft = 16; + int n_draft = 16; // max drafted tokens int n_reuse = 256; - float p_min = 0.9f; + float p_min = 0.9f; // min probabiliy required to accept a token in the draft }; struct common_speculative * common_speculative_init(struct llama_context * ctx_dft); @@ -21,9 +21,8 @@ bool common_speculative_are_compatible( const struct llama_context * ctx_dft); // sample up to n_draft tokens and add them to the batch using the draft model -// llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt, - llama_token id_last); + struct common_speculative * spec, + struct common_speculative_params params, + const llama_tokens & prompt, + llama_token id_last); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index ed3e6a466..1bc7f428c 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -46,8 +46,11 @@ int main(int argc, char ** argv) { ctx_tgt = llama_init_tgt.context; // load the draft model - params.model = params.speculative.model; + params.model = params.speculative.model; + params.n_ctx = params.speculative.n_ctx; + params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch; params.n_gpu_layers = params.speculative.n_gpu_layers; + if (params.speculative.cpuparams.n_threads > 0) { params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; } @@ -66,8 +69,14 @@ int main(int argc, char ** argv) { std::vector inp; inp = common_tokenize(ctx_tgt, params.prompt, true, true); - if ((int) inp.size() > llama_n_ctx(ctx_tgt)) { - LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); + if (llama_n_ctx(ctx_tgt) < (int) inp.size()) { + LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt)); + + return 1; + } + + if (llama_n_batch(ctx_tgt) < (int) inp.size()) { + LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt)); return 1; } @@ -114,7 +123,7 @@ int main(int argc, char ** argv) { // init the speculator struct common_speculative_params params_spec; params_spec.n_draft = n_draft; - params_spec.n_reuse = 256; + params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft; params_spec.p_min = p_min; struct common_speculative * spec = common_speculative_init(ctx_dft);