mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 12:10:18 +00:00
speculative : minor refactor
ggml-ci
This commit is contained in:
parent
360a333145
commit
1c626e2fe1
@ -116,8 +116,8 @@ int main(int argc, char ** argv) {
|
||||
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
||||
params.sampling_params.temp = 1.0f; // the draft samplers use default temperature
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
drafts[i].ctx_sampling = llama_sampling_init(params);
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].ctx_sampling = llama_sampling_init(params);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||
@ -131,24 +131,24 @@ int main(int argc, char ** argv) {
|
||||
|
||||
while (true) {
|
||||
// print current draft sequences
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
if (!drafts[i].active) {
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].active) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto & tokens = drafts[i].tokens;
|
||||
const auto & tokens = drafts[s].tokens;
|
||||
|
||||
LOG("draft %d: %s\n", i, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
|
||||
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
|
||||
}
|
||||
|
||||
int i_dft = 0;
|
||||
int i_keep = 0;
|
||||
int s_keep = 0;
|
||||
|
||||
while (true) {
|
||||
LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
// sample from the target model
|
||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, id);
|
||||
|
||||
@ -169,18 +169,18 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
bool matches = false;
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
if (!drafts[i].active) {
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].active) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) {
|
||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str());
|
||||
if (i_dft < (int) drafts[s].tokens.size() && id == drafts[s].tokens[i_dft]) {
|
||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, id, token_str.c_str());
|
||||
|
||||
i_keep = i;
|
||||
s_keep = s;
|
||||
matches = true;
|
||||
} else {
|
||||
drafts[i].active = false;
|
||||
drafts[s].active = false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -198,22 +198,22 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// TODO: simplify
|
||||
{
|
||||
LOG("keeping sequence %d\n", i_keep);
|
||||
LOG("keeping sequence %d\n", s_keep);
|
||||
|
||||
llama_kv_cache_seq_keep(ctx_dft, i_keep);
|
||||
llama_kv_cache_seq_cp (ctx_dft, i_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
||||
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_dft, 0);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx_tgt, i_keep, n_past_tgt, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, i_keep);
|
||||
llama_kv_cache_seq_cp (ctx_tgt, i_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
|
||||
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
drafts[i].active = false;
|
||||
drafts[i].tokens.clear();
|
||||
drafts[i].i_batch_tgt.clear();
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].active = false;
|
||||
drafts[s].tokens.clear();
|
||||
drafts[s].i_batch_tgt.clear();
|
||||
}
|
||||
// note: will be erased after the speculation phase
|
||||
drafts[0].tokens.push_back(id);
|
||||
@ -239,9 +239,9 @@ int main(int argc, char ** argv) {
|
||||
int n_seq_cur = 1;
|
||||
int n_past_cur = n_past_dft;
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
drafts[i].active = false;
|
||||
drafts[i].drafting = false;
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].active = false;
|
||||
drafts[s].drafting = false;
|
||||
}
|
||||
drafts[0].active = true;
|
||||
drafts[0].drafting = true;
|
||||
@ -324,19 +324,14 @@ int main(int argc, char ** argv) {
|
||||
for (int is = 0; is < (int) sa.size(); ++is) {
|
||||
const llama_token id = cur_p[is].id;
|
||||
|
||||
int s = sa[is];
|
||||
|
||||
auto & drafted = drafts[s].tokens;
|
||||
|
||||
auto & i_batch_dft = drafts[s].i_batch_dft;
|
||||
auto & i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
const int s = sa[is];
|
||||
|
||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
|
||||
|
||||
drafted.push_back(id);
|
||||
drafts[s].tokens.push_back(id);
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
|
||||
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
||||
|
||||
@ -347,7 +342,7 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// add the token to the batch for batched decoding with the draft model
|
||||
i_batch_dft = batch_dft.n_tokens;
|
||||
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
||||
|
||||
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
||||
}
|
||||
@ -381,12 +376,12 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// the first token is always proposed by the traget model before the speculation loop so we erase it here
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
if (!drafts[i].active) {
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].active) {
|
||||
continue;
|
||||
}
|
||||
|
||||
drafts[i].tokens.erase(drafts[i].tokens.begin());
|
||||
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
||||
}
|
||||
}
|
||||
|
||||
@ -411,8 +406,8 @@ int main(int argc, char ** argv) {
|
||||
llama_print_timings(ctx_tgt);
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
llama_sampling_free(drafts[i].ctx_sampling);
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
llama_sampling_free(drafts[s].ctx_sampling);
|
||||
}
|
||||
|
||||
llama_batch_free(batch_dft);
|
||||
|
Loading…
Reference in New Issue
Block a user