mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 11:40:17 +00:00
server : remove self-extend features (#9860)
* server : remove self-extend ggml-ci * server : fix context limit check to use slot.n_past ggml-ci
This commit is contained in:
parent
95c76e8e92
commit
1bde94dd02
@ -1163,14 +1163,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||||||
[](common_params & params, int value) {
|
[](common_params & params, int value) {
|
||||||
params.grp_attn_n = value;
|
params.grp_attn_n = value;
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_GRP_ATTN_N"));
|
).set_env("LLAMA_ARG_GRP_ATTN_N").set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_PASSKEY}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-gaw", "--grp-attn-w"}, "N",
|
{"-gaw", "--grp-attn-w"}, "N",
|
||||||
string_format("group-attention width (default: %.1f)", (double)params.grp_attn_w),
|
string_format("group-attention width (default: %d)", params.grp_attn_w),
|
||||||
[](common_params & params, int value) {
|
[](common_params & params, int value) {
|
||||||
params.grp_attn_w = value;
|
params.grp_attn_w = value;
|
||||||
}
|
}
|
||||||
).set_env("LLAMA_ARG_GRP_ATTN_W"));
|
).set_env("LLAMA_ARG_GRP_ATTN_W").set_examples({LLAMA_EXAMPLE_MAIN}));
|
||||||
add_opt(common_arg(
|
add_opt(common_arg(
|
||||||
{"-dkvc", "--dump-kv-cache"},
|
{"-dkvc", "--dump-kv-cache"},
|
||||||
"verbose print of the KV cache",
|
"verbose print of the KV cache",
|
||||||
|
@ -60,8 +60,6 @@ The project is under active development, and we are [looking for feedback and co
|
|||||||
| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)<br/>(env: LLAMA_ARG_YARN_ATTN_FACTOR) |
|
| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)<br/>(env: LLAMA_ARG_YARN_ATTN_FACTOR) |
|
||||||
| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)<br/>(env: LLAMA_ARG_YARN_BETA_SLOW) |
|
| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)<br/>(env: LLAMA_ARG_YARN_BETA_SLOW) |
|
||||||
| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)<br/>(env: LLAMA_ARG_YARN_BETA_FAST) |
|
| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)<br/>(env: LLAMA_ARG_YARN_BETA_FAST) |
|
||||||
| `-gan, --grp-attn-n N` | group-attention factor (default: 1)<br/>(env: LLAMA_ARG_GRP_ATTN_N) |
|
|
||||||
| `-gaw, --grp-attn-w N` | group-attention width (default: 512.0)<br/>(env: LLAMA_ARG_GRP_ATTN_W) |
|
|
||||||
| `-dkvc, --dump-kv-cache` | verbose print of the KV cache |
|
| `-dkvc, --dump-kv-cache` | verbose print of the KV cache |
|
||||||
| `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
|
| `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
|
||||||
| `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
|
| `-ctk, --cache-type-k TYPE` | KV cache data type for K (default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
|
||||||
|
@ -193,12 +193,6 @@ struct server_slot {
|
|||||||
|
|
||||||
llama_token sampled;
|
llama_token sampled;
|
||||||
|
|
||||||
int32_t ga_i = 0; // group-attention state
|
|
||||||
int32_t ga_n = 1; // group-attention factor
|
|
||||||
int32_t ga_w = 512; // group-attention width
|
|
||||||
|
|
||||||
int32_t n_past_se = 0; // self-extend
|
|
||||||
|
|
||||||
// stats
|
// stats
|
||||||
size_t n_sent_text = 0; // number of sent text character
|
size_t n_sent_text = 0; // number of sent text character
|
||||||
size_t n_sent_token_probs = 0;
|
size_t n_sent_token_probs = 0;
|
||||||
@ -225,8 +219,6 @@ struct server_slot {
|
|||||||
n_sent_text = 0;
|
n_sent_text = 0;
|
||||||
n_sent_token_probs = 0;
|
n_sent_token_probs = 0;
|
||||||
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
||||||
ga_i = 0;
|
|
||||||
n_past_se = 0;
|
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
}
|
}
|
||||||
@ -705,22 +697,6 @@ struct server_context {
|
|||||||
|
|
||||||
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
||||||
|
|
||||||
const int ga_n = params.grp_attn_n;
|
|
||||||
const int ga_w = params.grp_attn_w;
|
|
||||||
|
|
||||||
if (ga_n != 1) {
|
|
||||||
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
|
|
||||||
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
|
||||||
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
|
||||||
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
|
||||||
|
|
||||||
SLT_INF(slot, "slot self-extend: ga_n = %d, ga_w = %d\n", ga_n, ga_w);
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.ga_i = 0;
|
|
||||||
slot.ga_n = ga_n;
|
|
||||||
slot.ga_w = ga_w;
|
|
||||||
|
|
||||||
slot.sparams = params.sparams;
|
slot.sparams = params.sparams;
|
||||||
|
|
||||||
slot.callback_on_release = [this](int) {
|
slot.callback_on_release = [this](int) {
|
||||||
@ -916,11 +892,6 @@ struct server_context {
|
|||||||
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.params.cache_prompt && slot.ga_n != 1) {
|
|
||||||
slot.params.cache_prompt = false;
|
|
||||||
SLT_WRN(slot, "%s", "group-attention is not supported with prompt caching. disabling cache\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
||||||
// Might be better to reject the request with a 400 ?
|
// Might be better to reject the request with a 400 ?
|
||||||
slot.params.n_predict = slot.n_predict;
|
slot.params.n_predict = slot.n_predict;
|
||||||
@ -1131,12 +1102,13 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// if context shift is disabled, we stop when it reaches the context limit
|
// if context shift is disabled, we stop when it reaches the context limit
|
||||||
if (slot.n_decoded >= slot.n_ctx) {
|
if (slot.n_past >= slot.n_ctx) {
|
||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
slot.stopped_limit = true;
|
slot.stopped_limit = true;
|
||||||
slot.has_next_token = false;
|
slot.has_next_token = false;
|
||||||
|
|
||||||
SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
|
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
|
||||||
|
slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_token_is_eog(model, result.tok)) {
|
if (llama_token_is_eog(model, result.tok)) {
|
||||||
@ -1148,13 +1120,13 @@ struct server_context {
|
|||||||
|
|
||||||
const auto n_ctx_train = llama_n_ctx_train(model);
|
const auto n_ctx_train = llama_n_ctx_train(model);
|
||||||
|
|
||||||
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
slot.stopped_limit = true;
|
slot.stopped_limit = true;
|
||||||
slot.has_next_token = false; // stop prediction
|
slot.has_next_token = false; // stop prediction
|
||||||
|
|
||||||
SLT_WRN(slot,
|
SLT_WRN(slot,
|
||||||
"n_predict (%d) is not set and self-context extend is disabled. "
|
"n_predict (%d) is set for infinite generation. "
|
||||||
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
|
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
|
||||||
slot.params.n_predict, n_ctx_train);
|
slot.params.n_predict, n_ctx_train);
|
||||||
}
|
}
|
||||||
@ -1826,8 +1798,7 @@ struct server_context {
|
|||||||
// apply context-shift if needed
|
// apply context-shift if needed
|
||||||
// TODO: simplify and improve
|
// TODO: simplify and improve
|
||||||
for (server_slot & slot : slots) {
|
for (server_slot & slot : slots) {
|
||||||
if (slot.ga_n == 1) {
|
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
|
||||||
if (slot.is_processing() && slot.n_past >= slot.n_ctx - 1) {
|
|
||||||
if (!params.ctx_shift) {
|
if (!params.ctx_shift) {
|
||||||
// this check is redundant (for good)
|
// this check is redundant (for good)
|
||||||
// we should never get here, because generation should already stopped in process_token()
|
// we should never get here, because generation should already stopped in process_token()
|
||||||
@ -1859,7 +1830,6 @@ struct server_context {
|
|||||||
slot.truncated = true;
|
slot.truncated = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// start populating the batch for this iteration
|
// start populating the batch for this iteration
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
@ -1872,9 +1842,7 @@ struct server_context {
|
|||||||
|
|
||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = batch.n_tokens;
|
||||||
|
|
||||||
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id + 1 }, true);
|
||||||
|
|
||||||
common_batch_add(batch, slot.sampled, slot_npast, { slot.id + 1 }, true);
|
|
||||||
|
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
|
|
||||||
@ -1993,6 +1961,8 @@ struct server_context {
|
|||||||
} else {
|
} else {
|
||||||
if (!params.ctx_shift) {
|
if (!params.ctx_shift) {
|
||||||
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
||||||
|
// TODO: there should be a separate parameter that control prompt truncation
|
||||||
|
// context shift should be applied only during the generation phase
|
||||||
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||||
slot.release();
|
slot.release();
|
||||||
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
||||||
@ -2005,7 +1975,7 @@ struct server_context {
|
|||||||
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
||||||
|
|
||||||
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
// if input prompt is too big, truncate it (if group attention self-extend is disabled)
|
||||||
if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
|
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
||||||
const int n_left = slot.n_ctx - slot.params.n_keep;
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
||||||
|
|
||||||
const int n_block_size = n_left / 2;
|
const int n_block_size = n_left / 2;
|
||||||
@ -2032,12 +2002,7 @@ struct server_context {
|
|||||||
|
|
||||||
common_sampler_reset(slot.smpl);
|
common_sampler_reset(slot.smpl);
|
||||||
|
|
||||||
if (!slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
slot.n_past_se = 0;
|
|
||||||
slot.ga_i = 0;
|
|
||||||
} else {
|
|
||||||
GGML_ASSERT(slot.ga_n == 1);
|
|
||||||
|
|
||||||
// reuse any previously computed tokens that are common with the new prompt
|
// reuse any previously computed tokens that are common with the new prompt
|
||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||||
|
|
||||||
@ -2053,9 +2018,6 @@ struct server_context {
|
|||||||
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
|
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
|
||||||
|
|
||||||
slot.n_past--;
|
slot.n_past--;
|
||||||
if (slot.ga_i > 0) {
|
|
||||||
slot.n_past_se--;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed = 0;
|
slot.n_prompt_tokens_processed = 0;
|
||||||
@ -2081,52 +2043,31 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// keep only the common part
|
// keep only the common part
|
||||||
int p0 = slot.n_past;
|
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, slot.n_past, -1)) {
|
||||||
|
|
||||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
|
||||||
// could not partially delete (likely using a non-Transformer model)
|
// could not partially delete (likely using a non-Transformer model)
|
||||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
||||||
|
|
||||||
p0 = 0;
|
|
||||||
|
|
||||||
// there is no common part left
|
// there is no common part left
|
||||||
slot.n_past = 0;
|
slot.n_past = 0;
|
||||||
slot.n_past_se = 0;
|
|
||||||
slot.ga_i = 0;
|
|
||||||
|
|
||||||
common_sampler_reset(slot.smpl);
|
common_sampler_reset(slot.smpl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
||||||
|
|
||||||
// remove the non-common part from the cache
|
// remove the non-common part from the cache
|
||||||
slot.cache_tokens.resize(slot.n_past);
|
slot.cache_tokens.resize(slot.n_past);
|
||||||
|
|
||||||
SLT_INF(slot, "kv cache rm [%d, end)\n", p0);
|
|
||||||
|
|
||||||
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
|
||||||
|
|
||||||
int32_t ga_i = slot.ga_i;
|
|
||||||
int32_t ga_n = slot.ga_n;
|
|
||||||
int32_t ga_w = slot.ga_w;
|
|
||||||
|
|
||||||
// add prompt tokens for processing in the current batch
|
// add prompt tokens for processing in the current batch
|
||||||
// TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
|
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
||||||
for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
|
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id + 1 }, false);
|
||||||
if (slot.ga_n != 1) {
|
|
||||||
while (slot_npast >= ga_i + ga_w) {
|
|
||||||
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
|
||||||
slot_npast -= bd;
|
|
||||||
ga_i += ga_w/ga_n;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
common_batch_add(batch, prompt_tokens[slot.n_past], slot_npast, { slot.id + 1 }, false);
|
|
||||||
|
|
||||||
if (slot.params.cache_prompt) {
|
if (slot.params.cache_prompt) {
|
||||||
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.n_prompt_tokens_processed++;
|
slot.n_prompt_tokens_processed++;
|
||||||
slot_npast++;
|
slot.n_past++;
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
||||||
@ -2167,34 +2108,6 @@ struct server_context {
|
|||||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||||
|
|
||||||
for (auto & slot : slots) {
|
|
||||||
if (slot.ga_n != 1) {
|
|
||||||
// context extension via Self-Extend
|
|
||||||
// TODO: simplify and/or abstract this
|
|
||||||
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
|
|
||||||
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
|
||||||
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
|
||||||
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
|
||||||
|
|
||||||
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
|
||||||
SLT_DBG(slot, "div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
|
||||||
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
|
||||||
|
|
||||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
|
||||||
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
|
||||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
|
||||||
|
|
||||||
slot.n_past_se -= bd;
|
|
||||||
|
|
||||||
slot.ga_i += slot.ga_w / slot.ga_n;
|
|
||||||
|
|
||||||
SLT_DBG(slot, "\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
|
||||||
}
|
|
||||||
|
|
||||||
slot.n_past_se += n_tokens;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_batch batch_view = {
|
llama_batch batch_view = {
|
||||||
n_tokens,
|
n_tokens,
|
||||||
batch.token + i,
|
batch.token + i,
|
||||||
|
@ -13,6 +13,10 @@ Feature: llama.cpp server
|
|||||||
And 32 as batch size
|
And 32 as batch size
|
||||||
And 2 slots
|
And 2 slots
|
||||||
|
|
||||||
|
# the prompt is 301 tokens
|
||||||
|
# the slot context is 256/2 = 128 tokens
|
||||||
|
# the prompt is truncated to keep the last 109 tokens
|
||||||
|
# 64 tokens are generated thanks to shifting the context when it gets full
|
||||||
Scenario: Inference with context shift
|
Scenario: Inference with context shift
|
||||||
And 64 server max tokens to predict
|
And 64 server max tokens to predict
|
||||||
Then the server is starting
|
Then the server is starting
|
||||||
|
Loading…
Reference in New Issue
Block a user