From 4a7f43f28caee3da6890d3447a48126ead86313f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 15 Oct 2023 22:30:59 +0300 Subject: [PATCH] speculative : refactor sampling --- common/sampling.cpp | 166 +++++++++++++++-------- common/sampling.h | 54 ++++---- examples/speculative/speculative.cpp | 195 ++++++++++----------------- 3 files changed, 210 insertions(+), 205 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index e0704713f..ed636f415 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -1,27 +1,69 @@ #include "sampling.h" -llama_sampling_context llama_sampling_context_init( - const struct gpt_params & params, - llama_grammar * grammar) { - llama_sampling_context result; +struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) { + struct llama_sampling_context * result = + (struct llama_sampling_context *) malloc(sizeof(struct llama_sampling_context)); - result.params = params.sampling_params; - result.grammar = grammar; + result->params = params.sampling_params; + result->grammar = nullptr; + + // if there is a grammar, parse it + if (!params.grammar.empty()) { + result->parsed_grammar = grammar_parser::parse(params.grammar.c_str()); + + // will be empty (default) if there are parse errors + if (result->parsed_grammar.rules.empty()) { + fprintf(stderr, "%s: failed to parse grammar\n", __func__); + return nullptr; + } + + std::vector grammar_rules(result->parsed_grammar.c_rules()); + + result->grammar = llama_grammar_init( + grammar_rules.data(), + grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root")); + } + + result->prev.resize(params.n_ctx); return result; } -llama_token llama_sampling_sample( - struct llama_context * ctx, - struct llama_context * ctx_guidance, - struct llama_sampling_context & ctx_sampling, - const std::vector & last_tokens, - std::vector & candidates, - const int idx) { - const int n_ctx = llama_n_ctx(ctx); - const int n_vocab = llama_n_vocab(llama_get_model(ctx)); +void llama_sampling_free(struct llama_sampling_context * ctx) { + if (ctx->grammar != NULL) { + llama_grammar_free(ctx->grammar); + } + + free(ctx); +} + +void llama_sampling_reset(llama_sampling_context * ctx) { + if (ctx->grammar != NULL) { + llama_grammar_free(ctx->grammar); + } + + if (!ctx->parsed_grammar.rules.empty()) { + std::vector grammar_rules(ctx->parsed_grammar.c_rules()); + + ctx->grammar = llama_grammar_init( + grammar_rules.data(), + grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root")); + } + + std::fill(ctx->prev.begin(), ctx->prev.end(), 0); + ctx->cur.clear(); +} + +llama_token llama_sampling_sample( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_guidance, + const int idx) { + const int n_ctx = llama_n_ctx(ctx_main); + const int n_vocab = llama_n_vocab(llama_get_model(ctx_main)); + + const llama_sampling_params & params = ctx_sampling->params; - const llama_sampling_params & params = ctx_sampling.params; const float temp = params.temp; const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k; const float top_p = params.top_p; @@ -36,41 +78,45 @@ llama_token llama_sampling_sample( const float mirostat_eta = params.mirostat_eta; const bool penalize_nl = params.penalize_nl; + auto & prev = ctx_sampling->prev; + auto & cur = ctx_sampling->cur; + llama_token id = 0; - float * logits = llama_get_logits_ith(ctx, idx); + float * logits = llama_get_logits_ith(ctx_main, idx); // Apply params.logit_bias map for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) { logits[it->first] += it->second; } - candidates.clear(); + cur.clear(); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); } - llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + llama_token_data_array cur_p = { cur.data(), cur.size(), false }; if (ctx_guidance) { - llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale); + llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_guidance, params.cfg_scale); } // apply penalties - if (!last_tokens.empty()) { - const float nl_logit = logits[llama_token_nl(ctx)]; - const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx); + if (!prev.empty()) { + const float nl_logit = logits[llama_token_nl(ctx_main)]; + const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx); - llama_sample_repetition_penalty(ctx, &cur_p, - last_tokens.data() + last_tokens.size() - last_n_repeat, + llama_sample_repetition_penalty(ctx_main, &cur_p, + prev.data() + prev.size() - last_n_repeat, last_n_repeat, repeat_penalty); - llama_sample_frequency_and_presence_penalties(ctx, &cur_p, - last_tokens.data() + last_tokens.size() - last_n_repeat, + llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p, + prev.data() + prev.size() - last_n_repeat, last_n_repeat, alpha_frequency, alpha_presence); if (!penalize_nl) { for (size_t idx = 0; idx < cur_p.size; idx++) { - if (cur_p.data[idx].id == llama_token_nl(ctx)) { + if (cur_p.data[idx].id == llama_token_nl(ctx_main)) { cur_p.data[idx].logit = nl_logit; break; } @@ -78,50 +124,58 @@ llama_token llama_sampling_sample( } } - if (ctx_sampling.grammar != NULL) { - llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar); + if (ctx_sampling->grammar != NULL) { + llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); } if (temp <= 0) { // Greedy sampling - id = llama_sample_token_greedy(ctx, &cur_p); + id = llama_sample_token_greedy(ctx_main, &cur_p); } else { if (mirostat == 1) { const int mirostat_m = 100; - llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu); + llama_sample_temp(ctx_main, &cur_p, temp); + id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu); } else if (mirostat == 2) { - llama_sample_temp(ctx, &cur_p, temp); - id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu); + llama_sample_temp(ctx_main, &cur_p, temp); + id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu); } else { // Temperature sampling size_t min_keep = std::max(1, params.n_probs); - llama_sample_top_k (ctx, &cur_p, top_k, min_keep); - llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep); - llama_sample_typical (ctx, &cur_p, typical_p, min_keep); - llama_sample_top_p (ctx, &cur_p, top_p, min_keep); - llama_sample_temp(ctx, &cur_p, temp); + llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep); + llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep); + llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep); + llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep); + llama_sample_temp (ctx_main, &cur_p, temp); - { - const int n_top = 10; - LOG("top %d candidates:\n", n_top); + id = llama_sample_token(ctx_main, &cur_p); - for (int i = 0; i < n_top; i++) { - const llama_token id = cur_p.data[i].id; - (void)id; // To avoid a warning that id is unused when logging is disabled. - LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p); - } - } + //{ + // const int n_top = 10; + // LOG("top %d candidates:\n", n_top); - id = llama_sample_token(ctx, &cur_p); + // for (int i = 0; i < n_top; i++) { + // const llama_token id = cur_p.data[i].id; + // (void)id; // To avoid a warning that id is unused when logging is disabled. + // LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p); + // } + //} - LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str()); + LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str()); } } - if (ctx_sampling.grammar != NULL) { - llama_grammar_accept_token(ctx, ctx_sampling.grammar, id); - } - return id; } + +void llama_sampling_accept( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + llama_token id) { + ctx_sampling->prev.erase(ctx_sampling->prev.begin()); + ctx_sampling->prev.push_back(id); + + if (ctx_sampling->grammar != NULL) { + llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id); + } +} diff --git a/common/sampling.h b/common/sampling.h index fda5902a8..32deb26b0 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -2,6 +2,8 @@ #include "llama.h" +#include "grammar-parser.h" + #include #include #include @@ -35,7 +37,8 @@ typedef struct llama_sampling_params { } llama_sampling_params; // general sampler context -typedef struct llama_sampling_context { +// TODO: move to llama.h +struct llama_sampling_context { // parameters that will be used for sampling llama_sampling_params params; @@ -43,45 +46,50 @@ typedef struct llama_sampling_context { float mirostat_mu; llama_grammar * grammar; -} llama_sampling_context; + + // internal + grammar_parser::parse_state parsed_grammar; + + std::vector prev; + std::vector cur; +}; #include "common.h" // Create a new sampling context instance. -llama_sampling_context llama_sampling_context_init( - const struct gpt_params & params, - llama_grammar * grammar = NULL); +struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params); -// Reset the sampler context for the supplied sequence id (defaults to 0). -// This is necessary to reuse a sequence id or free memory used by sequences -// that are no longer required. -bool llama_sampling_context_reset( - llama_sampling_context & ctx_sampling, - const llama_seq_id seq = 0); +void llama_sampling_free(struct llama_sampling_context * ctx); + +// Reset the sampler context +// - clear prev tokens +// - reset grammar +void llama_sampling_reset(llama_sampling_context * ctx); // this is a common sampling function used across the examples for convenience // it can serve as a starting point for implementing your own sampling function // Note: When using multiple sequences, it is the caller's responsibility to call -// llama_sampling_context_reset when a sequence ends +// llama_sampling_reset when a sequence ends // // required: -// - ctx: context to use for sampling +// - ctx_main: context to use for sampling // - ctx_sampling: sampling-specific context // // optional: -// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL -// - last_tokens: needed for repetition penalty, ignore if empty -// - idx: sample from llama_get_logits_ith(ctx, idx) -// - seq: sequence id to associate sampler state with +// - ctx_guidance: context to use for guidance +// - idx: sample from llama_get_logits_ith(ctx, idx) // // returns: // - token: sampled token // - candidates: vector of candidate tokens // llama_token llama_sampling_sample( - struct llama_context * ctx, - struct llama_context * ctx_guidance, - struct llama_sampling_context & ctx_sampling, - const std::vector & last_tokens, - std::vector & candidates, - const int idx = 0); + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + struct llama_context * ctx_guidance, + int idx = 0); + +void llama_sampling_accept( + struct llama_sampling_context * ctx_sampling, + struct llama_context * ctx_main, + llama_token id); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index e552e0593..dadd3115b 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -2,7 +2,6 @@ #include "common.h" #include "llama.h" -#include "grammar-parser.h" #include #include @@ -19,10 +18,7 @@ struct seq_draft { std::vector tokens; - struct llama_grammar * grammar = NULL; - - std::vector last_tokens; - struct llama_sampling_context ctx_sampling; + struct llama_sampling_context * ctx_sampling; }; int main(int argc, char ** argv) { @@ -96,8 +92,6 @@ int main(int argc, char ** argv) { const auto t_enc_end = ggml_time_us(); // the 2 models should have the same vocab - const int n_ctx = llama_n_ctx(ctx_tgt); - const int n_vocab = llama_n_vocab(model_tgt); //GGML_ASSERT(n_vocab == llama_n_vocab(model_dft)); // how many tokens to draft each time @@ -113,87 +107,50 @@ int main(int argc, char ** argv) { // used to determine end of generation bool has_eos = false; - // grammar stuff - struct llama_grammar * grammar = NULL; - - grammar_parser::parse_state parsed_grammar; - - // if requested - load the grammar, error checking is omitted for brevity - if (!params.grammar.empty()) { - parsed_grammar = grammar_parser::parse(params.grammar.c_str()); - // will be empty (default) if there are parse errors - if (parsed_grammar.rules.empty()) { - return 1; - } - - std::vector grammar_rules(parsed_grammar.c_rules()); - grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root")); - } - // target model sampling context - llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar); - - // TODO: move to llama_sampling_state - std::vector candidates; - candidates.reserve(n_vocab); - - std::vector last_tokens; - last_tokens.resize(n_ctx); - std::fill(last_tokens.begin(), last_tokens.end(), 0); - - for (auto & id : inp) { - last_tokens.erase(last_tokens.begin()); - last_tokens.push_back(id); - } + struct llama_sampling_context * ctx_sampling = llama_sampling_init(params); // draft sequence data std::vector drafts(n_seq_dft); + params.grammar.clear(); // the draft samplers will copy the target sampler's grammar + params.sampling_params.temp = 1.0f; // the draft samplers use default temperature + for (int i = 0; i < n_seq_dft; ++i) { - { - auto & last_tokens = drafts[i].last_tokens; - - last_tokens.resize(n_ctx); - std::fill(last_tokens.begin(), last_tokens.end(), 0); - - for (auto & id : inp) { - last_tokens.erase(last_tokens.begin()); - last_tokens.push_back(id); - } - } - - drafts[i].ctx_sampling = llama_sampling_context_init(params, grammar); + drafts[i].ctx_sampling = llama_sampling_init(params); } - llama_batch batch_dft = llama_batch_init(512, 0, 1); - llama_batch batch_tgt = llama_batch_init(512, 0, n_seq_dft); + llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1); + llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft); const auto t_dec_start = ggml_time_us(); + // sample from the last token of the prompt drafts[0].i_batch_tgt.resize(1); drafts[0].i_batch_tgt[0] = 0; while (true) { + // print current draft sequences for (int i = 0; i < n_seq_dft; ++i) { - if (!drafts[i].active) continue; + if (!drafts[i].active) { + continue; + } const auto & tokens = drafts[i].tokens; LOG("draft %d: %s\n", i, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens)); } - int i_dft = 0; + int i_dft = 0; int i_keep = 0; while (true) { LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]); // sample from the target model - llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, drafts[i_keep].i_batch_tgt[i_dft]); + llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[i_keep].i_batch_tgt[i_dft]); - // remember which tokens were sampled - used for repetition penalties during sampling - last_tokens.erase(last_tokens.begin()); - last_tokens.push_back(id); + llama_sampling_accept(ctx_sampling, ctx_tgt, id); //LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens)); @@ -213,7 +170,9 @@ int main(int argc, char ** argv) { bool matches = false; for (int i = 0; i < n_seq_dft; ++i) { - if (!drafts[i].active) continue; + if (!drafts[i].active) { + continue; + } if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) { LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str()); @@ -263,11 +222,11 @@ int main(int argc, char ** argv) { { batch_dft.n_tokens = 1; - batch_dft.token[0] = id; - batch_dft.pos[0] = n_past_dft; - batch_dft.n_seq_id[0] = 1; - batch_dft.seq_id[0][0] = 0; - batch_dft.logits[0] = true; + batch_dft.token [0] = id; + batch_dft.pos [0] = n_past_dft; + batch_dft.n_seq_id[0] = 1; + batch_dft.seq_id [0][0] = 0; + batch_dft.logits [0] = true; } llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); @@ -281,17 +240,19 @@ int main(int argc, char ** argv) { break; } - if (grammar) { - for (int i = 0; i < n_seq_dft; ++i) { - auto * grammar_dft = drafts[i].grammar; + for (int i = 0; i < n_seq_dft; ++i) { + if (ctx_sampling->grammar) { + auto & grammar_dft = drafts[0].ctx_sampling->grammar; if (grammar_dft) { llama_grammar_free(grammar_dft); } - grammar_dft = llama_grammar_copy(ctx_sampling.grammar); + grammar_dft = llama_grammar_copy(ctx_sampling->grammar); - LOG("copied target grammar to draft %d grammar\n", i); + LOG("copied target grammar to draft %d grammar\n", 0); } + + drafts[i].ctx_sampling->prev = ctx_sampling->prev; } int n_seq_cur = 1; @@ -305,12 +266,12 @@ int main(int argc, char ** argv) { drafts[0].drafting = true; drafts[0].i_batch_dft = 0; - batch_tgt.n_tokens = 1; - batch_tgt.token[0] = drafts[0].tokens[0]; - batch_tgt.pos[0] = n_past_tgt; - batch_tgt.n_seq_id[0] = 1; - batch_tgt.seq_id[0][0] = 0; - batch_tgt.logits[0] = true; + batch_tgt.n_tokens = 1; + batch_tgt.token [0] = drafts[0].tokens[0]; + batch_tgt.pos [0] = n_past_tgt; + batch_tgt.n_seq_id[0] = 1; + batch_tgt.seq_id [0][0] = 0; + batch_tgt.logits [0] = true; // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { @@ -321,46 +282,32 @@ int main(int argc, char ** argv) { } for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].drafting || drafts[s].skip) continue; - - auto & grammar = drafts[s].grammar; - auto & i_batch_dft = drafts[s].i_batch_dft; - - float * logits = llama_get_logits_ith(ctx_dft, i_batch_dft); - - // TODO: optimize - candidates.clear(); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f}); + if (!drafts[s].drafting || drafts[s].skip) { + continue; } - llama_token_data_array cur_p = { candidates.data(), candidates.size(), false }; + llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft); - if (grammar != NULL) { - llama_sample_grammar(ctx_dft, &cur_p, grammar); - } + const auto & cur_p = drafts[s].ctx_sampling->cur; - // computes softmax and sorts the candidates - llama_sample_softmax(ctx_dft, &cur_p); - - for (int k = 0; k < 3; ++k) { + for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) { LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, s, i, cur_p.data[k].id, cur_p.data[k].p, llama_token_to_piece(ctx_dft, cur_p.data[k].id).c_str()); + k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str()); } // TODO: make this configurable - if (cur_p.data[0].p < 0.1) { - //if (cur_p.data[0].p < 2*cur_p.data[1].p) { - LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p.data[0].p, cur_p.data[1].p); + if (cur_p[0].p < 0.4) { + LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p); drafts[s].drafting = false; continue; } std::vector sa(1, s); + // attempt to split the branch if the probability is high enough for (int f = 1; f < 8; ++f) { // TODO: make this configurable - if (n_seq_cur < n_seq_dft && cur_p.data[f].p > 0.10) { + if (n_seq_cur < n_seq_dft && cur_p[f].p > 0.3) { LOG("splitting seq %3d into %3d\n", s, n_seq_cur); llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); @@ -376,9 +323,18 @@ int main(int argc, char ** argv) { } } - drafts[n_seq_cur] = drafts[s]; + // copy the draft state + drafts[n_seq_cur].active = true; + drafts[n_seq_cur].drafting = true; drafts[n_seq_cur].skip = true; - // TODO: grammar + drafts[n_seq_cur].tokens = drafts[s].tokens; + drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft; + drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt; + + if (ctx_sampling->grammar) { + drafts[n_seq_cur].ctx_sampling->grammar = + llama_grammar_copy(drafts[s].ctx_sampling->grammar); + } sa.push_back(n_seq_cur); n_seq_cur++; @@ -389,17 +345,17 @@ int main(int argc, char ** argv) { // add drafted token for each sequence for (int is = 0; is < (int) sa.size(); ++is) { - const llama_token id = cur_p.data[is].id; + const llama_token id = cur_p[is].id; int s = sa[is]; auto & drafted = drafts[s].tokens; - //auto & grammar = drafts[s].grammar; auto & i_batch_dft = drafts[s].i_batch_dft; auto & i_batch_tgt = drafts[s].i_batch_tgt; drafted.push_back(id); + llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id); // add unique drafted tokens to the target batch batch_tgt.token [batch_tgt.n_tokens] = id; @@ -413,7 +369,7 @@ int main(int argc, char ** argv) { batch_tgt.n_tokens++; // no need to evaluate the last drafted token, since we won't use the result - if (i == n_draft - 1) { + if (batch_tgt.n_tokens == n_draft) { drafts[s].drafting = false; continue; } @@ -441,19 +397,6 @@ int main(int argc, char ** argv) { ++n_past_cur; ++n_drafted; - // update grammar - for (int s = 0; s < n_seq_dft; ++s) { - if (!drafts[s].drafting) continue; - - auto & drafted = drafts[s].tokens; - auto & grammar = drafts[s].grammar; - - if (grammar != NULL) { - llama_grammar_accept_token(ctx_dft, grammar, drafted.back()); - } - } - - // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! if (batch_tgt.n_tokens >= n_draft) { break; } @@ -473,7 +416,9 @@ int main(int argc, char ** argv) { // the first token is always proposed by the traget model before the speculation loop so we erase it here for (int i = 0; i < n_seq_dft; ++i) { - if (!drafts[i].active) continue; + if (!drafts[i].active) { + continue; + } drafts[i].tokens.erase(drafts[i].tokens.begin()); } @@ -484,7 +429,7 @@ int main(int argc, char ** argv) { LOG_TEE("\n\n"); LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f)); - LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); + LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); LOG_TEE("\n"); LOG_TEE("n_draft = %d\n", n_draft); @@ -507,13 +452,11 @@ int main(int argc, char ** argv) { llama_free(ctx_dft); llama_free_model(model_dft); - if (grammar) { - llama_grammar_free(grammar); - - for (int i = 0; i < n_seq_dft; ++i) { - llama_grammar_free(drafts[i].grammar); - } + llama_sampling_free(ctx_sampling); + for (int i = 0; i < n_seq_dft; ++i) { + llama_sampling_free(drafts[i].ctx_sampling); } + llama_backend_free(); fprintf(stderr, "\n\n");