speculative : refactor sampling

This commit is contained in:
Georgi Gerganov 2023-10-15 22:30:59 +03:00
parent 32a67cbd16
commit 4a7f43f28c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 210 additions and 205 deletions

View File

@ -1,27 +1,69 @@
#include "sampling.h" #include "sampling.h"
llama_sampling_context llama_sampling_context_init( struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
const struct gpt_params & params, struct llama_sampling_context * result =
llama_grammar * grammar) { (struct llama_sampling_context *) malloc(sizeof(struct llama_sampling_context));
llama_sampling_context result;
result.params = params.sampling_params; result->params = params.sampling_params;
result.grammar = grammar; result->grammar = nullptr;
// if there is a grammar, parse it
if (!params.grammar.empty()) {
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (result->parsed_grammar.rules.empty()) {
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
return nullptr;
}
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
result->grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
}
result->prev.resize(params.n_ctx);
return result; return result;
} }
llama_token llama_sampling_sample( void llama_sampling_free(struct llama_sampling_context * ctx) {
struct llama_context * ctx, if (ctx->grammar != NULL) {
struct llama_context * ctx_guidance, llama_grammar_free(ctx->grammar);
struct llama_sampling_context & ctx_sampling, }
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates, free(ctx);
const int idx) { }
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(llama_get_model(ctx)); void llama_sampling_reset(llama_sampling_context * ctx) {
if (ctx->grammar != NULL) {
llama_grammar_free(ctx->grammar);
}
if (!ctx->parsed_grammar.rules.empty()) {
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
ctx->grammar = llama_grammar_init(
grammar_rules.data(),
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
}
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
ctx->cur.clear();
}
llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_guidance,
const int idx) {
const int n_ctx = llama_n_ctx(ctx_main);
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
const llama_sampling_params & params = ctx_sampling->params;
const llama_sampling_params & params = ctx_sampling.params;
const float temp = params.temp; const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
const float top_p = params.top_p; const float top_p = params.top_p;
@ -36,41 +78,45 @@ llama_token llama_sampling_sample(
const float mirostat_eta = params.mirostat_eta; const float mirostat_eta = params.mirostat_eta;
const bool penalize_nl = params.penalize_nl; const bool penalize_nl = params.penalize_nl;
auto & prev = ctx_sampling->prev;
auto & cur = ctx_sampling->cur;
llama_token id = 0; llama_token id = 0;
float * logits = llama_get_logits_ith(ctx, idx); float * logits = llama_get_logits_ith(ctx_main, idx);
// Apply params.logit_bias map // Apply params.logit_bias map
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
logits[it->first] += it->second; logits[it->first] += it->second;
} }
candidates.clear(); cur.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) { for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
} }
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; llama_token_data_array cur_p = { cur.data(), cur.size(), false };
if (ctx_guidance) { if (ctx_guidance) {
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_guidance, params.cfg_scale);
} }
// apply penalties // apply penalties
if (!last_tokens.empty()) { if (!prev.empty()) {
const float nl_logit = logits[llama_token_nl(ctx)]; const float nl_logit = logits[llama_token_nl(ctx_main)];
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx); const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
llama_sample_repetition_penalty(ctx, &cur_p, llama_sample_repetition_penalty(ctx_main, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat, prev.data() + prev.size() - last_n_repeat,
last_n_repeat, repeat_penalty); last_n_repeat, repeat_penalty);
llama_sample_frequency_and_presence_penalties(ctx, &cur_p, llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
last_tokens.data() + last_tokens.size() - last_n_repeat, prev.data() + prev.size() - last_n_repeat,
last_n_repeat, alpha_frequency, alpha_presence); last_n_repeat, alpha_frequency, alpha_presence);
if (!penalize_nl) { if (!penalize_nl) {
for (size_t idx = 0; idx < cur_p.size; idx++) { for (size_t idx = 0; idx < cur_p.size; idx++) {
if (cur_p.data[idx].id == llama_token_nl(ctx)) { if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
cur_p.data[idx].logit = nl_logit; cur_p.data[idx].logit = nl_logit;
break; break;
} }
@ -78,50 +124,58 @@ llama_token llama_sampling_sample(
} }
} }
if (ctx_sampling.grammar != NULL) { if (ctx_sampling->grammar != NULL) {
llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar); llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
} }
if (temp <= 0) { if (temp <= 0) {
// Greedy sampling // Greedy sampling
id = llama_sample_token_greedy(ctx, &cur_p); id = llama_sample_token_greedy(ctx_main, &cur_p);
} else { } else {
if (mirostat == 1) { if (mirostat == 1) {
const int mirostat_m = 100; const int mirostat_m = 100;
llama_sample_temp(ctx, &cur_p, temp); llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu); id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
} else if (mirostat == 2) { } else if (mirostat == 2) {
llama_sample_temp(ctx, &cur_p, temp); llama_sample_temp(ctx_main, &cur_p, temp);
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu); id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
} else { } else {
// Temperature sampling // Temperature sampling
size_t min_keep = std::max(1, params.n_probs); size_t min_keep = std::max(1, params.n_probs);
llama_sample_top_k (ctx, &cur_p, top_k, min_keep); llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep); llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
llama_sample_typical (ctx, &cur_p, typical_p, min_keep); llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
llama_sample_top_p (ctx, &cur_p, top_p, min_keep); llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
llama_sample_temp(ctx, &cur_p, temp); llama_sample_temp (ctx_main, &cur_p, temp);
{ id = llama_sample_token(ctx_main, &cur_p);
const int n_top = 10;
LOG("top %d candidates:\n", n_top);
for (int i = 0; i < n_top; i++) { //{
const llama_token id = cur_p.data[i].id; // const int n_top = 10;
(void)id; // To avoid a warning that id is unused when logging is disabled. // LOG("top %d candidates:\n", n_top);
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
// for (int i = 0; i < n_top; i++) {
// const llama_token id = cur_p.data[i].id;
// (void)id; // To avoid a warning that id is unused when logging is disabled.
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
// }
//}
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
} }
} }
id = llama_sample_token(ctx, &cur_p);
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
}
}
if (ctx_sampling.grammar != NULL) {
llama_grammar_accept_token(ctx, ctx_sampling.grammar, id);
}
return id; return id;
} }
void llama_sampling_accept(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id) {
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
ctx_sampling->prev.push_back(id);
if (ctx_sampling->grammar != NULL) {
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
}
}

