mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
server : fix context shift (#5195)
* server : fix context shift + simplify self-extend * server : take system_tokens into account * server : more n_past fixes * server : rever n_past_se changes
This commit is contained in:
parent
4003be0e5f
commit
e6f291d158
@ -48,6 +48,7 @@ chat_completion() {
|
|||||||
top_p: 0.9,
|
top_p: 0.9,
|
||||||
n_keep: $n_keep,
|
n_keep: $n_keep,
|
||||||
n_predict: 256,
|
n_predict: 256,
|
||||||
|
cache_prompt: true,
|
||||||
stop: ["\n### Human:"],
|
stop: ["\n### Human:"],
|
||||||
stream: true
|
stream: true
|
||||||
}')"
|
}')"
|
||||||
|
@ -220,6 +220,7 @@ struct llama_client_slot
|
|||||||
infill = false;
|
infill = false;
|
||||||
ga_i = 0;
|
ga_i = 0;
|
||||||
n_past_se = 0;
|
n_past_se = 0;
|
||||||
|
|
||||||
generated_token_probs.clear();
|
generated_token_probs.clear();
|
||||||
|
|
||||||
for (slot_image & img : images)
|
for (slot_image & img : images)
|
||||||
@ -1227,7 +1228,7 @@ struct llama_server_context
|
|||||||
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
|
std::vector<llama_token> append_tokens = tokenize(json_prompt, false); // has next image
|
||||||
for (int i = 0; i < (int) append_tokens.size(); ++i)
|
for (int i = 0; i < (int) append_tokens.size(); ++i)
|
||||||
{
|
{
|
||||||
llama_batch_add(batch, append_tokens[i], slot.n_past, { slot.id }, true);
|
llama_batch_add(batch, append_tokens[i], system_tokens.size() + slot.n_past, { slot.id }, true);
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1295,6 +1296,8 @@ struct llama_server_context
|
|||||||
for (llama_client_slot &slot : slots)
|
for (llama_client_slot &slot : slots)
|
||||||
{
|
{
|
||||||
slot.cache_tokens.clear();
|
slot.cache_tokens.clear();
|
||||||
|
slot.n_past = 0;
|
||||||
|
slot.n_past_se = 0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1364,26 +1367,26 @@ struct llama_server_context
|
|||||||
kv_cache_clear();
|
kv_cache_clear();
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} else {
|
}
|
||||||
|
|
||||||
task_server task;
|
task_server task;
|
||||||
task.type = TASK_TYPE_NEXT_RESPONSE;
|
task.type = TASK_TYPE_NEXT_RESPONSE;
|
||||||
task.target_id = -1;
|
task.target_id = -1;
|
||||||
queue_tasks.post(task);
|
queue_tasks.post(task);
|
||||||
}
|
|
||||||
|
|
||||||
for (llama_client_slot &slot : slots)
|
for (llama_client_slot &slot : slots)
|
||||||
{
|
{
|
||||||
if (slot.ga_n == 1)
|
if (slot.ga_n == 1)
|
||||||
{
|
{
|
||||||
if (slot.is_processing() && slot.cache_tokens.size() >= (size_t) slot.n_ctx)
|
if (slot.is_processing() && system_tokens.size() + slot.cache_tokens.size() >= (size_t) slot.n_ctx)
|
||||||
{
|
{
|
||||||
// Shift context
|
// Shift context
|
||||||
const int n_left = slot.n_past - slot.params.n_keep - 1;
|
const int n_left = system_tokens.size() + slot.n_past - slot.params.n_keep - 1;
|
||||||
const int n_discard = n_left / 2;
|
const int n_discard = n_left / 2;
|
||||||
|
|
||||||
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
|
LOG_TEE("slot %d: context shift - n_keep = %d, n_left = %d, n_discard = %d\n", slot.id, slot.params.n_keep, n_left, n_discard);
|
||||||
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
|
llama_kv_cache_seq_rm (ctx, slot.id, slot.params.n_keep + 1 , slot.params.n_keep + n_discard + 1);
|
||||||
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, slot.n_past, -n_discard);
|
llama_kv_cache_seq_shift(ctx, slot.id, slot.params.n_keep + 1 + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||||
|
|
||||||
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
|
for (size_t i = slot.params.n_keep + 1 + n_discard; i < slot.cache_tokens.size(); i++)
|
||||||
{
|
{
|
||||||
@ -1429,8 +1432,10 @@ struct llama_server_context
|
|||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = batch.n_tokens;
|
||||||
|
|
||||||
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||||
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
|
|
||||||
|
|
||||||
|
// TODO: we always have to take into account the "system_tokens"
|
||||||
|
// this is not great and needs to be improved somehow
|
||||||
|
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id }, true);
|
||||||
slot.n_past += 1;
|
slot.n_past += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1591,10 +1596,13 @@ struct llama_server_context
|
|||||||
|
|
||||||
// process the prefix of first image
|
// process the prefix of first image
|
||||||
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
|
std::vector<llama_token> prefix_tokens = has_images ? tokenize(slot.images[0].prefix_prompt, add_bos_token) : prompt_tokens;
|
||||||
|
|
||||||
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
||||||
int ga_i = slot.ga_i;
|
|
||||||
|
int32_t ga_i = slot.ga_i;
|
||||||
int32_t ga_n = slot.ga_n;
|
int32_t ga_n = slot.ga_n;
|
||||||
int32_t ga_w = slot.ga_w;
|
int32_t ga_w = slot.ga_w;
|
||||||
|
|
||||||
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
|
for (; slot.n_past < (int) prefix_tokens.size(); ++slot.n_past)
|
||||||
{
|
{
|
||||||
if (slot.ga_n != 1)
|
if (slot.ga_n != 1)
|
||||||
@ -1606,7 +1614,7 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
|
llama_batch_add(batch, prefix_tokens[slot.n_past], system_tokens.size() + slot_npast, {slot.id }, false);
|
||||||
slot_npast += 1;
|
slot_npast++;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (has_images && !ingest_images(slot, n_batch))
|
if (has_images && !ingest_images(slot, n_batch))
|
||||||
@ -1666,6 +1674,7 @@ struct llama_server_context
|
|||||||
slot.n_past_se += n_tokens;
|
slot.n_past_se += n_tokens;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch batch_view =
|
llama_batch batch_view =
|
||||||
{
|
{
|
||||||
n_tokens,
|
n_tokens,
|
||||||
@ -1818,15 +1827,15 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms,
|
|||||||
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
|
||||||
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n");
|
||||||
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
printf(" -spf FNAME, --system-prompt-file FNAME\n");
|
||||||
printf(" Set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
|
printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
|
||||||
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
|
printf(" --mmproj MMPROJ_FILE path to a multimodal projector file for LLaVA.\n");
|
||||||
printf(" --log-disable disables logging to a file.\n");
|
printf(" --log-disable disables logging to a file.\n");
|
||||||
printf("\n");
|
printf("\n");
|
||||||
printf(" --override-kv KEY=TYPE:VALUE\n");
|
printf(" --override-kv KEY=TYPE:VALUE\n");
|
||||||
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
|
printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
|
||||||
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
|
||||||
printf(" -gan N, --grp-attn-n N Set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
|
printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`");
|
||||||
printf(" -gaw N, --grp-attn-w N Set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
|
printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`");
|
||||||
printf("\n");
|
printf("\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user