speculative : minor refactor

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-10-16 12:47:37 +03:00
parent 360a333145
commit 1c626e2fe1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -116,8 +116,8 @@ int main(int argc, char ** argv) {
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
params.sampling_params.temp = 1.0f; // the draft samplers use default temperature params.sampling_params.temp = 1.0f; // the draft samplers use default temperature
for (int i = 0; i < n_seq_dft; ++i) { for (int s = 0; s < n_seq_dft; ++s) {
drafts[i].ctx_sampling = llama_sampling_init(params); drafts[s].ctx_sampling = llama_sampling_init(params);
} }
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
@ -131,24 +131,24 @@ int main(int argc, char ** argv) {
while (true) { while (true) {
// print current draft sequences // print current draft sequences
for (int i = 0; i < n_seq_dft; ++i) { for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[i].active) { if (!drafts[s].active) {
continue; continue;
} }
const auto & tokens = drafts[i].tokens; const auto & tokens = drafts[s].tokens;
LOG("draft %d: %s\n", i, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str()); LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
} }
int i_dft = 0; int i_dft = 0;
int i_keep = 0; int s_keep = 0;
while (true) { while (true) {
LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]); LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
// sample from the target model // sample from the target model
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[i_keep].i_batch_tgt[i_dft]); llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
llama_sampling_accept(ctx_sampling, ctx_tgt, id); llama_sampling_accept(ctx_sampling, ctx_tgt, id);
@ -169,18 +169,18 @@ int main(int argc, char ** argv) {
{ {
bool matches = false; bool matches = false;
for (int i = 0; i < n_seq_dft; ++i) { for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[i].active) { if (!drafts[s].active) {
continue; continue;
} }
if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) { if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str()); LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
i_keep = i; s_keep = s;
matches = true; matches = true;
} else { } else {
drafts[i].active = false; drafts[s].active = false;
} }
} }
@ -198,22 +198,22 @@ int main(int argc, char ** argv) {
// TODO: simplify // TODO: simplify
{ {
LOG("keeping sequence %d\n", i_keep); LOG("keeping sequence %d\n", s_keep);
llama_kv_cache_seq_keep(ctx_dft, i_keep); llama_kv_cache_seq_keep(ctx_dft, s_keep);
llama_kv_cache_seq_cp (ctx_dft, i_keep, 0, -1, -1); llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_dft, 0); llama_kv_cache_seq_keep(ctx_dft, 0);
llama_kv_cache_seq_rm (ctx_tgt, i_keep, n_past_tgt, -1); llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
llama_kv_cache_seq_keep(ctx_tgt, i_keep); llama_kv_cache_seq_keep(ctx_tgt, s_keep);
llama_kv_cache_seq_cp (ctx_tgt, i_keep, 0, -1, -1); llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
llama_kv_cache_seq_keep(ctx_tgt, 0); llama_kv_cache_seq_keep(ctx_tgt, 0);
} }
for (int i = 0; i < n_seq_dft; ++i) { for (int s = 0; s < n_seq_dft; ++s) {
drafts[i].active = false; drafts[s].active = false;
drafts[i].tokens.clear(); drafts[s].tokens.clear();
drafts[i].i_batch_tgt.clear(); drafts[s].i_batch_tgt.clear();
} }
// note: will be erased after the speculation phase // note: will be erased after the speculation phase
drafts[0].tokens.push_back(id); drafts[0].tokens.push_back(id);
@ -239,9 +239,9 @@ int main(int argc, char ** argv) {
int n_seq_cur = 1; int n_seq_cur = 1;
int n_past_cur = n_past_dft; int n_past_cur = n_past_dft;
for (int i = 0; i < n_seq_dft; ++i) { for (int s = 0; s < n_seq_dft; ++s) {
drafts[i].active = false; drafts[s].active = false;
drafts[i].drafting = false; drafts[s].drafting = false;
} }
drafts[0].active = true; drafts[0].active = true;
drafts[0].drafting = true; drafts[0].drafting = true;
@ -324,19 +324,14 @@ int main(int argc, char ** argv) {
for (int is = 0; is < (int) sa.size(); ++is) { for (int is = 0; is < (int) sa.size(); ++is) {
const llama_token id = cur_p[is].id; const llama_token id = cur_p[is].id;
int s = sa[is]; const int s = sa[is];
auto & drafted = drafts[s].tokens;
auto & i_batch_dft = drafts[s].i_batch_dft;
auto & i_batch_tgt = drafts[s].i_batch_tgt;
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id); llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
drafted.push_back(id); drafts[s].tokens.push_back(id);
// add unique drafted tokens to the target batch // add unique drafted tokens to the target batch
i_batch_tgt.push_back(batch_tgt.n_tokens); drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
@ -347,7 +342,7 @@ int main(int argc, char ** argv) {
} }
// add the token to the batch for batched decoding with the draft model // add the token to the batch for batched decoding with the draft model
i_batch_dft = batch_dft.n_tokens; drafts[s].i_batch_dft = batch_dft.n_tokens;
llama_batch_add(batch_dft, id, n_past_cur, { s }, true); llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
} }
@ -381,12 +376,12 @@ int main(int argc, char ** argv) {
} }
// the first token is always proposed by the traget model before the speculation loop so we erase it here // the first token is always proposed by the traget model before the speculation loop so we erase it here
for (int i = 0; i < n_seq_dft; ++i) { for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[i].active) { if (!drafts[s].active) {
continue; continue;
} }
drafts[i].tokens.erase(drafts[i].tokens.begin()); drafts[s].tokens.erase(drafts[s].tokens.begin());
} }
} }
@ -411,8 +406,8 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx_tgt); llama_print_timings(ctx_tgt);
llama_sampling_free(ctx_sampling); llama_sampling_free(ctx_sampling);
for (int i = 0; i < n_seq_dft; ++i) { for (int s = 0; s < n_seq_dft; ++s) {
llama_sampling_free(drafts[i].ctx_sampling); llama_sampling_free(drafts[s].ctx_sampling);
} }
llama_batch_free(batch_dft); llama_batch_free(batch_dft);