mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
speculative : simplify (cont)
ggml-ci
This commit is contained in:
parent
e4c122b93c
commit
0d4d0c1559
@ -320,7 +320,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
return cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first) {
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||
|
||||
std::vector<llama_token> result;
|
||||
@ -342,23 +342,10 @@ std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl,
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first) {
|
||||
std::vector<int> idxs;
|
||||
idxs.reserve(batch.n_tokens);
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
draft.reserve(batch.n_tokens);
|
||||
|
||||
for (int i = 0; i < batch.n_tokens; i++) {
|
||||
if (batch.logits[i] == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (idxs.size() > 0) {
|
||||
GGML_ASSERT(batch.pos[idxs.back()] + 1 == batch.pos[i]);
|
||||
draft.push_back(batch.token[i]);
|
||||
}
|
||||
idxs.push_back(i);
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
||||
std::vector<int> idxs(draft.size() + 1);
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
idxs[i] = i;
|
||||
}
|
||||
|
||||
return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first);
|
||||
|
@ -71,9 +71,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||
//
|
||||
// returns at least 1 token, up to idxs.size()
|
||||
//
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const std::vector<llama_token> & draft, bool grammar_first = false);
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const struct llama_batch & batch, bool grammar_first = false);
|
||||
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
||||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||
|
||||
|
@ -10,24 +10,19 @@
|
||||
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
||||
|
||||
struct common_speculative {
|
||||
struct common_speculative_params params;
|
||||
|
||||
llama_batch batch;
|
||||
|
||||
struct llama_context * ctx;
|
||||
struct common_sampler * smpl;
|
||||
|
||||
llama_batch batch;
|
||||
llama_tokens prompt;
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct common_speculative_params params,
|
||||
struct llama_context * ctx_dft) {
|
||||
auto * result = new common_speculative {
|
||||
/* .params = */ params,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .ctx = */ ctx_dft,
|
||||
/* .smpl = */ nullptr,
|
||||
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
|
||||
/* .prompt = */ {},
|
||||
};
|
||||
|
||||
@ -130,12 +125,11 @@ bool common_speculative_are_compatible(
|
||||
return true;
|
||||
}
|
||||
|
||||
void common_speculative_add_draft(
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
struct llama_batch & batch_tgt,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt_tgt,
|
||||
llama_token id_last,
|
||||
llama_token n_past_tgt) {
|
||||
llama_token id_last) {
|
||||
auto & batch = spec->batch;
|
||||
auto & ctx = spec->ctx;
|
||||
auto & smpl = spec->smpl;
|
||||
@ -144,7 +138,7 @@ void common_speculative_add_draft(
|
||||
int reuse_i = 0;
|
||||
int reuse_n = 0;
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx) - spec->params.n_draft;
|
||||
const int n_ctx = llama_n_ctx(ctx) - params.n_draft;
|
||||
|
||||
const int i_start = std::max<int>(0, (int) prompt_tgt.size() - n_ctx);
|
||||
|
||||
@ -156,7 +150,7 @@ void common_speculative_add_draft(
|
||||
cur++;
|
||||
}
|
||||
|
||||
if ((cur >= spec->params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) {
|
||||
if ((cur >= params.n_reuse || prompt_tgt.size() <= n_ctx) && cur > reuse_n) {
|
||||
reuse_i = i;
|
||||
reuse_n = cur;
|
||||
}
|
||||
@ -207,8 +201,11 @@ void common_speculative_add_draft(
|
||||
|
||||
common_sampler_reset(smpl);
|
||||
|
||||
llama_tokens result;
|
||||
result.reserve(params.n_draft);
|
||||
|
||||
// sample n_draft tokens from the draft model
|
||||
for (int i = 0; i < spec->params.n_draft; ++i) {
|
||||
for (int i = 0; i < params.n_draft; ++i) {
|
||||
common_batch_clear(batch);
|
||||
|
||||
common_sampler_sample(smpl, ctx, 0, true);
|
||||
@ -224,15 +221,15 @@ void common_speculative_add_draft(
|
||||
const llama_token id = cur_p->data[0].id;
|
||||
|
||||
// only collect very high-confidence draft tokens
|
||||
if (cur_p->data[0].p < spec->params.p_min) {
|
||||
if (cur_p->data[0].p < params.p_min) {
|
||||
break;
|
||||
}
|
||||
|
||||
common_sampler_accept(smpl, id, true);
|
||||
|
||||
common_batch_add(batch_tgt, id, n_past_tgt + i, { 0 }, true);
|
||||
result.push_back(id);
|
||||
|
||||
if (batch_tgt.n_tokens > spec->params.n_draft) {
|
||||
if (result.size() >= params.n_draft) {
|
||||
break;
|
||||
}
|
||||
|
||||
@ -244,9 +241,5 @@ void common_speculative_add_draft(
|
||||
prompt.push_back(id);
|
||||
}
|
||||
|
||||
// don't waste time on small batches
|
||||
// TODO: do not evaluate the draft model for that many rounds
|
||||
if (batch_tgt.n_tokens < spec->params.n_min) {
|
||||
batch_tgt.n_tokens = 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
@ -7,15 +7,12 @@ struct common_speculative;
|
||||
|
||||
struct common_speculative_params {
|
||||
int n_draft = 16;
|
||||
int n_min = 5; // do not add drafts smaller than this, TODO: leave this to user?
|
||||
int n_reuse = 256;
|
||||
|
||||
float p_min = 0.9f;
|
||||
};
|
||||
|
||||
struct common_speculative * common_speculative_init(
|
||||
struct common_speculative_params params,
|
||||
struct llama_context * ctx_dft);
|
||||
struct common_speculative * common_speculative_init(struct llama_context * ctx_dft);
|
||||
|
||||
void common_speculative_free(struct common_speculative * spec);
|
||||
|
||||
@ -25,9 +22,8 @@ bool common_speculative_are_compatible(
|
||||
|
||||
// sample up to n_draft tokens and add them to the batch using the draft model
|
||||
//
|
||||
void common_speculative_add_draft(
|
||||
llama_tokens common_speculative_gen_draft(
|
||||
struct common_speculative * spec,
|
||||
struct llama_batch & batch_tgt,
|
||||
struct common_speculative_params params,
|
||||
const llama_tokens & prompt,
|
||||
llama_token id_last,
|
||||
llama_token n_past_tgt);
|
||||
llama_token id_last);
|
||||
|
@ -13,6 +13,9 @@
|
||||
int main(int argc, char ** argv) {
|
||||
common_params params;
|
||||
|
||||
// minimum size of the draft to use
|
||||
const int n_min = 5;
|
||||
|
||||
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
||||
return 1;
|
||||
}
|
||||
@ -92,31 +95,29 @@ int main(int argc, char ** argv) {
|
||||
// everything until here is standard initialization
|
||||
// the relevant stuff for speculative decoding starts here
|
||||
|
||||
const int n_input = inp.size();
|
||||
|
||||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
// target model sampling context
|
||||
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
|
||||
|
||||
// eval the prompt
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), n_input - 1));
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
|
||||
|
||||
// note: keep the last token separate!
|
||||
llama_token id_last = inp.back();
|
||||
|
||||
auto prompt_dft = std::vector<llama_token>(inp.begin(), inp.end() - 1);
|
||||
// all tokens currently in the target context
|
||||
auto prompt_tgt = std::vector<llama_token>(inp.begin(), inp.end() - 1);
|
||||
|
||||
int n_past = inp.size() - 1;
|
||||
|
||||
// init the speculator
|
||||
struct common_speculative_params params_spec;
|
||||
params_spec.n_draft = n_draft;
|
||||
params_spec.n_min = 5;
|
||||
params_spec.n_reuse = 256;
|
||||
params_spec.p_min = 0.9f;
|
||||
|
||||
struct common_speculative * spec = common_speculative_init(params_spec, ctx_dft);
|
||||
struct common_speculative * spec = common_speculative_init(ctx_dft);
|
||||
|
||||
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
||||
|
||||
@ -125,21 +126,30 @@ int main(int argc, char ** argv) {
|
||||
const auto t_dec_start = ggml_time_us();
|
||||
|
||||
while (true) {
|
||||
// always have a token to evaluate from before
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add (batch_tgt, id_last, n_past, { 0 }, true);
|
||||
|
||||
// optionally, append draft tokens to the target batch
|
||||
// optionally, generate draft tokens that can be appended to the target batch
|
||||
//
|
||||
// this is the most important part of the speculation. the more probable tokens that are provided here
|
||||
// the better the performance will be. in theory, this computation can be performed asynchronously and even
|
||||
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
|
||||
// from a cache or lookup tables.
|
||||
//
|
||||
common_speculative_add_draft(spec, batch_tgt, prompt_dft, id_last, n_past + 1);
|
||||
llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
|
||||
|
||||
// always have a token to evaluate from before - id_last
|
||||
common_batch_clear(batch_tgt);
|
||||
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
|
||||
|
||||
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
||||
{
|
||||
// do not waste time on small drafts
|
||||
if (draft.size() < n_min) {
|
||||
draft.clear();
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < draft.size(); ++i) {
|
||||
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
||||
}
|
||||
|
||||
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
|
||||
|
||||
llama_decode(ctx_tgt, batch_tgt);
|
||||
@ -152,11 +162,11 @@ int main(int argc, char ** argv) {
|
||||
// available logits from the batch and sample the next token until we run out of logits or the sampler
|
||||
// disagrees with the draft
|
||||
//
|
||||
const auto ids = common_sampler_sample_n(smpl, ctx_tgt, batch_tgt);
|
||||
const auto ids = common_sampler_sample_n(smpl, ctx_tgt, draft);
|
||||
|
||||
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
|
||||
|
||||
n_past += ids.size();
|
||||
n_past += ids.size() - 1;
|
||||
n_drafted += batch_tgt.n_tokens - 1;
|
||||
n_accept += ids.size() - 1;
|
||||
|
||||
@ -192,7 +202,7 @@ int main(int argc, char ** argv) {
|
||||
break;
|
||||
}
|
||||
|
||||
LOG_DBG("accepted %d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, id, token_str.c_str());
|
||||
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d, '%s')\n", (int) ids.size() - 1, (int) draft.size(), id, token_str.c_str());
|
||||
|
||||
{
|
||||
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
|
||||
@ -200,8 +210,8 @@ int main(int argc, char ** argv) {
|
||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
|
||||
}
|
||||
|
||||
prompt_dft.push_back(id_last);
|
||||
prompt_dft.insert(prompt_dft.end(), ids.begin(), ids.end() - 1);
|
||||
prompt_tgt.push_back(id_last);
|
||||
prompt_tgt.insert(prompt_tgt.end(), ids.begin(), ids.end() - 1);
|
||||
|
||||
// remember the last accepted token for the next iteration
|
||||
id_last = id;
|
||||
@ -210,6 +220,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
auto t_dec_end = ggml_time_us();
|
||||
|
||||
const int n_input = inp.size();
|
||||
|
||||
LOG("\n\n");
|
||||
|
||||
LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
||||
|
Loading…
Reference in New Issue
Block a user