diff --git a/common/sampling.cpp b/common/sampling.cpp index 75e2e5d29..52f4c9e22 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -320,7 +320,7 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co return cur_p.data[cur_p.selected].id; } -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first) { GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1"); std::vector result; @@ -330,25 +330,33 @@ std::vector common_sampler_sample_n(struct common_sampler * gsmpl, for (; i < draft.size(); i++) { const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + common_sampler_accept(gsmpl, id, true); + + result.push_back(id); + if (draft[i] != id) { break; } + } + + if (i == draft.size()) { + const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first); + + common_sampler_accept(gsmpl, id, true); result.push_back(id); } - result.push_back(common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first)); - return result; } -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) { std::vector idxs(draft.size() + 1); for (size_t i = 0; i < idxs.size(); ++i) { idxs[i] = i; } - return common_sampler_sample_n(gsmpl, ctx, idxs, draft, grammar_first); + return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first); } uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) { diff --git a/common/sampling.h b/common/sampling.h index f9b193ac8..883d905a6 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -62,19 +62,24 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co // generalized version of common_sampler_sample // -// will cross-reference the sampled tokens with a batch of draft tokens -// if the sampler disagrees at some point, we stop and return the sampled tokens up to now +// will cross-reference the sampled tokens with a batch of draft tokens and accept those that match +// if the sampler disagrees at some point, we stop and return the accepted tokens up to now // -// `common_sampler_sample_n(gsmpl, ctx, { idx }, {})` is equivalent to `common_sampler_sample(gsmpl, ctx, idx)` +// common_sampler_sample_n(gsmpl, ctx, { idx }, {}); +// +// is equivalent to +// +// common_sampler_sample(gsmpl, ctx, idx); +// common_sampler_accept(gsmpl, token, true); // // requires: idxs.size() == draft.size() + 1 // // returns at least 1 token, up to idxs.size() // -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector & idxs, const llama_tokens & draft, bool grammar_first = false); // assume idxs == [ 0, 1, 2, ..., draft.size() ] -std::vector common_sampler_sample_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); +std::vector common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false); uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 98a9b35d4..6699e1d85 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -163,7 +163,9 @@ int main(int argc, char ** argv) { // available logits from the batch and sample the next token until we run out of logits or the sampler // disagrees with the draft // - const auto ids = common_sampler_sample_n(smpl, ctx_tgt, draft); + const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft); + + //LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str()); GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token