speculative : simplify the implementation (#10504)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-26 12:29:38 +02:00 committed by GitHub
parent 9a4b79bcfa
commit 811872a59d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -117,7 +117,8 @@ int main(int argc, char ** argv) {
llama_token id_last = inp.back();
// all tokens currently in the target context
auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1);
llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
int n_past = inp.size() - 1;
@ -181,29 +182,26 @@ int main(int argc, char ** argv) {
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
n_past += ids.size() - 1;
n_drafted += batch_tgt.n_tokens - 1;
n_drafted += draft.size(); // note: we ignore the discarded small drafts
n_accept += ids.size() - 1;
n_predict += ids.size();
// process the accepted tokens and update contexts
//
// this is the standard token post-processing that we normally do
// in this case, we do it for a group of accepted tokens at once
//
{
llama_token id;
std::string token_str;
for (size_t i = 0; i < ids.size(); ++i) {
id = ids[i];
prompt_tgt.push_back(id_last);
++n_predict;
id_last = ids[i];
if (llama_token_is_eog(model_tgt, id)) {
if (llama_token_is_eog(model_tgt, id_last)) {
has_eos = true;
break;
}
token_str = common_token_to_piece(ctx_tgt, id);
const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
if (params.use_color && i + 1 < ids.size()) {
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
@ -212,11 +210,7 @@ int main(int argc, char ** argv) {
}
}
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
break;
}
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
{
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
@ -224,11 +218,8 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
}
prompt_tgt.push_back(id_last);
prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
// remember the last accepted token for the next iteration
id_last = id;
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
break;
}
}