mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
speculative : refactor sampling
This commit is contained in:
parent
32a67cbd16
commit
4a7f43f28c
@ -1,27 +1,69 @@
|
||||
#include "sampling.h"
|
||||
|
||||
llama_sampling_context llama_sampling_context_init(
|
||||
const struct gpt_params & params,
|
||||
llama_grammar * grammar) {
|
||||
llama_sampling_context result;
|
||||
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
|
||||
struct llama_sampling_context * result =
|
||||
(struct llama_sampling_context *) malloc(sizeof(struct llama_sampling_context));
|
||||
|
||||
result.params = params.sampling_params;
|
||||
result.grammar = grammar;
|
||||
result->params = params.sampling_params;
|
||||
result->grammar = nullptr;
|
||||
|
||||
// if there is a grammar, parse it
|
||||
if (!params.grammar.empty()) {
|
||||
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (result->parsed_grammar.rules.empty()) {
|
||||
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
||||
|
||||
result->grammar = llama_grammar_init(
|
||||
grammar_rules.data(),
|
||||
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
|
||||
result->prev.resize(params.n_ctx);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
llama_token llama_sampling_sample(
|
||||
struct llama_context * ctx,
|
||||
struct llama_context * ctx_guidance,
|
||||
struct llama_sampling_context & ctx_sampling,
|
||||
const std::vector<llama_token> & last_tokens,
|
||||
std::vector<llama_token_data> & candidates,
|
||||
const int idx) {
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
||||
void llama_sampling_free(struct llama_sampling_context * ctx) {
|
||||
if (ctx->grammar != NULL) {
|
||||
llama_grammar_free(ctx->grammar);
|
||||
}
|
||||
|
||||
free(ctx);
|
||||
}
|
||||
|
||||
void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||
if (ctx->grammar != NULL) {
|
||||
llama_grammar_free(ctx->grammar);
|
||||
}
|
||||
|
||||
if (!ctx->parsed_grammar.rules.empty()) {
|
||||
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
|
||||
|
||||
ctx->grammar = llama_grammar_init(
|
||||
grammar_rules.data(),
|
||||
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
|
||||
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||
ctx->cur.clear();
|
||||
}
|
||||
|
||||
llama_token llama_sampling_sample(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_guidance,
|
||||
const int idx) {
|
||||
const int n_ctx = llama_n_ctx(ctx_main);
|
||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||
|
||||
const llama_sampling_params & params = ctx_sampling->params;
|
||||
|
||||
const llama_sampling_params & params = ctx_sampling.params;
|
||||
const float temp = params.temp;
|
||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||
const float top_p = params.top_p;
|
||||
@ -36,41 +78,45 @@ llama_token llama_sampling_sample(
|
||||
const float mirostat_eta = params.mirostat_eta;
|
||||
const bool penalize_nl = params.penalize_nl;
|
||||
|
||||
auto & prev = ctx_sampling->prev;
|
||||
auto & cur = ctx_sampling->cur;
|
||||
|
||||
llama_token id = 0;
|
||||
|
||||
float * logits = llama_get_logits_ith(ctx, idx);
|
||||
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||
|
||||
// Apply params.logit_bias map
|
||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||
logits[it->first] += it->second;
|
||||
}
|
||||
|
||||
candidates.clear();
|
||||
cur.clear();
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
||||
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||
|
||||
if (ctx_guidance) {
|
||||
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
|
||||
llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_guidance, params.cfg_scale);
|
||||
}
|
||||
|
||||
// apply penalties
|
||||
if (!last_tokens.empty()) {
|
||||
const float nl_logit = logits[llama_token_nl(ctx)];
|
||||
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
|
||||
if (!prev.empty()) {
|
||||
const float nl_logit = logits[llama_token_nl(ctx_main)];
|
||||
const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
|
||||
|
||||
llama_sample_repetition_penalty(ctx, &cur_p,
|
||||
last_tokens.data() + last_tokens.size() - last_n_repeat,
|
||||
llama_sample_repetition_penalty(ctx_main, &cur_p,
|
||||
prev.data() + prev.size() - last_n_repeat,
|
||||
last_n_repeat, repeat_penalty);
|
||||
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
|
||||
last_tokens.data() + last_tokens.size() - last_n_repeat,
|
||||
llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
|
||||
prev.data() + prev.size() - last_n_repeat,
|
||||
last_n_repeat, alpha_frequency, alpha_presence);
|
||||
|
||||
if (!penalize_nl) {
|
||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
|
||||
if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
|
||||
cur_p.data[idx].logit = nl_logit;
|
||||
break;
|
||||
}
|
||||
@ -78,50 +124,58 @@ llama_token llama_sampling_sample(
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx_sampling.grammar != NULL) {
|
||||
llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar);
|
||||
if (ctx_sampling->grammar != NULL) {
|
||||
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
||||
}
|
||||
|
||||
if (temp <= 0) {
|
||||
// Greedy sampling
|
||||
id = llama_sample_token_greedy(ctx, &cur_p);
|
||||
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
||||
} else {
|
||||
if (mirostat == 1) {
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temp(ctx, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu);
|
||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
|
||||
} else if (mirostat == 2) {
|
||||
llama_sample_temp(ctx, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu);
|
||||
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
size_t min_keep = std::max(1, params.n_probs);
|
||||
llama_sample_top_k (ctx, &cur_p, top_k, min_keep);
|
||||
llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
|
||||
llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
|
||||
llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
|
||||
llama_sample_temp(ctx, &cur_p, temp);
|
||||
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
||||
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
||||
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
||||
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
||||
llama_sample_temp (ctx_main, &cur_p, temp);
|
||||
|
||||
{
|
||||
const int n_top = 10;
|
||||
LOG("top %d candidates:\n", n_top);
|
||||
id = llama_sample_token(ctx_main, &cur_p);
|
||||
|
||||
for (int i = 0; i < n_top; i++) {
|
||||
const llama_token id = cur_p.data[i].id;
|
||||
(void)id; // To avoid a warning that id is unused when logging is disabled.
|
||||
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
|
||||
//{
|
||||
// const int n_top = 10;
|
||||
// LOG("top %d candidates:\n", n_top);
|
||||
|
||||
// for (int i = 0; i < n_top; i++) {
|
||||
// const llama_token id = cur_p.data[i].id;
|
||||
// (void)id; // To avoid a warning that id is unused when logging is disabled.
|
||||
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
|
||||
// }
|
||||
//}
|
||||
|
||||
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
id = llama_sample_token(ctx, &cur_p);
|
||||
|
||||
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx_sampling.grammar != NULL) {
|
||||
llama_grammar_accept_token(ctx, ctx_sampling.grammar, id);
|
||||
}
|
||||
|
||||
return id;
|
||||
}
|
||||
|
||||
void llama_sampling_accept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
llama_token id) {
|
||||
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
||||
ctx_sampling->prev.push_back(id);
|
||||
|
||||
if (ctx_sampling->grammar != NULL) {
|
||||
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,8 @@
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
@ -35,7 +37,8 @@ typedef struct llama_sampling_params {
|
||||
} llama_sampling_params;
|
||||
|
||||
// general sampler context
|
||||
typedef struct llama_sampling_context {
|
||||
// TODO: move to llama.h
|
||||
struct llama_sampling_context {
|
||||
// parameters that will be used for sampling
|
||||
llama_sampling_params params;
|
||||
|
||||
@ -43,45 +46,50 @@ typedef struct llama_sampling_context {
|
||||
float mirostat_mu;
|
||||
|
||||
llama_grammar * grammar;
|
||||
} llama_sampling_context;
|
||||
|
||||
// internal
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
|
||||
std::vector<llama_token> prev;
|
||||
std::vector<llama_token_data> cur;
|
||||
};
|
||||
|
||||
#include "common.h"
|
||||
|
||||
// Create a new sampling context instance.
|
||||
llama_sampling_context llama_sampling_context_init(
|
||||
const struct gpt_params & params,
|
||||
llama_grammar * grammar = NULL);
|
||||
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
|
||||
|
||||
// Reset the sampler context for the supplied sequence id (defaults to 0).
|
||||
// This is necessary to reuse a sequence id or free memory used by sequences
|
||||
// that are no longer required.
|
||||
bool llama_sampling_context_reset(
|
||||
llama_sampling_context & ctx_sampling,
|
||||
const llama_seq_id seq = 0);
|
||||
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||
|
||||
// Reset the sampler context
|
||||
// - clear prev tokens
|
||||
// - reset grammar
|
||||
void llama_sampling_reset(llama_sampling_context * ctx);
|
||||
|
||||
// this is a common sampling function used across the examples for convenience
|
||||
// it can serve as a starting point for implementing your own sampling function
|
||||
// Note: When using multiple sequences, it is the caller's responsibility to call
|
||||
// llama_sampling_context_reset when a sequence ends
|
||||
// llama_sampling_reset when a sequence ends
|
||||
//
|
||||
// required:
|
||||
// - ctx: context to use for sampling
|
||||
// - ctx_main: context to use for sampling
|
||||
// - ctx_sampling: sampling-specific context
|
||||
//
|
||||
// optional:
|
||||
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
|
||||
// - last_tokens: needed for repetition penalty, ignore if empty
|
||||
// - ctx_guidance: context to use for guidance
|
||||
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
||||
// - seq: sequence id to associate sampler state with
|
||||
//
|
||||
// returns:
|
||||
// - token: sampled token
|
||||
// - candidates: vector of candidate tokens
|
||||
//
|
||||
llama_token llama_sampling_sample(
|
||||
struct llama_context * ctx,
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
struct llama_context * ctx_guidance,
|
||||
struct llama_sampling_context & ctx_sampling,
|
||||
const std::vector<llama_token> & last_tokens,
|
||||
std::vector<llama_token_data> & candidates,
|
||||
const int idx = 0);
|
||||
int idx = 0);
|
||||
|
||||
void llama_sampling_accept(
|
||||
struct llama_sampling_context * ctx_sampling,
|
||||
struct llama_context * ctx_main,
|
||||
llama_token id);
|
||||
|
@ -2,7 +2,6 @@
|
||||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
@ -19,10 +18,7 @@ struct seq_draft {
|
||||
|
||||
std::vector<llama_token> tokens;
|
||||
|
||||
struct llama_grammar * grammar = NULL;
|
||||
|
||||
std::vector<llama_token> last_tokens;
|
||||
struct llama_sampling_context ctx_sampling;
|
||||
struct llama_sampling_context * ctx_sampling;
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
@ -96,8 +92,6 @@ int main(int argc, char ** argv) {
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
// the 2 models should have the same vocab
|
||||
const int n_ctx = llama_n_ctx(ctx_tgt);
|
||||
const int n_vocab = llama_n_vocab(model_tgt);
|
||||
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
|
||||
|
||||
// how many tokens to draft each time
|
||||
@ -113,69 +107,34 @@ int main(int argc, char ** argv) {
|
||||
// used to determine end of generation
|
||||
bool has_eos = false;
|
||||
|
||||
// grammar stuff
|
||||
struct llama_grammar * grammar = NULL;
|
||||
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
|
||||
// if requested - load the grammar, error checking is omitted for brevity
|
||||
if (!params.grammar.empty()) {
|
||||
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||
// will be empty (default) if there are parse errors
|
||||
if (parsed_grammar.rules.empty()) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
}
|
||||
|
||||
// target model sampling context
|
||||
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
|
||||
|
||||
// TODO: move to llama_sampling_state
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
std::vector<llama_token> last_tokens;
|
||||
last_tokens.resize(n_ctx);
|
||||
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
||||
|
||||
for (auto & id : inp) {
|
||||
last_tokens.erase(last_tokens.begin());
|
||||
last_tokens.push_back(id);
|
||||
}
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
|
||||
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
|
||||
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
||||
params.sampling_params.temp = 1.0f; // the draft samplers use default temperature
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
{
|
||||
auto & last_tokens = drafts[i].last_tokens;
|
||||
|
||||
last_tokens.resize(n_ctx);
|
||||
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
||||
|
||||
for (auto & id : inp) {
|
||||
last_tokens.erase(last_tokens.begin());
|
||||
last_tokens.push_back(id);
|
||||
}
|
||||
drafts[i].ctx_sampling = llama_sampling_init(params);
|
||||
}
|
||||
|
||||
drafts[i].ctx_sampling = llama_sampling_context_init(params, grammar);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(512, 0, 1);
|
||||
llama_batch batch_tgt = llama_batch_init(512, 0, n_seq_dft);
|
||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
|
||||
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
// sample from the last token of the prompt
|
||||
drafts[0].i_batch_tgt.resize(1);
|
||||
drafts[0].i_batch_tgt[0] = 0;
|
||||
|
||||
while (true) {
|
||||
// print current draft sequences
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
if (!drafts[i].active) continue;
|
||||
if (!drafts[i].active) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto & tokens = drafts[i].tokens;
|
||||
|
||||
@ -189,11 +148,9 @@ int main(int argc, char ** argv) {
|
||||
LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
// sample from the target model
|
||||
llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
last_tokens.erase(last_tokens.begin());
|
||||
last_tokens.push_back(id);
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, id);
|
||||
|
||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
|
||||
|
||||
@ -213,7 +170,9 @@ int main(int argc, char ** argv) {
|
||||
bool matches = false;
|
||||
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
if (!drafts[i].active) continue;
|
||||
if (!drafts[i].active) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) {
|
||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str());
|
||||
@ -263,11 +222,11 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
batch_dft.n_tokens = 1;
|
||||
|
||||
batch_dft.token[0] = id;
|
||||
batch_dft.pos[0] = n_past_dft;
|
||||
batch_dft.token [0] = id;
|
||||
batch_dft.pos [0] = n_past_dft;
|
||||
batch_dft.n_seq_id[0] = 1;
|
||||
batch_dft.seq_id[0][0] = 0;
|
||||
batch_dft.logits[0] = true;
|
||||
batch_dft.seq_id [0][0] = 0;
|
||||
batch_dft.logits [0] = true;
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
@ -281,17 +240,19 @@ int main(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (grammar) {
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
auto * grammar_dft = drafts[i].grammar;
|
||||
if (ctx_sampling->grammar) {
|
||||
auto & grammar_dft = drafts[0].ctx_sampling->grammar;
|
||||
if (grammar_dft) {
|
||||
llama_grammar_free(grammar_dft);
|
||||
}
|
||||
|
||||
grammar_dft = llama_grammar_copy(ctx_sampling.grammar);
|
||||
grammar_dft = llama_grammar_copy(ctx_sampling->grammar);
|
||||
|
||||
LOG("copied target grammar to draft %d grammar\n", i);
|
||||
LOG("copied target grammar to draft %d grammar\n", 0);
|
||||
}
|
||||
|
||||
drafts[i].ctx_sampling->prev = ctx_sampling->prev;
|
||||
}
|
||||
|
||||
int n_seq_cur = 1;
|
||||
@ -306,11 +267,11 @@ int main(int argc, char ** argv) {
|
||||
drafts[0].i_batch_dft = 0;
|
||||
|
||||
batch_tgt.n_tokens = 1;
|
||||
batch_tgt.token[0] = drafts[0].tokens[0];
|
||||
batch_tgt.pos[0] = n_past_tgt;
|
||||
batch_tgt.token [0] = drafts[0].tokens[0];
|
||||
batch_tgt.pos [0] = n_past_tgt;
|
||||
batch_tgt.n_seq_id[0] = 1;
|
||||
batch_tgt.seq_id[0][0] = 0;
|
||||
batch_tgt.logits[0] = true;
|
||||
batch_tgt.seq_id [0][0] = 0;
|
||||
batch_tgt.logits [0] = true;
|
||||
|
||||
// sample n_draft tokens from the draft model using tree-based sampling
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
@ -321,46 +282,32 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].drafting || drafts[s].skip) continue;
|
||||
|
||||
auto & grammar = drafts[s].grammar;
|
||||
auto & i_batch_dft = drafts[s].i_batch_dft;
|
||||
|
||||
float * logits = llama_get_logits_ith(ctx_dft, i_batch_dft);
|
||||
|
||||
// TODO: optimize
|
||||
candidates.clear();
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
if (!drafts[s].drafting || drafts[s].skip) {
|
||||
continue;
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
||||
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
|
||||
|
||||
if (grammar != NULL) {
|
||||
llama_sample_grammar(ctx_dft, &cur_p, grammar);
|
||||
}
|
||||
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
||||
|
||||
// computes softmax and sorts the candidates
|
||||
llama_sample_softmax(ctx_dft, &cur_p);
|
||||
|
||||
for (int k = 0; k < 3; ++k) {
|
||||
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
|
||||
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
k, s, i, cur_p.data[k].id, cur_p.data[k].p, llama_token_to_piece(ctx_dft, cur_p.data[k].id).c_str());
|
||||
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
||||
}
|
||||
|
||||
// TODO: make this configurable
|
||||
if (cur_p.data[0].p < 0.1) {
|
||||
//if (cur_p.data[0].p < 2*cur_p.data[1].p) {
|
||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p.data[0].p, cur_p.data[1].p);
|
||||
if (cur_p[0].p < 0.4) {
|
||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
|
||||
drafts[s].drafting = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<int> sa(1, s);
|
||||
|
||||
// attempt to split the branch if the probability is high enough
|
||||
for (int f = 1; f < 8; ++f) {
|
||||
// TODO: make this configurable
|
||||
if (n_seq_cur < n_seq_dft && cur_p.data[f].p > 0.10) {
|
||||
if (n_seq_cur < n_seq_dft && cur_p[f].p > 0.3) {
|
||||
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||
@ -376,9 +323,18 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
}
|
||||
|
||||
drafts[n_seq_cur] = drafts[s];
|
||||
// copy the draft state
|
||||
drafts[n_seq_cur].active = true;
|
||||
drafts[n_seq_cur].drafting = true;
|
||||
drafts[n_seq_cur].skip = true;
|
||||
// TODO: grammar
|
||||
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
||||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
|
||||
if (ctx_sampling->grammar) {
|
||||
drafts[n_seq_cur].ctx_sampling->grammar =
|
||||
llama_grammar_copy(drafts[s].ctx_sampling->grammar);
|
||||
}
|
||||
|
||||
sa.push_back(n_seq_cur);
|
||||
n_seq_cur++;
|
||||
@ -389,17 +345,17 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// add drafted token for each sequence
|
||||
for (int is = 0; is < (int) sa.size(); ++is) {
|
||||
const llama_token id = cur_p.data[is].id;
|
||||
const llama_token id = cur_p[is].id;
|
||||
|
||||
int s = sa[is];
|
||||
|
||||
auto & drafted = drafts[s].tokens;
|
||||
//auto & grammar = drafts[s].grammar;
|
||||
|
||||
auto & i_batch_dft = drafts[s].i_batch_dft;
|
||||
auto & i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
|
||||
drafted.push_back(id);
|
||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
batch_tgt.token [batch_tgt.n_tokens] = id;
|
||||
@ -413,7 +369,7 @@ int main(int argc, char ** argv) {
|
||||
batch_tgt.n_tokens++;
|
||||
|
||||
// no need to evaluate the last drafted token, since we won't use the result
|
||||
if (i == n_draft - 1) {
|
||||
if (batch_tgt.n_tokens == n_draft) {
|
||||
drafts[s].drafting = false;
|
||||
continue;
|
||||
}
|
||||
@ -441,19 +397,6 @@ int main(int argc, char ** argv) {
|
||||
++n_past_cur;
|
||||
++n_drafted;
|
||||
|
||||
// update grammar
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
if (!drafts[s].drafting) continue;
|
||||
|
||||
auto & drafted = drafts[s].tokens;
|
||||
auto & grammar = drafts[s].grammar;
|
||||
|
||||
if (grammar != NULL) {
|
||||
llama_grammar_accept_token(ctx_dft, grammar, drafted.back());
|
||||
}
|
||||
}
|
||||
|
||||
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
||||
if (batch_tgt.n_tokens >= n_draft) {
|
||||
break;
|
||||
}
|
||||
@ -473,7 +416,9 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// the first token is always proposed by the traget model before the speculation loop so we erase it here
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
if (!drafts[i].active) continue;
|
||||
if (!drafts[i].active) {
|
||||
continue;
|
||||
}
|
||||
|
||||
drafts[i].tokens.erase(drafts[i].tokens.begin());
|
||||
}
|
||||
@ -507,13 +452,11 @@ int main(int argc, char ** argv) {
|
||||
llama_free(ctx_dft);
|
||||
llama_free_model(model_dft);
|
||||
|
||||
if (grammar) {
|
||||
llama_grammar_free(grammar);
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
for (int i = 0; i < n_seq_dft; ++i) {
|
||||
llama_grammar_free(drafts[i].grammar);
|
||||
}
|
||||
llama_sampling_free(drafts[i].ctx_sampling);
|
||||
}
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
|
Loading…
Reference in New Issue
Block a user