diff --git a/common/common.cpp b/common/common.cpp index d314523db..43fa8a1ef 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -536,12 +536,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat [](const unsigned char c) { return !std::isprint(c); }), detokenized.end()); - buf << "\n" << std::to_string(i) - << ":token '" << detokenized << "'" - << ":pos " << std::to_string(batch.pos[i]) - << ":n_seq_id " << std::to_string(batch.n_seq_id[i]) - << ":seq_id " << std::to_string(batch.seq_id[i][0]) - << ":logits " << std::to_string(batch.logits[i]); + buf << "\n" << std::to_string(i) + << ", token '" << detokenized << "'" + << ", pos " << std::to_string(batch.pos[i]) + << ", n_seq_id " << std::to_string(batch.n_seq_id[i]) + << ", seq_id " << std::to_string(batch.seq_id[i][0]) + << ", logits " << std::to_string(batch.logits[i]); } buf << " ]"; @@ -1490,6 +1490,66 @@ void common_batch_add( batch.n_tokens++; } +// +// Token utils +// + +size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { + size_t i; + for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} + + return i; +} + +size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { + // check for empty sequences + if (a.empty() || b.empty()) { + return 0; + } + + // get the lengths of the input sequences + size_t a_len = a.size(); + size_t b_len = b.size(); + + // initialize the maximum length of the longest common subsequence (LCS) + size_t max_length = 0; + + // use two rows instead of a 2D matrix to optimize space + std::vector prev_row(b_len + 1, 0); + std::vector curr_row(b_len + 1, 0); + + // iterate through the elements of a + for (size_t i = 1; i <= a_len; i++) { + // iterate through the elements of b + for (size_t j = 1; j <= b_len; j++) { + // if elements at the current positions match + if (a[i - 1] == b[j - 1]) { + // if it's the first element of either sequences, set LCS length to 1 + if (i == 1 || j == 1) { + curr_row[j] = 1; + } else { + // increment LCS length by 1 compared to the previous element + curr_row[j] = prev_row[j - 1] + 1; + } + + // update max_length if necessary + if (curr_row[j] > max_length) { + max_length = curr_row[j]; + } + } else { + // reset LCS length if elements don't match + curr_row[j] = 0; + } + } + + // update the previous row for the next iteration + prev_row = curr_row; + } + + // return the maximum length of the LCS + return max_length; +} + // // Vocab utils // diff --git a/common/common.h b/common/common.h index 7977cc7a9..29d678c7b 100644 --- a/common/common.h +++ b/common/common.h @@ -33,6 +33,8 @@ struct common_lora_adapter_container : common_lora_adapter_info { struct llama_lora_adapter * adapter; }; +using llama_tokens = std::vector; + // build info extern int LLAMA_BUILD_NUMBER; extern char const * LLAMA_COMMIT; @@ -461,7 +463,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f // clear LoRA adapters from context, then apply new list of adapters void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); +// // Batch utils +// void common_batch_clear(struct llama_batch & batch); @@ -472,6 +476,16 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +// +// Token utils +// + +// longest common prefix +size_t common_lcp(const llama_tokens & a, const llama_tokens & b); + +// longet common subsequence +size_t common_lcs(const llama_tokens & a, const llama_tokens & b); + // // Vocab utils // diff --git a/common/sampling.cpp b/common/sampling.cpp index fe1ef5bf9..f90ac8b90 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -342,6 +342,28 @@ std::vector common_sampler_sample_n(struct common_sampler * gsmpl, return result; } +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first) { + std::vector idxs; + idxs.reserve(batch.n_tokens); + + std::vector draft; + draft.reserve(batch.n_tokens); + + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i] == 0) { + continue; + } + + if (idxs.size() > 0) { + GGML_ASSERT(batch.pos[idxs.back()] + 1 == batch.pos[i]); + draft.push_back(batch.token[i]); + } + idxs.push_back(i); + } + + return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first); +} + uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { return llama_sampler_get_seed(gsmpl->chain); } diff --git a/common/sampling.h b/common/sampling.h index 23cfae1ac..ba496ac27 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -73,6 +73,8 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const std::vector & draft, bool grammar_first = false); +std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first = false); + uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); // helpers diff --git a/common/speculative.cpp b/common/speculative.cpp index 2726760ad..6acf84a23 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -11,9 +11,7 @@ struct common_speculative { struct common_sampler * smpl; - std::vector i_batch_tgt; - - std::vector tokens; + llama_tokens prompt_last; }; struct common_speculative * common_speculative_init(struct common_speculative_params params) { @@ -21,12 +19,10 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa /* .params = */ params, /* .batch_dft = */ llama_batch_init(llama_n_batch(params.ctx_dft), 0, 1), /* .smpl = */ nullptr, - /* .i_batch_tgt = */ {}, - /* .tokens = */ {}, }; // TODO: optimize or pass from outside? -#if 0 +#if 1 { common_sampler_params sparams; sparams.no_perf = false; @@ -70,30 +66,79 @@ void common_speculative_free(struct common_speculative * spec) { delete spec; } -void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens) { - llama_kv_cache_clear(spec->params.ctx_dft); - - // TODO: error handling - llama_decode(spec->params.ctx_dft, llama_batch_get_one(tokens, n_tokens)); -} - void common_speculative_add_draft( struct common_speculative * spec, struct llama_batch & batch_tgt, + const llama_tokens & prompt, llama_token id_last, - int n_past) { - spec->tokens.clear(); + llama_token n_past_tgt) { - spec->i_batch_tgt.clear(); - spec->i_batch_tgt.push_back(0); + int reuse_i = 0; + int reuse_n = 0; - common_sampler_reset(spec->smpl); + const int n_ctx = llama_n_ctx(spec->params.ctx_dft) - spec->params.n_draft; + + const int i_start = std::max(0, (int) prompt.size() - n_ctx); + + for (int i = 0; i < (int) spec->prompt_last.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt.size() && + i + cur < (int) spec->prompt_last.size() && + prompt[i_start + cur] == spec->prompt_last[i + cur]) { + cur++; + } + + if ((cur >= spec->params.n_reuse || prompt.size() <= n_ctx) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } + } + + LOG_DBG("%s: reuse_i = %d, reuse_n = %d\n", __func__, reuse_i, reuse_n); + + if (reuse_n == 0) { + llama_kv_cache_clear(spec->params.ctx_dft); + + spec->prompt_last.clear(); + } else { + llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, 0, reuse_i); + llama_kv_cache_seq_rm (spec->params.ctx_dft, 0, reuse_i + reuse_n, -1); + llama_kv_cache_seq_add(spec->params.ctx_dft, 0, reuse_i, -1, -reuse_i); + + spec->prompt_last.erase(spec->prompt_last.begin(), spec->prompt_last.begin() + reuse_i); + spec->prompt_last.erase(spec->prompt_last.begin() + reuse_n, spec->prompt_last.end()); + } + + common_batch_clear(spec->batch_dft); + + for (int i = i_start + reuse_n; i < (int) prompt.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt[i]); + common_batch_add(spec->batch_dft, prompt[i], i - i_start, { 0 }, false); + + spec->prompt_last.push_back(prompt[i]); + } + + const llama_pos n_past = prompt.size() - i_start; + + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + if (spec->batch_dft.n_tokens > 0) { + LOG_DBG("%s: draft batch: %s\n", __func__, string_from(spec->params.ctx_dft, spec->batch_dft).c_str()); + + llama_decode(spec->params.ctx_dft, spec->batch_dft); + } common_batch_clear(spec->batch_dft); common_batch_add (spec->batch_dft, id_last, n_past, { 0 }, true); + spec->prompt_last.push_back(id_last); + + LOG_DBG("%s: prompt_last: %s\n", __func__, string_from(spec->params.ctx_dft, spec->prompt_last).c_str()); + llama_decode(spec->params.ctx_dft, spec->batch_dft); + common_sampler_reset(spec->smpl); + // sample n_draft tokens from the draft model for (int i = 0; i < spec->params.n_draft; ++i) { common_batch_clear(spec->batch_dft); @@ -111,18 +156,13 @@ void common_speculative_add_draft( const llama_token id = cur_p->data[0].id; // only collect very high-confidence draft tokens - if (cur_p->data[0].p < 0.75 && spec->tokens.size() >= 0) { + if (cur_p->data[0].p < spec->params.p_min) { break; } common_sampler_accept(spec->smpl, id, true); - spec->tokens.push_back(id); - - // add unique drafted tokens to the target batch - spec->i_batch_tgt.push_back(batch_tgt.n_tokens); - - common_batch_add(batch_tgt, id, n_past + i + 1, { 0 }, true); + common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true); if (batch_tgt.n_tokens > spec->params.n_draft) { break; @@ -132,23 +172,13 @@ void common_speculative_add_draft( // evaluate the drafted tokens on the draft model llama_decode(spec->params.ctx_dft, spec->batch_dft); + + spec->prompt_last.push_back(id); } // don't waste time on small batches // TODO: do not evaluate the draft model for that many rounds if (batch_tgt.n_tokens < spec->params.n_min) { batch_tgt.n_tokens = 1; - spec->tokens.resize(0); - spec->i_batch_tgt.resize(1); } - - // print current draft sequences - LOG_DBG("draft %s\n", string_from(spec->params.ctx_dft, spec->tokens).c_str()); -} - -std::vector common_speculative_sample( - struct common_speculative * spec, - struct common_sampler * smpl, - struct llama_context * ctx_tgt) { - return common_sampler_sample_n(smpl, ctx_tgt, spec->i_batch_tgt, spec->tokens); } diff --git a/common/speculative.h b/common/speculative.h index a2df2667a..b3a87e64c 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -1,14 +1,16 @@ #pragma once #include "llama.h" - -#include +#include "common.h" struct common_speculative; struct common_speculative_params { int n_draft = 16; int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user? + int n_reuse = 256; + + float p_min = 0.9f; struct llama_model * model_dft = nullptr; @@ -19,28 +21,11 @@ struct common_speculative * common_speculative_init(struct common_speculative_pa void common_speculative_free(struct common_speculative * spec); -// TODO: remove -void common_speculative_set_prompt(struct common_speculative * spec, llama_token * tokens, int32_t n_tokens); - // sample up to n_draft tokens and add them to the batch using the draft model // -// TODO: change to: -// -// void common_speculative_add_draft( -// struct common_speculative * spec, -// struct llama_batch & batch_tgt, -// llama_token * tokens, -// int32_t n_tokens); -// -// and update the internal logic to compute only the new tokens -// void common_speculative_add_draft( struct common_speculative * spec, struct llama_batch & batch_tgt, + const llama_tokens & prompt, llama_token id_last, - int n_past); - -std::vector common_speculative_sample( - struct common_speculative * spec, - struct common_sampler * smpl, - struct llama_context * ctx_tgt); + llama_token n_past_tgt); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b8e003be9..b7b2cbe5a 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -743,7 +743,7 @@ struct server_context { } // length of the Longest Common Subsequence between the current slot's prompt and the input prompt - int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens); + int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens); // fraction of the common subsequence length compared to the current slot's prompt length float cur_similarity = static_cast(cur_lcs_len) / static_cast(slot.cache_tokens.size()); @@ -1960,7 +1960,7 @@ struct server_context { if (slot.params.cache_prompt) { // reuse any previously computed tokens that are common with the new prompt - slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens); + slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens); // reuse chunks from the cached prompt by shifting their KV cache in the new position if (params.n_cache_reuse > 0) { diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index c47ed3e47..1665e9dc3 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -24,7 +24,6 @@ #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613" using json = nlohmann::ordered_json; -using llama_tokens = std::vector; #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__) @@ -439,62 +438,6 @@ static std::string gen_chatcmplid() { // other common utils // -static size_t longest_common_prefix(const llama_tokens & a, const llama_tokens & b) { - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} - - return i; -} - -static size_t longest_common_subsequence(const llama_tokens & a, const llama_tokens & b) { - // check for empty sequences - if (a.empty() || b.empty()) { - return 0; - } - - // get the lengths of the input sequences - size_t a_len = a.size(); - size_t b_len = b.size(); - - // initialize the maximum length of the longest common subsequence (LCS) - size_t max_length = 0; - - // use two rows instead of a 2D matrix to optimize space - std::vector prev_row(b_len + 1, 0); - std::vector curr_row(b_len + 1, 0); - - // iterate through the elements of a - for (size_t i = 1; i <= a_len; i++) { - // iterate through the elements of b - for (size_t j = 1; j <= b_len; j++) { - // if elements at the current positions match - if (a[i - 1] == b[j - 1]) { - // if it's the first element of either sequences, set LCS length to 1 - if (i == 1 || j == 1) { - curr_row[j] = 1; - } else { - // increment LCS length by 1 compared to the previous element - curr_row[j] = prev_row[j - 1] + 1; - } - - // update max_length if necessary - if (curr_row[j] > max_length) { - max_length = curr_row[j]; - } - } else { - // reset LCS length if elements don't match - curr_row[j] = 0; - } - } - - // update the previous row for the next iteration - prev_row = curr_row; - } - - // return the maximum length of the LCS - return max_length; -} - static bool ends_with(const std::string & str, const std::string & suffix) { return str.size() >= suffix.size() && 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); } diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index aeccfd369..cb6c35ce1 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -14,14 +14,6 @@ #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 -struct seq_draft { - std::vector i_batch_tgt; - - std::vector tokens; - - struct common_sampler * smpl = nullptr; -}; - int main(int argc, char ** argv) { common_params params; @@ -165,27 +157,21 @@ int main(int argc, char ** argv) { // note: keep the last token separate! llama_token id_last = inp.back(); + auto prompt_dft = std::vector(inp.begin(), inp.end() - 1); + int n_past = inp.size() - 1; // init the speculator struct common_speculative_params params_spec; params_spec.n_draft = n_draft; params_spec.n_min = 5; + params_spec.n_reuse = 256; + params_spec.p_min = 0.9f; params_spec.model_dft = model_dft; params_spec.ctx_dft = ctx_dft; struct common_speculative * spec = common_speculative_init(params_spec); - // feed the prompt to the speculator - // - // this has to be kept synchronized with the target context - // - // TODO: simplify this by moving the context management logic in the common_speculative instance - // for example, the common_speculative_add_draft can pass the entire context (or part of it) and the - // speculator will automatically compute any new tokens that are not present in its context - // - common_speculative_set_prompt(spec, inp.data(), n_input - 1); - llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); const auto t_enc_end = ggml_time_us(); @@ -204,7 +190,7 @@ int main(int argc, char ** argv) { // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens // from a cache or lookup tables. // - common_speculative_add_draft(spec, batch_tgt, id_last, n_past); + common_speculative_add_draft(spec, batch_tgt, prompt_dft, id_last, n_past + 1); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { @@ -220,7 +206,7 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - const auto ids = common_speculative_sample(spec, smpl, ctx_tgt); + const auto ids = common_sampler_sample_n(smpl, ctx_tgt, batch_tgt); GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token @@ -266,9 +252,11 @@ int main(int argc, char ** argv) { LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); - llama_kv_cache_seq_rm(ctx_dft, 0, n_past, -1); } + prompt_dft.push_back(id_last); + prompt_dft.insert(prompt_dft.end(), ids.begin(), ids.end() - 1); + // remember the last accepted token for the next iteration id_last = id; }