diff --git a/examples/server/server.cpp b/examples/server/server.cpp index a0aa30744..31dfd6240 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -921,6 +921,8 @@ struct server_context { slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min); + slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2); + slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0); if (slot.params.sampling.dry_base < 1.0f) { slot.params.sampling.dry_base = defaults.sampling.dry_base; @@ -2322,10 +2324,29 @@ struct server_context { continue; } + // determine the max draft that fits the current slot state + int n_draft_max = slot.params.speculative.n_max; + + // note: n_past is not yet increased for the `id` token sampled above + // also, need to leave space for 1 extra token to allow context shifts + n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2); + + if (slot.n_remaining > 0) { + n_draft_max = std::min(n_draft_max, slot.n_remaining - 1); + } + + SLT_DBG(slot, "max possible draft: %d\n", n_draft_max); + + if (n_draft_max < slot.params.speculative.n_min) { + SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min); + + continue; + } + llama_token id = slot.sampled; struct common_speculative_params params_spec; - params_spec.n_draft = std::min(slot.params.speculative.n_max, slot.n_ctx - slot.n_past - 1); + params_spec.n_draft = n_draft_max; params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max; params_spec.p_min = slot.params.speculative.p_min; @@ -2333,6 +2354,8 @@ struct server_context { // ignore small drafts if (slot.params.speculative.n_min > (int) draft.size()) { + SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min); + continue; } @@ -2344,6 +2367,8 @@ struct server_context { common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); } + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + llama_decode(ctx, slot.batch_spec); // the accepted tokens from the speculation @@ -2372,7 +2397,7 @@ struct server_context { } } - SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size()); + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past); } }