mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 04:00:16 +00:00
speculative : refactor sampling
This commit is contained in:
parent
32a67cbd16
commit
4a7f43f28c
@ -1,27 +1,69 @@
|
|||||||
#include "sampling.h"
|
#include "sampling.h"
|
||||||
|
|
||||||
llama_sampling_context llama_sampling_context_init(
|
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params) {
|
||||||
const struct gpt_params & params,
|
struct llama_sampling_context * result =
|
||||||
llama_grammar * grammar) {
|
(struct llama_sampling_context *) malloc(sizeof(struct llama_sampling_context));
|
||||||
llama_sampling_context result;
|
|
||||||
|
|
||||||
result.params = params.sampling_params;
|
result->params = params.sampling_params;
|
||||||
result.grammar = grammar;
|
result->grammar = nullptr;
|
||||||
|
|
||||||
|
// if there is a grammar, parse it
|
||||||
|
if (!params.grammar.empty()) {
|
||||||
|
result->parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
||||||
|
|
||||||
|
// will be empty (default) if there are parse errors
|
||||||
|
if (result->parsed_grammar.rules.empty()) {
|
||||||
|
fprintf(stderr, "%s: failed to parse grammar\n", __func__);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<const llama_grammar_element *> grammar_rules(result->parsed_grammar.c_rules());
|
||||||
|
|
||||||
|
result->grammar = llama_grammar_init(
|
||||||
|
grammar_rules.data(),
|
||||||
|
grammar_rules.size(), result->parsed_grammar.symbol_ids.at("root"));
|
||||||
|
}
|
||||||
|
|
||||||
|
result->prev.resize(params.n_ctx);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token llama_sampling_sample(
|
void llama_sampling_free(struct llama_sampling_context * ctx) {
|
||||||
struct llama_context * ctx,
|
if (ctx->grammar != NULL) {
|
||||||
struct llama_context * ctx_guidance,
|
llama_grammar_free(ctx->grammar);
|
||||||
struct llama_sampling_context & ctx_sampling,
|
}
|
||||||
const std::vector<llama_token> & last_tokens,
|
|
||||||
std::vector<llama_token_data> & candidates,
|
free(ctx);
|
||||||
const int idx) {
|
}
|
||||||
const int n_ctx = llama_n_ctx(ctx);
|
|
||||||
const int n_vocab = llama_n_vocab(llama_get_model(ctx));
|
void llama_sampling_reset(llama_sampling_context * ctx) {
|
||||||
|
if (ctx->grammar != NULL) {
|
||||||
|
llama_grammar_free(ctx->grammar);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ctx->parsed_grammar.rules.empty()) {
|
||||||
|
std::vector<const llama_grammar_element *> grammar_rules(ctx->parsed_grammar.c_rules());
|
||||||
|
|
||||||
|
ctx->grammar = llama_grammar_init(
|
||||||
|
grammar_rules.data(),
|
||||||
|
grammar_rules.size(), ctx->parsed_grammar.symbol_ids.at("root"));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::fill(ctx->prev.begin(), ctx->prev.end(), 0);
|
||||||
|
ctx->cur.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_token llama_sampling_sample(
|
||||||
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
struct llama_context * ctx_guidance,
|
||||||
|
const int idx) {
|
||||||
|
const int n_ctx = llama_n_ctx(ctx_main);
|
||||||
|
const int n_vocab = llama_n_vocab(llama_get_model(ctx_main));
|
||||||
|
|
||||||
|
const llama_sampling_params & params = ctx_sampling->params;
|
||||||
|
|
||||||
const llama_sampling_params & params = ctx_sampling.params;
|
|
||||||
const float temp = params.temp;
|
const float temp = params.temp;
|
||||||
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
const int32_t top_k = params.top_k <= 0 ? n_vocab : params.top_k;
|
||||||
const float top_p = params.top_p;
|
const float top_p = params.top_p;
|
||||||
@ -36,41 +78,45 @@ llama_token llama_sampling_sample(
|
|||||||
const float mirostat_eta = params.mirostat_eta;
|
const float mirostat_eta = params.mirostat_eta;
|
||||||
const bool penalize_nl = params.penalize_nl;
|
const bool penalize_nl = params.penalize_nl;
|
||||||
|
|
||||||
|
auto & prev = ctx_sampling->prev;
|
||||||
|
auto & cur = ctx_sampling->cur;
|
||||||
|
|
||||||
llama_token id = 0;
|
llama_token id = 0;
|
||||||
|
|
||||||
float * logits = llama_get_logits_ith(ctx, idx);
|
float * logits = llama_get_logits_ith(ctx_main, idx);
|
||||||
|
|
||||||
// Apply params.logit_bias map
|
// Apply params.logit_bias map
|
||||||
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
for (auto it = params.logit_bias.begin(); it != params.logit_bias.end(); it++) {
|
||||||
logits[it->first] += it->second;
|
logits[it->first] += it->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
candidates.clear();
|
cur.clear();
|
||||||
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
llama_token_data_array cur_p = { cur.data(), cur.size(), false };
|
||||||
|
|
||||||
if (ctx_guidance) {
|
if (ctx_guidance) {
|
||||||
llama_sample_classifier_free_guidance(ctx, &cur_p, ctx_guidance, params.cfg_scale);
|
llama_sample_classifier_free_guidance(ctx_main, &cur_p, ctx_guidance, params.cfg_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply penalties
|
// apply penalties
|
||||||
if (!last_tokens.empty()) {
|
if (!prev.empty()) {
|
||||||
const float nl_logit = logits[llama_token_nl(ctx)];
|
const float nl_logit = logits[llama_token_nl(ctx_main)];
|
||||||
const int last_n_repeat = std::min(std::min((int)last_tokens.size(), repeat_last_n), n_ctx);
|
const int last_n_repeat = std::min(std::min((int)prev.size(), repeat_last_n), n_ctx);
|
||||||
|
|
||||||
llama_sample_repetition_penalty(ctx, &cur_p,
|
llama_sample_repetition_penalty(ctx_main, &cur_p,
|
||||||
last_tokens.data() + last_tokens.size() - last_n_repeat,
|
prev.data() + prev.size() - last_n_repeat,
|
||||||
last_n_repeat, repeat_penalty);
|
last_n_repeat, repeat_penalty);
|
||||||
llama_sample_frequency_and_presence_penalties(ctx, &cur_p,
|
llama_sample_frequency_and_presence_penalties(ctx_main, &cur_p,
|
||||||
last_tokens.data() + last_tokens.size() - last_n_repeat,
|
prev.data() + prev.size() - last_n_repeat,
|
||||||
last_n_repeat, alpha_frequency, alpha_presence);
|
last_n_repeat, alpha_frequency, alpha_presence);
|
||||||
|
|
||||||
if (!penalize_nl) {
|
if (!penalize_nl) {
|
||||||
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
for (size_t idx = 0; idx < cur_p.size; idx++) {
|
||||||
if (cur_p.data[idx].id == llama_token_nl(ctx)) {
|
if (cur_p.data[idx].id == llama_token_nl(ctx_main)) {
|
||||||
cur_p.data[idx].logit = nl_logit;
|
cur_p.data[idx].logit = nl_logit;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -78,50 +124,58 @@ llama_token llama_sampling_sample(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ctx_sampling.grammar != NULL) {
|
if (ctx_sampling->grammar != NULL) {
|
||||||
llama_sample_grammar(ctx, &cur_p, ctx_sampling.grammar);
|
llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (temp <= 0) {
|
if (temp <= 0) {
|
||||||
// Greedy sampling
|
// Greedy sampling
|
||||||
id = llama_sample_token_greedy(ctx, &cur_p);
|
id = llama_sample_token_greedy(ctx_main, &cur_p);
|
||||||
} else {
|
} else {
|
||||||
if (mirostat == 1) {
|
if (mirostat == 1) {
|
||||||
const int mirostat_m = 100;
|
const int mirostat_m = 100;
|
||||||
llama_sample_temp(ctx, &cur_p, temp);
|
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||||
id = llama_sample_token_mirostat(ctx, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling.mirostat_mu);
|
id = llama_sample_token_mirostat(ctx_main, &cur_p, mirostat_tau, mirostat_eta, mirostat_m, &ctx_sampling->mirostat_mu);
|
||||||
} else if (mirostat == 2) {
|
} else if (mirostat == 2) {
|
||||||
llama_sample_temp(ctx, &cur_p, temp);
|
llama_sample_temp(ctx_main, &cur_p, temp);
|
||||||
id = llama_sample_token_mirostat_v2(ctx, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling.mirostat_mu);
|
id = llama_sample_token_mirostat_v2(ctx_main, &cur_p, mirostat_tau, mirostat_eta, &ctx_sampling->mirostat_mu);
|
||||||
} else {
|
} else {
|
||||||
// Temperature sampling
|
// Temperature sampling
|
||||||
size_t min_keep = std::max(1, params.n_probs);
|
size_t min_keep = std::max(1, params.n_probs);
|
||||||
llama_sample_top_k (ctx, &cur_p, top_k, min_keep);
|
llama_sample_top_k (ctx_main, &cur_p, top_k, min_keep);
|
||||||
llama_sample_tail_free (ctx, &cur_p, tfs_z, min_keep);
|
llama_sample_tail_free(ctx_main, &cur_p, tfs_z, min_keep);
|
||||||
llama_sample_typical (ctx, &cur_p, typical_p, min_keep);
|
llama_sample_typical (ctx_main, &cur_p, typical_p, min_keep);
|
||||||
llama_sample_top_p (ctx, &cur_p, top_p, min_keep);
|
llama_sample_top_p (ctx_main, &cur_p, top_p, min_keep);
|
||||||
llama_sample_temp(ctx, &cur_p, temp);
|
llama_sample_temp (ctx_main, &cur_p, temp);
|
||||||
|
|
||||||
{
|
id = llama_sample_token(ctx_main, &cur_p);
|
||||||
const int n_top = 10;
|
|
||||||
LOG("top %d candidates:\n", n_top);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_top; i++) {
|
//{
|
||||||
const llama_token id = cur_p.data[i].id;
|
// const int n_top = 10;
|
||||||
(void)id; // To avoid a warning that id is unused when logging is disabled.
|
// LOG("top %d candidates:\n", n_top);
|
||||||
LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx, id).c_str(), cur_p.data[i].p);
|
|
||||||
|
// for (int i = 0; i < n_top; i++) {
|
||||||
|
// const llama_token id = cur_p.data[i].id;
|
||||||
|
// (void)id; // To avoid a warning that id is unused when logging is disabled.
|
||||||
|
// LOG(" - %5d: '%12s' (%.3f)\n", id, llama_token_to_piece(ctx_main, id).c_str(), cur_p.data[i].p);
|
||||||
|
// }
|
||||||
|
//}
|
||||||
|
|
||||||
|
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx_main, id).c_str());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
id = llama_sample_token(ctx, &cur_p);
|
|
||||||
|
|
||||||
LOG("sampled token: %5d: '%s'\n", id, llama_token_to_piece(ctx, id).c_str());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ctx_sampling.grammar != NULL) {
|
|
||||||
llama_grammar_accept_token(ctx, ctx_sampling.grammar, id);
|
|
||||||
}
|
|
||||||
|
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_sampling_accept(
|
||||||
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token id) {
|
||||||
|
ctx_sampling->prev.erase(ctx_sampling->prev.begin());
|
||||||
|
ctx_sampling->prev.push_back(id);
|
||||||
|
|
||||||
|
if (ctx_sampling->grammar != NULL) {
|
||||||
|
llama_grammar_accept_token(ctx_main, ctx_sampling->grammar, id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
|
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include "grammar-parser.h"
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
@ -35,7 +37,8 @@ typedef struct llama_sampling_params {
|
|||||||
} llama_sampling_params;
|
} llama_sampling_params;
|
||||||
|
|
||||||
// general sampler context
|
// general sampler context
|
||||||
typedef struct llama_sampling_context {
|
// TODO: move to llama.h
|
||||||
|
struct llama_sampling_context {
|
||||||
// parameters that will be used for sampling
|
// parameters that will be used for sampling
|
||||||
llama_sampling_params params;
|
llama_sampling_params params;
|
||||||
|
|
||||||
@ -43,45 +46,50 @@ typedef struct llama_sampling_context {
|
|||||||
float mirostat_mu;
|
float mirostat_mu;
|
||||||
|
|
||||||
llama_grammar * grammar;
|
llama_grammar * grammar;
|
||||||
} llama_sampling_context;
|
|
||||||
|
// internal
|
||||||
|
grammar_parser::parse_state parsed_grammar;
|
||||||
|
|
||||||
|
std::vector<llama_token> prev;
|
||||||
|
std::vector<llama_token_data> cur;
|
||||||
|
};
|
||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
// Create a new sampling context instance.
|
// Create a new sampling context instance.
|
||||||
llama_sampling_context llama_sampling_context_init(
|
struct llama_sampling_context * llama_sampling_init(const struct gpt_params & params);
|
||||||
const struct gpt_params & params,
|
|
||||||
llama_grammar * grammar = NULL);
|
|
||||||
|
|
||||||
// Reset the sampler context for the supplied sequence id (defaults to 0).
|
void llama_sampling_free(struct llama_sampling_context * ctx);
|
||||||
// This is necessary to reuse a sequence id or free memory used by sequences
|
|
||||||
// that are no longer required.
|
// Reset the sampler context
|
||||||
bool llama_sampling_context_reset(
|
// - clear prev tokens
|
||||||
llama_sampling_context & ctx_sampling,
|
// - reset grammar
|
||||||
const llama_seq_id seq = 0);
|
void llama_sampling_reset(llama_sampling_context * ctx);
|
||||||
|
|
||||||
// this is a common sampling function used across the examples for convenience
|
// this is a common sampling function used across the examples for convenience
|
||||||
// it can serve as a starting point for implementing your own sampling function
|
// it can serve as a starting point for implementing your own sampling function
|
||||||
// Note: When using multiple sequences, it is the caller's responsibility to call
|
// Note: When using multiple sequences, it is the caller's responsibility to call
|
||||||
// llama_sampling_context_reset when a sequence ends
|
// llama_sampling_reset when a sequence ends
|
||||||
//
|
//
|
||||||
// required:
|
// required:
|
||||||
// - ctx: context to use for sampling
|
// - ctx_main: context to use for sampling
|
||||||
// - ctx_sampling: sampling-specific context
|
// - ctx_sampling: sampling-specific context
|
||||||
//
|
//
|
||||||
// optional:
|
// optional:
|
||||||
// - ctx_guidance: context to use for classifier-free guidance, ignore if NULL
|
// - ctx_guidance: context to use for guidance
|
||||||
// - last_tokens: needed for repetition penalty, ignore if empty
|
|
||||||
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
// - idx: sample from llama_get_logits_ith(ctx, idx)
|
||||||
// - seq: sequence id to associate sampler state with
|
|
||||||
//
|
//
|
||||||
// returns:
|
// returns:
|
||||||
// - token: sampled token
|
// - token: sampled token
|
||||||
// - candidates: vector of candidate tokens
|
// - candidates: vector of candidate tokens
|
||||||
//
|
//
|
||||||
llama_token llama_sampling_sample(
|
llama_token llama_sampling_sample(
|
||||||
struct llama_context * ctx,
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
struct llama_context * ctx_guidance,
|
struct llama_context * ctx_guidance,
|
||||||
struct llama_sampling_context & ctx_sampling,
|
int idx = 0);
|
||||||
const std::vector<llama_token> & last_tokens,
|
|
||||||
std::vector<llama_token_data> & candidates,
|
void llama_sampling_accept(
|
||||||
const int idx = 0);
|
struct llama_sampling_context * ctx_sampling,
|
||||||
|
struct llama_context * ctx_main,
|
||||||
|
llama_token id);
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "grammar-parser.h"
|
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
@ -19,10 +18,7 @@ struct seq_draft {
|
|||||||
|
|
||||||
std::vector<llama_token> tokens;
|
std::vector<llama_token> tokens;
|
||||||
|
|
||||||
struct llama_grammar * grammar = NULL;
|
struct llama_sampling_context * ctx_sampling;
|
||||||
|
|
||||||
std::vector<llama_token> last_tokens;
|
|
||||||
struct llama_sampling_context ctx_sampling;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
@ -96,8 +92,6 @@ int main(int argc, char ** argv) {
|
|||||||
const auto t_enc_end = ggml_time_us();
|
const auto t_enc_end = ggml_time_us();
|
||||||
|
|
||||||
// the 2 models should have the same vocab
|
// the 2 models should have the same vocab
|
||||||
const int n_ctx = llama_n_ctx(ctx_tgt);
|
|
||||||
const int n_vocab = llama_n_vocab(model_tgt);
|
|
||||||
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
|
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
|
||||||
|
|
||||||
// how many tokens to draft each time
|
// how many tokens to draft each time
|
||||||
@ -113,69 +107,34 @@ int main(int argc, char ** argv) {
|
|||||||
// used to determine end of generation
|
// used to determine end of generation
|
||||||
bool has_eos = false;
|
bool has_eos = false;
|
||||||
|
|
||||||
// grammar stuff
|
|
||||||
struct llama_grammar * grammar = NULL;
|
|
||||||
|
|
||||||
grammar_parser::parse_state parsed_grammar;
|
|
||||||
|
|
||||||
// if requested - load the grammar, error checking is omitted for brevity
|
|
||||||
if (!params.grammar.empty()) {
|
|
||||||
parsed_grammar = grammar_parser::parse(params.grammar.c_str());
|
|
||||||
// will be empty (default) if there are parse errors
|
|
||||||
if (parsed_grammar.rules.empty()) {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
|
||||||
grammar = llama_grammar_init(grammar_rules.data(), grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
|
||||||
}
|
|
||||||
|
|
||||||
// target model sampling context
|
// target model sampling context
|
||||||
llama_sampling_context ctx_sampling = llama_sampling_context_init(params, grammar);
|
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params);
|
||||||
|
|
||||||
// TODO: move to llama_sampling_state
|
|
||||||
std::vector<llama_token_data> candidates;
|
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
|
|
||||||
std::vector<llama_token> last_tokens;
|
|
||||||
last_tokens.resize(n_ctx);
|
|
||||||
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
|
||||||
|
|
||||||
for (auto & id : inp) {
|
|
||||||
last_tokens.erase(last_tokens.begin());
|
|
||||||
last_tokens.push_back(id);
|
|
||||||
}
|
|
||||||
|
|
||||||
// draft sequence data
|
// draft sequence data
|
||||||
std::vector<seq_draft> drafts(n_seq_dft);
|
std::vector<seq_draft> drafts(n_seq_dft);
|
||||||
|
|
||||||
|
params.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
||||||
|
params.sampling_params.temp = 1.0f; // the draft samplers use default temperature
|
||||||
|
|
||||||
for (int i = 0; i < n_seq_dft; ++i) {
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
{
|
drafts[i].ctx_sampling = llama_sampling_init(params);
|
||||||
auto & last_tokens = drafts[i].last_tokens;
|
|
||||||
|
|
||||||
last_tokens.resize(n_ctx);
|
|
||||||
std::fill(last_tokens.begin(), last_tokens.end(), 0);
|
|
||||||
|
|
||||||
for (auto & id : inp) {
|
|
||||||
last_tokens.erase(last_tokens.begin());
|
|
||||||
last_tokens.push_back(id);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
drafts[i].ctx_sampling = llama_sampling_context_init(params, grammar);
|
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||||
}
|
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
|
||||||
|
|
||||||
llama_batch batch_dft = llama_batch_init(512, 0, 1);
|
|
||||||
llama_batch batch_tgt = llama_batch_init(512, 0, n_seq_dft);
|
|
||||||
|
|
||||||
const auto t_dec_start = ggml_time_us();
|
const auto t_dec_start = ggml_time_us();
|
||||||
|
|
||||||
|
// sample from the last token of the prompt
|
||||||
drafts[0].i_batch_tgt.resize(1);
|
drafts[0].i_batch_tgt.resize(1);
|
||||||
drafts[0].i_batch_tgt[0] = 0;
|
drafts[0].i_batch_tgt[0] = 0;
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
|
// print current draft sequences
|
||||||
for (int i = 0; i < n_seq_dft; ++i) {
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
if (!drafts[i].active) continue;
|
if (!drafts[i].active) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
const auto & tokens = drafts[i].tokens;
|
const auto & tokens = drafts[i].tokens;
|
||||||
|
|
||||||
@ -189,11 +148,9 @@ int main(int argc, char ** argv) {
|
|||||||
LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]);
|
LOG("sampling target: i_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", i_keep, i_dft, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
llama_token id = llama_sampling_sample(ctx_tgt, NULL, ctx_sampling, last_tokens, candidates, drafts[i_keep].i_batch_tgt[i_dft]);
|
llama_token id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[i_keep].i_batch_tgt[i_dft]);
|
||||||
|
|
||||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
llama_sampling_accept(ctx_sampling, ctx_tgt, id);
|
||||||
last_tokens.erase(last_tokens.begin());
|
|
||||||
last_tokens.push_back(id);
|
|
||||||
|
|
||||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
|
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, last_tokens));
|
||||||
|
|
||||||
@ -213,7 +170,9 @@ int main(int argc, char ** argv) {
|
|||||||
bool matches = false;
|
bool matches = false;
|
||||||
|
|
||||||
for (int i = 0; i < n_seq_dft; ++i) {
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
if (!drafts[i].active) continue;
|
if (!drafts[i].active) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) {
|
if (i_dft < (int) drafts[i].tokens.size() && id == drafts[i].tokens[i_dft]) {
|
||||||
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str());
|
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, i, id, token_str.c_str());
|
||||||
@ -263,11 +222,11 @@ int main(int argc, char ** argv) {
|
|||||||
{
|
{
|
||||||
batch_dft.n_tokens = 1;
|
batch_dft.n_tokens = 1;
|
||||||
|
|
||||||
batch_dft.token[0] = id;
|
batch_dft.token [0] = id;
|
||||||
batch_dft.pos[0] = n_past_dft;
|
batch_dft.pos [0] = n_past_dft;
|
||||||
batch_dft.n_seq_id[0] = 1;
|
batch_dft.n_seq_id[0] = 1;
|
||||||
batch_dft.seq_id[0][0] = 0;
|
batch_dft.seq_id [0][0] = 0;
|
||||||
batch_dft.logits[0] = true;
|
batch_dft.logits [0] = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||||
@ -281,17 +240,19 @@ int main(int argc, char ** argv) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (grammar) {
|
|
||||||
for (int i = 0; i < n_seq_dft; ++i) {
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
auto * grammar_dft = drafts[i].grammar;
|
if (ctx_sampling->grammar) {
|
||||||
|
auto & grammar_dft = drafts[0].ctx_sampling->grammar;
|
||||||
if (grammar_dft) {
|
if (grammar_dft) {
|
||||||
llama_grammar_free(grammar_dft);
|
llama_grammar_free(grammar_dft);
|
||||||
}
|
}
|
||||||
|
|
||||||
grammar_dft = llama_grammar_copy(ctx_sampling.grammar);
|
grammar_dft = llama_grammar_copy(ctx_sampling->grammar);
|
||||||
|
|
||||||
LOG("copied target grammar to draft %d grammar\n", i);
|
LOG("copied target grammar to draft %d grammar\n", 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
drafts[i].ctx_sampling->prev = ctx_sampling->prev;
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_seq_cur = 1;
|
int n_seq_cur = 1;
|
||||||
@ -306,11 +267,11 @@ int main(int argc, char ** argv) {
|
|||||||
drafts[0].i_batch_dft = 0;
|
drafts[0].i_batch_dft = 0;
|
||||||
|
|
||||||
batch_tgt.n_tokens = 1;
|
batch_tgt.n_tokens = 1;
|
||||||
batch_tgt.token[0] = drafts[0].tokens[0];
|
batch_tgt.token [0] = drafts[0].tokens[0];
|
||||||
batch_tgt.pos[0] = n_past_tgt;
|
batch_tgt.pos [0] = n_past_tgt;
|
||||||
batch_tgt.n_seq_id[0] = 1;
|
batch_tgt.n_seq_id[0] = 1;
|
||||||
batch_tgt.seq_id[0][0] = 0;
|
batch_tgt.seq_id [0][0] = 0;
|
||||||
batch_tgt.logits[0] = true;
|
batch_tgt.logits [0] = true;
|
||||||
|
|
||||||
// sample n_draft tokens from the draft model using tree-based sampling
|
// sample n_draft tokens from the draft model using tree-based sampling
|
||||||
for (int i = 0; i < n_draft; ++i) {
|
for (int i = 0; i < n_draft; ++i) {
|
||||||
@ -321,46 +282,32 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
for (int s = 0; s < n_seq_dft; ++s) {
|
||||||
if (!drafts[s].drafting || drafts[s].skip) continue;
|
if (!drafts[s].drafting || drafts[s].skip) {
|
||||||
|
continue;
|
||||||
auto & grammar = drafts[s].grammar;
|
|
||||||
auto & i_batch_dft = drafts[s].i_batch_dft;
|
|
||||||
|
|
||||||
float * logits = llama_get_logits_ith(ctx_dft, i_batch_dft);
|
|
||||||
|
|
||||||
// TODO: optimize
|
|
||||||
candidates.clear();
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array cur_p = { candidates.data(), candidates.size(), false };
|
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
|
||||||
|
|
||||||
if (grammar != NULL) {
|
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
||||||
llama_sample_grammar(ctx_dft, &cur_p, grammar);
|
|
||||||
}
|
|
||||||
|
|
||||||
// computes softmax and sorts the candidates
|
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
|
||||||
llama_sample_softmax(ctx_dft, &cur_p);
|
|
||||||
|
|
||||||
for (int k = 0; k < 3; ++k) {
|
|
||||||
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||||
k, s, i, cur_p.data[k].id, cur_p.data[k].p, llama_token_to_piece(ctx_dft, cur_p.data[k].id).c_str());
|
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: make this configurable
|
// TODO: make this configurable
|
||||||
if (cur_p.data[0].p < 0.1) {
|
if (cur_p[0].p < 0.4) {
|
||||||
//if (cur_p.data[0].p < 2*cur_p.data[1].p) {
|
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p[0].p, cur_p[1].p);
|
||||||
LOG("stopping drafting for seq %3d, probability too low: %.3f < 2*%.3f\n", s, cur_p.data[0].p, cur_p.data[1].p);
|
|
||||||
drafts[s].drafting = false;
|
drafts[s].drafting = false;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> sa(1, s);
|
std::vector<int> sa(1, s);
|
||||||
|
|
||||||
|
// attempt to split the branch if the probability is high enough
|
||||||
for (int f = 1; f < 8; ++f) {
|
for (int f = 1; f < 8; ++f) {
|
||||||
// TODO: make this configurable
|
// TODO: make this configurable
|
||||||
if (n_seq_cur < n_seq_dft && cur_p.data[f].p > 0.10) {
|
if (n_seq_cur < n_seq_dft && cur_p[f].p > 0.3) {
|
||||||
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||||
@ -376,9 +323,18 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
drafts[n_seq_cur] = drafts[s];
|
// copy the draft state
|
||||||
|
drafts[n_seq_cur].active = true;
|
||||||
|
drafts[n_seq_cur].drafting = true;
|
||||||
drafts[n_seq_cur].skip = true;
|
drafts[n_seq_cur].skip = true;
|
||||||
// TODO: grammar
|
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
||||||
|
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||||
|
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||||
|
|
||||||
|
if (ctx_sampling->grammar) {
|
||||||
|
drafts[n_seq_cur].ctx_sampling->grammar =
|
||||||
|
llama_grammar_copy(drafts[s].ctx_sampling->grammar);
|
||||||
|
}
|
||||||
|
|
||||||
sa.push_back(n_seq_cur);
|
sa.push_back(n_seq_cur);
|
||||||
n_seq_cur++;
|
n_seq_cur++;
|
||||||
@ -389,17 +345,17 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// add drafted token for each sequence
|
// add drafted token for each sequence
|
||||||
for (int is = 0; is < (int) sa.size(); ++is) {
|
for (int is = 0; is < (int) sa.size(); ++is) {
|
||||||
const llama_token id = cur_p.data[is].id;
|
const llama_token id = cur_p[is].id;
|
||||||
|
|
||||||
int s = sa[is];
|
int s = sa[is];
|
||||||
|
|
||||||
auto & drafted = drafts[s].tokens;
|
auto & drafted = drafts[s].tokens;
|
||||||
//auto & grammar = drafts[s].grammar;
|
|
||||||
|
|
||||||
auto & i_batch_dft = drafts[s].i_batch_dft;
|
auto & i_batch_dft = drafts[s].i_batch_dft;
|
||||||
auto & i_batch_tgt = drafts[s].i_batch_tgt;
|
auto & i_batch_tgt = drafts[s].i_batch_tgt;
|
||||||
|
|
||||||
drafted.push_back(id);
|
drafted.push_back(id);
|
||||||
|
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id);
|
||||||
|
|
||||||
// add unique drafted tokens to the target batch
|
// add unique drafted tokens to the target batch
|
||||||
batch_tgt.token [batch_tgt.n_tokens] = id;
|
batch_tgt.token [batch_tgt.n_tokens] = id;
|
||||||
@ -413,7 +369,7 @@ int main(int argc, char ** argv) {
|
|||||||
batch_tgt.n_tokens++;
|
batch_tgt.n_tokens++;
|
||||||
|
|
||||||
// no need to evaluate the last drafted token, since we won't use the result
|
// no need to evaluate the last drafted token, since we won't use the result
|
||||||
if (i == n_draft - 1) {
|
if (batch_tgt.n_tokens == n_draft) {
|
||||||
drafts[s].drafting = false;
|
drafts[s].drafting = false;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -441,19 +397,6 @@ int main(int argc, char ** argv) {
|
|||||||
++n_past_cur;
|
++n_past_cur;
|
||||||
++n_drafted;
|
++n_drafted;
|
||||||
|
|
||||||
// update grammar
|
|
||||||
for (int s = 0; s < n_seq_dft; ++s) {
|
|
||||||
if (!drafts[s].drafting) continue;
|
|
||||||
|
|
||||||
auto & drafted = drafts[s].tokens;
|
|
||||||
auto & grammar = drafts[s].grammar;
|
|
||||||
|
|
||||||
if (grammar != NULL) {
|
|
||||||
llama_grammar_accept_token(ctx_dft, grammar, drafted.back());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
|
||||||
if (batch_tgt.n_tokens >= n_draft) {
|
if (batch_tgt.n_tokens >= n_draft) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -473,7 +416,9 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// the first token is always proposed by the traget model before the speculation loop so we erase it here
|
// the first token is always proposed by the traget model before the speculation loop so we erase it here
|
||||||
for (int i = 0; i < n_seq_dft; ++i) {
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
if (!drafts[i].active) continue;
|
if (!drafts[i].active) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
drafts[i].tokens.erase(drafts[i].tokens.begin());
|
drafts[i].tokens.erase(drafts[i].tokens.begin());
|
||||||
}
|
}
|
||||||
@ -507,13 +452,11 @@ int main(int argc, char ** argv) {
|
|||||||
llama_free(ctx_dft);
|
llama_free(ctx_dft);
|
||||||
llama_free_model(model_dft);
|
llama_free_model(model_dft);
|
||||||
|
|
||||||
if (grammar) {
|
llama_sampling_free(ctx_sampling);
|
||||||
llama_grammar_free(grammar);
|
|
||||||
|
|
||||||
for (int i = 0; i < n_seq_dft; ++i) {
|
for (int i = 0; i < n_seq_dft; ++i) {
|
||||||
llama_grammar_free(drafts[i].grammar);
|
llama_sampling_free(drafts[i].ctx_sampling);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
fprintf(stderr, "\n\n");
|
fprintf(stderr, "\n\n");
|
||||||
|
Loading…
Reference in New Issue
Block a user