mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-12 19:50:17 +00:00
server : avoid context swaps by shifting the KV cache
This commit is contained in:
parent
ce2d995af2
commit
c5650ed470
@ -381,6 +381,10 @@ struct llama_server_context
|
||||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
n_past = common_part(embd, prompt_tokens);
|
||||
|
||||
// since #3228 we now have to manually manage the KV cache
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
|
||||
|
||||
embd = prompt_tokens;
|
||||
if (n_past == num_prompt_tokens)
|
||||
{
|
||||
@ -411,19 +415,27 @@ struct llama_server_context
|
||||
|
||||
if (embd.size() >= (size_t)params.n_ctx)
|
||||
{
|
||||
// Reset context
|
||||
const int n_left = (params.n_ctx - params.n_keep) / 2;
|
||||
// Shift context
|
||||
|
||||
const int n_left = n_past - params.n_keep - 1;
|
||||
const int n_discard = n_left/2;
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
|
||||
for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
|
||||
{
|
||||
embd[i - n_discard] = embd[i];
|
||||
}
|
||||
embd.resize(embd.size() - n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
|
||||
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
|
||||
embd = new_tokens;
|
||||
n_past = params.n_keep;
|
||||
truncated = true;
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", params.n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
});
|
||||
}
|
||||
|
||||
@ -435,9 +447,6 @@ struct llama_server_context
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
|
||||
// since #3228 we now have to manually manage the KV cache
|
||||
llama_kv_cache_tokens_rm(ctx, n_past, -1);
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
|
||||
{
|
||||
LOG_ERROR("failed to eval", {
|
||||
|
Loading…
Reference in New Issue
Block a user