mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-13 20:14:29 +00:00
server : fix speculative decoding with context shift (#10641)
* server : fix speculative decoding with context shift ggml-ci * server : take into account speculative limits ggml-ci * server : add tests
This commit is contained in:
parent
59f4db1088
commit
1da7b76569
@ -921,6 +921,8 @@ struct server_context {
|
|||||||
slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
slot.params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
||||||
|
|
||||||
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
slot.params.speculative.n_min = std::min(slot.params.speculative.n_max, slot.params.speculative.n_min);
|
||||||
|
slot.params.speculative.n_min = std::max(slot.params.speculative.n_min, 2);
|
||||||
|
slot.params.speculative.n_max = std::max(slot.params.speculative.n_max, 0);
|
||||||
|
|
||||||
if (slot.params.sampling.dry_base < 1.0f) {
|
if (slot.params.sampling.dry_base < 1.0f) {
|
||||||
slot.params.sampling.dry_base = defaults.sampling.dry_base;
|
slot.params.sampling.dry_base = defaults.sampling.dry_base;
|
||||||
@ -2322,10 +2324,29 @@ struct server_context {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// determine the max draft that fits the current slot state
|
||||||
|
int n_draft_max = slot.params.speculative.n_max;
|
||||||
|
|
||||||
|
// note: n_past is not yet increased for the `id` token sampled above
|
||||||
|
// also, need to leave space for 1 extra token to allow context shifts
|
||||||
|
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
|
||||||
|
|
||||||
|
if (slot.n_remaining > 0) {
|
||||||
|
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
|
||||||
|
|
||||||
|
if (n_draft_max < slot.params.speculative.n_min) {
|
||||||
|
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
llama_token id = slot.sampled;
|
llama_token id = slot.sampled;
|
||||||
|
|
||||||
struct common_speculative_params params_spec;
|
struct common_speculative_params params_spec;
|
||||||
params_spec.n_draft = slot.params.speculative.n_max;
|
params_spec.n_draft = n_draft_max;
|
||||||
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
||||||
params_spec.p_min = slot.params.speculative.p_min;
|
params_spec.p_min = slot.params.speculative.p_min;
|
||||||
|
|
||||||
@ -2333,6 +2354,8 @@ struct server_context {
|
|||||||
|
|
||||||
// ignore small drafts
|
// ignore small drafts
|
||||||
if (slot.params.speculative.n_min > (int) draft.size()) {
|
if (slot.params.speculative.n_min > (int) draft.size()) {
|
||||||
|
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2344,6 +2367,8 @@ struct server_context {
|
|||||||
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
|
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
||||||
|
|
||||||
llama_decode(ctx, slot.batch_spec);
|
llama_decode(ctx, slot.batch_spec);
|
||||||
|
|
||||||
// the accepted tokens from the speculation
|
// the accepted tokens from the speculation
|
||||||
@ -2372,7 +2397,7 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
SRV_DBG("accepted %d/%d draft tokens\n", (int) ids.size() - 1, (int) draft.size());
|
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,6 +82,37 @@ def test_different_draft_min_draft_max():
|
|||||||
last_content = res.body["content"]
|
last_content = res.body["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_slot_ctx_not_exceeded():
|
||||||
|
global server
|
||||||
|
server.n_ctx = 64
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "Hello " * 56,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_k": 1,
|
||||||
|
"speculative.p_min": 0.0,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert len(res.body["content"]) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_ctx_shift():
|
||||||
|
global server
|
||||||
|
server.n_ctx = 64
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "Hello " * 56,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_k": 1,
|
||||||
|
"n_predict": 64,
|
||||||
|
"speculative.p_min": 0.0,
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert len(res.body["content"]) > 0
|
||||||
|
assert res.body["tokens_predicted"] == 64
|
||||||
|
assert res.body["truncated"] == True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_slots,n_requests", [
|
@pytest.mark.parametrize("n_slots,n_requests", [
|
||||||
(1, 2),
|
(1, 2),
|
||||||
(2, 2),
|
(2, 2),
|
||||||
|
Loading…
Reference in New Issue
Block a user