speculative : fix the draft sampling

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-24 12:50:17 +02:00
parent be5f611000
commit d9fb3b2e01
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 26 additions and 11 deletions

View File

@ -320,7 +320,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
return cur_p.data[cur_p.selected].id; return cur_p.data[cur_p.selected].id;
} }
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) { std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
std::vector<llama_token> result; std::vector<llama_token> result;
@ -330,25 +330,33 @@ std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl,
for (; i < draft.size(); i++) { for (; i < draft.size(); i++) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
result.push_back(id);
if (draft[i] != id) { if (draft[i] != id) {
break; break;
} }
}
if (i == draft.size()) {
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
common_sampler_accept(gsmpl, id, true);
result.push_back(id); result.push_back(id);
} }
result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first));
return result; return result;
} }
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
std::vector<int> idxs(draft.size() + 1); std::vector<int> idxs(draft.size() + 1);
for (size_t i = 0; i < idxs.size(); ++i) { for (size_t i = 0; i < idxs.size(); ++i) {
idxs[i] = i; idxs[i] = i;
} }
return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first); return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
} }
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {

View File

@ -62,19 +62,24 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
// generalized version of common_sampler_sample // generalized version of common_sampler_sample
// //
// will cross-reference the sampled tokens with a batch of draft tokens // will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
// if the sampler disagrees at some point, we stop and return the sampled tokens up to now // if the sampler disagrees at some point, we stop and return the accepted tokens up to now
// //
// `common_sampler_sample_n(gsmpl, ctx, { idx }, {})` is equivalent to `common_sampler_sample(gsmpl, ctx, idx)` // common_sampler_sample_n(gsmpl, ctx, { idx }, {});
//
// is equivalent to
//
// common_sampler_sample(gsmpl, ctx, idx);
// common_sampler_accept(gsmpl, token, true);
// //
// requires: idxs.size() == draft.size() + 1 // requires: idxs.size() == draft.size() + 1
// //
// returns at least 1 token, up to idxs.size() // returns at least 1 token, up to idxs.size()
// //
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false); std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
// assume idxs == [ 0, 1, 2, ..., draft.size() ] // assume idxs == [ 0, 1, 2, ..., draft.size() ]
std::vector<llama_token> common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);

View File

@ -163,7 +163,9 @@ int main(int argc, char ** argv) {
// available logits from the batch and sample the next token until we run out of logits or the sampler // available logits from the batch and sample the next token until we run out of logits or the sampler
// disagrees with the draft // disagrees with the draft
// //
const auto ids = common_sampler_sample_n(smpl, ctx_tgt, draft); const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
//LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token