mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-28 12:24:35 +00:00
speculative : fix the draft sampling
ggml-ci
This commit is contained in:
parent
be5f611000
commit
d9fb3b2e01
@ -320,7 +320,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||||||
return cur_p.data[cur_p.selected].id;
|
return cur_p.data[cur_p.selected].id;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
||||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||||
|
|
||||||
std::vector<llama_token> result;
|
std::vector<llama_token> result;
|
||||||
@ -330,25 +330,33 @@ std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl,
|
|||||||
for (; i < draft.size(); i++) {
|
for (; i < draft.size(); i++) {
|
||||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
|
common_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
|
result.push_back(id);
|
||||||
|
|
||||||
if (draft[i] != id) {
|
if (draft[i] != id) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i == draft.size()) {
|
||||||
|
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||||
|
|
||||||
|
common_sampler_accept(gsmpl, id, true);
|
||||||
|
|
||||||
result.push_back(id);
|
result.push_back(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first));
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
||||||
std::vector<int> idxs(draft.size() + 1);
|
std::vector<int> idxs(draft.size() + 1);
|
||||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||||
idxs[i] = i;
|
idxs[i] = i;
|
||||||
}
|
}
|
||||||
|
|
||||||
return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first);
|
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||||
|
@ -62,19 +62,24 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||||||
|
|
||||||
// generalized version of common_sampler_sample
|
// generalized version of common_sampler_sample
|
||||||
//
|
//
|
||||||
// will cross-reference the sampled tokens with a batch of draft tokens
|
// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
|
||||||
// if the sampler disagrees at some point, we stop and return the sampled tokens up to now
|
// if the sampler disagrees at some point, we stop and return the accepted tokens up to now
|
||||||
//
|
//
|
||||||
// `common_sampler_sample_n(gsmpl, ctx, { idx }, {})` is equivalent to `common_sampler_sample(gsmpl, ctx, idx)`
|
// common_sampler_sample_n(gsmpl, ctx, { idx }, {});
|
||||||
|
//
|
||||||
|
// is equivalent to
|
||||||
|
//
|
||||||
|
// common_sampler_sample(gsmpl, ctx, idx);
|
||||||
|
// common_sampler_accept(gsmpl, token, true);
|
||||||
//
|
//
|
||||||
// requires: idxs.size() == draft.size() + 1
|
// requires: idxs.size() == draft.size() + 1
|
||||||
//
|
//
|
||||||
// returns at least 1 token, up to idxs.size()
|
// returns at least 1 token, up to idxs.size()
|
||||||
//
|
//
|
||||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
||||||
|
|
||||||
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
||||||
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
||||||
|
|
||||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
|
@ -163,7 +163,9 @@ int main(int argc, char ** argv) {
|
|||||||
// available logits from the batch and sample the next token until we run out of logits or the sampler
|
// available logits from the batch and sample the next token until we run out of logits or the sampler
|
||||||
// disagrees with the draft
|
// disagrees with the draft
|
||||||
//
|
//
|
||||||
const auto ids = common_sampler_sample_n(smpl, ctx_tgt, draft);
|
const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
|
||||||
|
|
||||||
|
//LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
|
||||||
|
|
||||||
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
|
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user