View File

@ -2,6 +2,8 @@
#include "llama.h" #include "llama.h"
#include "grammar-parser.h"
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_map> #include <unordered_map>
@ -35,7 +37,8 @@ typedef struct llama_sampling_params {
} llama_sampling_params; } llama_sampling_params;
// general sampler context // general sampler context
typedef struct llama_sampling_context { // TODO: move to llama.h
struct llama_sampling_context {
// parameters that will be used for sampling // parameters that will be used for sampling
llama_sampling_params params; llama_sampling_params params;
@ -43,45 +46,50 @@ typedef struct llama_sampling_context {
float mirostat_mu; float mirostat_mu;
llama_grammar * grammar; llama_grammar * grammar;
} llama_sampling_context;
// internal
grammar_parser::parse_state parsed_grammar;
std::vector<llama_token> prev;
std::vector<llama_token_data> cur;
};
#include "common.h" #include "common.h"
// Create a new sampling context instance. // Create a new sampling context instance.
llama_sampling_context llama_sampling_context_init( struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
const struct gpt_params & params,
llama_grammar * grammar = NULL);
// Reset the sampler context for the supplied sequence id (defaults to 0). void llama_sampling_free(struct llama_sampling_context * ctx);
// This is necessary to reuse a sequence id or free memory used by sequences
// that are no longer required. // Reset the sampler context
bool llama_sampling_context_reset( // - clear prev tokens
llama_sampling_context & ctx_sampling, // - reset grammar
const llama_seq_id seq = 0); void llama_sampling_reset(llama_sampling_context * ctx);
// this is a common sampling function used across the examples for convenience // this is a common sampling function used across the examples for convenience
// it can serve as a starting point for implementing your own sampling function // it can serve as a starting point for implementing your own sampling function
// Note: When using multiple sequences, it is the caller's responsibility to call // Note: When using multiple sequences, it is the caller's responsibility to call
// llama_sampling_context_reset when a sequence ends // llama_sampling_reset when a sequence ends
// //
// required: // required:
// - ctx: context to use for sampling // - ctx_main: context to use for sampling
// - ctx_sampling: sampling-specific context // - ctx_sampling: sampling-specific context
// //
// optional: // optional:
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL // - ctx_guidance: context to use for guidance
// - last_tokens: needed for repetition penalty, ignore if empty
// - idx: sample from llama_get_logits_ith(ctx, idx) // - idx: sample from llama_get_logits_ith(ctx, idx)
// - seq: sequence id to associate sampler state with
// //
// returns: // returns:
// - token: sampled token // - token: sampled token
// - candidates: vector of candidate tokens // - candidates: vector of candidate tokens
// //
llama_token llama_sampling_sample( llama_token llama_sampling_sample(
struct llama_context * ctx, struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_guidance, struct llama_context * ctx_guidance,
struct llama_sampling_context & ctx_sampling, int idx = 0);
const std::vector<llama_token> & last_tokens,
std::vector<llama_token_data> & candidates, void llama_sampling_accept(
const int idx = 0); struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
llama_token id);

View File

@ -2,7 +2,6 @@
#include "common.h" #include "common.h"
#include "llama.h" #include "llama.h"
#include "grammar-parser.h"
#include <cmath> #include <cmath>
#include <cstdio> #include <cstdio>
@ -19,10 +18,7 @@ struct seq_draft {
std::vector<llama_token> tokens; std::vector<llama_token> tokens;
struct llama_grammar * grammar = NULL; struct llama_sampling_context * ctx_sampling;
std::vector<llama_token> last_tokens;
struct llama_sampling_context ctx_sampling;
}; };
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
@ -96,8 +92,6 @@ int main(int argc, char ** argv) {
const auto t_enc_end = ggml_time_us(); const auto t_enc_end = ggml_time_us();
// the 2 models should have the same vocab // the 2 models should have the same vocab
const int n_ctx = llama_n_ctx(ctx_tgt);
const int n_vocab = llama_n_vocab(model_tgt);
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft)); //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
// how many tokens to draft each time // how many tokens to draft each time
@ -113,69 +107,34 @@ int main(int argc, char ** argv) {
// used to determine end of generation // used to determine end of generation
bool has_eos = false; bool has_eos = false;
// grammar stuff
struct llama_grammar * grammar = NULL;
grammar_parser::parse_state parsed_grammar;
// if requested - load the grammar, error checking is omitted for brevity
if (!params.grammar.empty()) {
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
// will be empty (default) if there are parse errors
if (parsed_grammar.rules.empty()) {
return 1;
}
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
}
// target model sampling context // target model sampling context
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar); struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
// TODO: move to llama_sampling_state
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
std::vector<llama_token> last_tokens;
last_tokens.resize(n_ctx);
std::fill(last_tokens.begin(), last_tokens.end(), 0);
for (auto & id : inp) {
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
}
// draft sequence data // draft sequence data
std::vector<seq_draft> drafts(n_seq_dft); std::vector<seq_draft> drafts(n_seq_dft);
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
params.sampling_params.temp = 1.0f; // the draft samplers use default temperature
for (int i = 0; i < n_seq_dft; ++i) { for (int i = 0; i < n_seq_dft; ++i) {
{ drafts[i].ctx_sampling = llama_sampling_init(params);
auto & last_tokens = drafts[i].last_tokens;
last_tokens.resize(n_ctx);
std::fill(last_tokens.begin(), last_tokens.end(), 0);
for (auto & id : inp) {
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
}
} }
drafts[i].ctx_sampling = llama_sampling_context_init(params, grammar); llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
} llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
llama_batch batch_dft = llama_batch_init(512, 0, 1);
llama_batch batch_tgt = llama_batch_init(512, 0, n_seq_dft);
const auto t_dec_start = ggml_time_us(); const auto t_dec_start = ggml_time_us();
// sample from the last token of the prompt
drafts[0].i_batch_tgt.resize(1); drafts[0].i_batch_tgt.resize(1);
drafts[0].i_batch_tgt[0] = 0; drafts[0].i_batch_tgt[0] = 0;
while (true) { while (true) {
// print current draft sequences
for (int i = 0; i < n_seq_dft; ++i) { for (int i = 0; i < n_seq_dft; ++i) {
if (!drafts[i].active) continue; if (!drafts[i].active) {
continue;
}
const auto & tokens = drafts[i].tokens; const auto & tokens = drafts[i].tokens;
@ -189,11 +148,9 @@ int main(int argc, char ** argv) {
LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]); LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]);
// sample from the target model // sample from the target model
llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, drafts[i_keep].i_batch_tgt[i_dft]); llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[i_keep].i_batch_tgt[i_dft]);
// remember which tokens were sampled - used for repetition penalties during sampling llama_sampling_accept(ctx_sampling, ctx_tgt, id);
last_tokens.erase(last_tokens.begin());
last_tokens.push_back(id);
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens)); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
@ -213,7 +170,9 @@ int main(int argc, char ** argv) {
bool matches = false; bool matches = false;
for (int i = 0; i < n_seq_dft; ++i) { for (int i = 0; i < n_seq_dft; ++i) {
if (!drafts[i].active) continue; if (!drafts[i].active) {
continue;
}
if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) { if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) {
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str()); LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str());
@ -281,17 +240,19 @@ int main(int argc, char ** argv) {
break; break;
} }
if (grammar) {
for (int i = 0; i < n_seq_dft; ++i) { for (int i = 0; i < n_seq_dft; ++i) {
auto * grammar_dft = drafts[i].grammar; if (ctx_sampling->grammar) {
auto & grammar_dft = drafts[0].ctx_sampling->grammar;
if (grammar_dft) { if (grammar_dft) {
llama_grammar_free(grammar_dft); llama_grammar_free(grammar_dft);
} }
grammar_dft = llama_grammar_copy(ctx_sampling.grammar); grammar_dft = llama_grammar_copy(ctx_sampling->grammar);
LOG("copied target grammar to draft %d grammar\n", i); LOG("copied target grammar to draft %d grammar\n", 0);
} }
drafts[i].ctx_sampling->prev = ctx_sampling->prev;
} }
int n_seq_cur = 1; int n_seq_cur = 1;
@ -321,46 +282,32 @@ int main(int argc, char ** argv) {
} }
for (int s = 0; s < n_seq_dft; ++s) { for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].drafting || drafts[s].skip) continue; if (!drafts[s].drafting || drafts[s].skip) {
continue;
auto & grammar = drafts[s].grammar;
auto & i_batch_dft = drafts[s].i_batch_dft;
float * logits = llama_get_logits_ith(ctx_dft, i_batch_dft);
// TODO: optimize
candidates.clear();
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
} }
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
if (grammar != NULL) { const auto & cur_p = drafts[s].ctx_sampling->cur;
llama_sample_grammar(ctx_dft, &cur_p, grammar);
}
// computes softmax and sorts the candidates for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
llama_sample_softmax(ctx_dft, &cur_p);
for (int k = 0; k < 3; ++k) {
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, s, i, cur_p.data[k].id, cur_p.data[k].p, llama_token_to_piece(ctx_dft, cur_p.data[k].id).c_str()); k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
} }
// TODO: make this configurable // TODO: make this configurable
if (cur_p.data[0].p < 0.1) { if (cur_p[0].p < 0.4) {
//if (cur_p.data[0].p < 2*cur_p.data[1].p) { LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p.data[0].p, cur_p.data[1].p);
drafts[s].drafting = false; drafts[s].drafting = false;
continue; continue;
} }
std::vector<int> sa(1, s); std::vector<int> sa(1, s);
// attempt to split the branch if the probability is high enough
for (int f = 1; f < 8; ++f) { for (int f = 1; f < 8; ++f) {
// TODO: make this configurable // TODO: make this configurable
if (n_seq_cur < n_seq_dft && cur_p.data[f].p > 0.10) { if (n_seq_cur < n_seq_dft && cur_p[f].p > 0.3) {
LOG("splitting seq %3d into %3d\n", s, n_seq_cur); LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
@ -376,9 +323,18 @@ int main(int argc, char ** argv) {
} }
} }
drafts[n_seq_cur] = drafts[s]; // copy the draft state
drafts[n_seq_cur].active = true;
drafts[n_seq_cur].drafting = true;
drafts[n_seq_cur].skip = true; drafts[n_seq_cur].skip = true;
// TODO: grammar drafts[n_seq_cur].tokens = drafts[s].tokens;
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
if (ctx_sampling->grammar) {
drafts[n_seq_cur].ctx_sampling->grammar =
llama_grammar_copy(drafts[s].ctx_sampling->grammar);
}
sa.push_back(n_seq_cur); sa.push_back(n_seq_cur);
n_seq_cur++; n_seq_cur++;
@ -389,17 +345,17 @@ int main(int argc, char ** argv) {
// add drafted token for each sequence // add drafted token for each sequence
for (int is = 0; is < (int) sa.size(); ++is) { for (int is = 0; is < (int) sa.size(); ++is) {
const llama_token id = cur_p.data[is].id; const llama_token id = cur_p[is].id;
int s = sa[is]; int s = sa[is];
auto & drafted = drafts[s].tokens; auto & drafted = drafts[s].tokens;
//auto & grammar = drafts[s].grammar;
auto & i_batch_dft = drafts[s].i_batch_dft; auto & i_batch_dft = drafts[s].i_batch_dft;
auto & i_batch_tgt = drafts[s].i_batch_tgt; auto & i_batch_tgt = drafts[s].i_batch_tgt;
drafted.push_back(id); drafted.push_back(id);
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
// add unique drafted tokens to the target batch // add unique drafted tokens to the target batch
batch_tgt.token [batch_tgt.n_tokens] = id; batch_tgt.token [batch_tgt.n_tokens] = id;
@ -413,7 +369,7 @@ int main(int argc, char ** argv) {
batch_tgt.n_tokens++; batch_tgt.n_tokens++;
// no need to evaluate the last drafted token, since we won't use the result // no need to evaluate the last drafted token, since we won't use the result
if (i == n_draft - 1) { if (batch_tgt.n_tokens == n_draft) {
drafts[s].drafting = false; drafts[s].drafting = false;
continue; continue;
} }
@ -441,19 +397,6 @@ int main(int argc, char ** argv) {
++n_past_cur; ++n_past_cur;
++n_drafted; ++n_drafted;
// update grammar
for (int s = 0; s < n_seq_dft; ++s) {
if (!drafts[s].drafting) continue;
auto & drafted = drafts[s].tokens;
auto & grammar = drafts[s].grammar;
if (grammar != NULL) {
llama_grammar_accept_token(ctx_dft, grammar, drafted.back());
}
}
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
if (batch_tgt.n_tokens >= n_draft) { if (batch_tgt.n_tokens >= n_draft) {
break; break;
} }
@ -473,7 +416,9 @@ int main(int argc, char ** argv) {
// the first token is always proposed by the traget model before the speculation loop so we erase it here // the first token is always proposed by the traget model before the speculation loop so we erase it here
for (int i = 0; i < n_seq_dft; ++i) { for (int i = 0; i < n_seq_dft; ++i) {
if (!drafts[i].active) continue; if (!drafts[i].active) {
continue;
}
drafts[i].tokens.erase(drafts[i].tokens.begin()); drafts[i].tokens.erase(drafts[i].tokens.begin());
} }
@ -507,13 +452,11 @@ int main(int argc, char ** argv) {
llama_free(ctx_dft); llama_free(ctx_dft);
llama_free_model(model_dft); llama_free_model(model_dft);
if (grammar) { llama_sampling_free(ctx_sampling);
llama_grammar_free(grammar);
for (int i = 0; i < n_seq_dft; ++i) { for (int i = 0; i < n_seq_dft; ++i) {
llama_grammar_free(drafts[i].grammar); llama_sampling_free(drafts[i].ctx_sampling);
}
} }
llama_backend_free(); llama_backend_free();
fprintf(stderr, "\n\n"); fprintf(stderr, "\n\n");