diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 7bf9056bf..2ea49d47c 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -117,7 +117,8 @@ int main(int argc, char ** argv) { llama_token id_last = inp.back(); // all tokens currently in the target context - auto prompt_tgt = std::vector(inp.begin(), inp.end() - 1); + llama_tokens prompt_tgt(inp.begin(), inp.end() - 1); + prompt_tgt.reserve(llama_n_ctx(ctx_tgt)); int n_past = inp.size() - 1; @@ -181,54 +182,44 @@ int main(int argc, char ** argv) { GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token n_past += ids.size() - 1; - n_drafted += batch_tgt.n_tokens - 1; + n_drafted += draft.size(); // note: we ignore the discarded small drafts n_accept += ids.size() - 1; + n_predict += ids.size(); // process the accepted tokens and update contexts // // this is the standard token post-processing that we normally do // in this case, we do it for a group of accepted tokens at once // - { - llama_token id; - std::string token_str; + for (size_t i = 0; i < ids.size(); ++i) { + prompt_tgt.push_back(id_last); - for (size_t i = 0; i < ids.size(); ++i) { - id = ids[i]; + id_last = ids[i]; - ++n_predict; - - if (llama_token_is_eog(model_tgt, id)) { - has_eos = true; - break; - } - - token_str = common_token_to_piece(ctx_tgt, id); - - if (params.use_color && i + 1 < ids.size()) { - LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str()); - } else { - LOG("%s", token_str.c_str()); - } - } - - if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { + if (llama_token_is_eog(model_tgt, id_last)) { + has_eos = true; break; } - LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str()); + const std::string token_str = common_token_to_piece(ctx_tgt, id_last); - { - LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); - - llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); + if (params.use_color && i + 1 < ids.size()) { + LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str()); + } else { + LOG("%s", token_str.c_str()); } + } - prompt_tgt.push_back(id_last); - prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1); + LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last); - // remember the last accepted token for the next iteration - id_last = id; + { + LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past); + + llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1); + } + + if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) { + break; } }