From 1c626e2fe1ff3e46fd3329c467c5a50484ec6592 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 16 Oct 2023 12:47:37 +0300 Subject: [PATCH] speculative : minor refactor ggml-ci --- examples/speculative/speculative.cpp | 79 +++++++++++++--------------- 1 file changed, 37 insertions(+), 42 deletions(-) diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index e873214ee..117f1b41f 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -116,8 +116,8 @@ int main(int argc, char ** argv) { params.grammar.clear(); // the draft samplers will copy the target sampler's grammar params.sampling_params.temp = 1.0f; // the draft samplers use default temperature - for (int i = 0; i < n_seq_dft; ++i) { - drafts[i].ctx_sampling = llama_sampling_init(params); + for (int s = 0; s < n_seq_dft; ++s) { + drafts[s].ctx_sampling = llama_sampling_init(params); } llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); @@ -131,24 +131,24 @@ int main(int argc, char ** argv) { while (true) { // print current draft sequences - for (int i = 0; i < n_seq_dft; ++i) { - if (!drafts[i].active) { + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { continue; } - const auto & tokens = drafts[i].tokens; + const auto & tokens = drafts[s].tokens; - LOG("draft %d: %s\n", i, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); + LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); } int i_dft = 0; - int i_keep = 0; + int s_keep = 0; while (true) { - LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]); + LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]); // sample from the target model - llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[i_keep].i_batch_tgt[i_dft]); + llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]); llama_sampling_accept(ctx_sampling, ctx_tgt, id); @@ -169,18 +169,18 @@ int main(int argc, char ** argv) { { bool matches = false; - for (int i = 0; i < n_seq_dft; ++i) { - if (!drafts[i].active) { + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { continue; } - if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) { - LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str()); + if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) { + LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str()); - i_keep = i; + s_keep = s; matches = true; } else { - drafts[i].active = false; + drafts[s].active = false; } } @@ -198,22 +198,22 @@ int main(int argc, char ** argv) { // TODO: simplify { - LOG("keeping sequence %d\n", i_keep); + LOG("keeping sequence %d\n", s_keep); - llama_kv_cache_seq_keep(ctx_dft, i_keep); - llama_kv_cache_seq_cp (ctx_dft, i_keep, 0, -1, -1); + llama_kv_cache_seq_keep(ctx_dft, s_keep); + llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); llama_kv_cache_seq_keep(ctx_dft, 0); - llama_kv_cache_seq_rm (ctx_tgt, i_keep, n_past_tgt, -1); - llama_kv_cache_seq_keep(ctx_tgt, i_keep); - llama_kv_cache_seq_cp (ctx_tgt, i_keep, 0, -1, -1); + llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); + llama_kv_cache_seq_keep(ctx_tgt, s_keep); + llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); llama_kv_cache_seq_keep(ctx_tgt, 0); } - for (int i = 0; i < n_seq_dft; ++i) { - drafts[i].active = false; - drafts[i].tokens.clear(); - drafts[i].i_batch_tgt.clear(); + for (int s = 0; s < n_seq_dft; ++s) { + drafts[s].active = false; + drafts[s].tokens.clear(); + drafts[s].i_batch_tgt.clear(); } // note: will be erased after the speculation phase drafts[0].tokens.push_back(id); @@ -239,9 +239,9 @@ int main(int argc, char ** argv) { int n_seq_cur = 1; int n_past_cur = n_past_dft; - for (int i = 0; i < n_seq_dft; ++i) { - drafts[i].active = false; - drafts[i].drafting = false; + for (int s = 0; s < n_seq_dft; ++s) { + drafts[s].active = false; + drafts[s].drafting = false; } drafts[0].active = true; drafts[0].drafting = true; @@ -324,19 +324,14 @@ int main(int argc, char ** argv) { for (int is = 0; is < (int) sa.size(); ++is) { const llama_token id = cur_p[is].id; - int s = sa[is]; - - auto & drafted = drafts[s].tokens; - - auto & i_batch_dft = drafts[s].i_batch_dft; - auto & i_batch_tgt = drafts[s].i_batch_tgt; + const int s = sa[is]; llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id); - drafted.push_back(id); + drafts[s].tokens.push_back(id); // add unique drafted tokens to the target batch - i_batch_tgt.push_back(batch_tgt.n_tokens); + drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); @@ -347,7 +342,7 @@ int main(int argc, char ** argv) { } // add the token to the batch for batched decoding with the draft model - i_batch_dft = batch_dft.n_tokens; + drafts[s].i_batch_dft = batch_dft.n_tokens; llama_batch_add(batch_dft, id, n_past_cur, { s }, true); } @@ -381,12 +376,12 @@ int main(int argc, char ** argv) { } // the first token is always proposed by the traget model before the speculation loop so we erase it here - for (int i = 0; i < n_seq_dft; ++i) { - if (!drafts[i].active) { + for (int s = 0; s < n_seq_dft; ++s) { + if (!drafts[s].active) { continue; } - drafts[i].tokens.erase(drafts[i].tokens.begin()); + drafts[s].tokens.erase(drafts[s].tokens.begin()); } } @@ -411,8 +406,8 @@ int main(int argc, char ** argv) { llama_print_timings(ctx_tgt); llama_sampling_free(ctx_sampling); - for (int i = 0; i < n_seq_dft; ++i) { - llama_sampling_free(drafts[i].ctx_sampling); + for (int s = 0; s < n_seq_dft; ++s) { + llama_sampling_free(drafts[s].ctx_sampling); } llama_batch_free(batch_dft);