mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 19:21:46 +00:00
server : fix crash when system prompt is bigger than batch size (#5714)
The system prompt is now decoded in batches. * server : fix off-by-one n_past when start of prompt matches whole cache The tokens right after the matching part would otherwise skip a pos value.
This commit is contained in:
parent
abbabc5e51
commit
f7625019c5
@ -902,11 +902,25 @@ struct llama_server_context
|
|||||||
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0)
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += params.n_batch)
|
||||||
|
{
|
||||||
|
const int32_t n_tokens = std::min(params.n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
|
llama_batch batch_view = {
|
||||||
|
n_tokens,
|
||||||
|
batch.token + i,
|
||||||
|
nullptr,
|
||||||
|
batch.pos + i,
|
||||||
|
batch.n_seq_id + i,
|
||||||
|
batch.seq_id + i,
|
||||||
|
batch.logits + i,
|
||||||
|
0, 0, 0, // unused
|
||||||
|
};
|
||||||
|
if (llama_decode(ctx, batch_view) != 0)
|
||||||
{
|
{
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// assign the system KV cache to all parallel sequences
|
// assign the system KV cache to all parallel sequences
|
||||||
for (int32_t i = 1; i < params.n_parallel; ++i)
|
for (int32_t i = 1; i < params.n_parallel; ++i)
|
||||||
@ -1785,6 +1799,14 @@ struct llama_server_context
|
|||||||
}
|
}
|
||||||
|
|
||||||
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
|
||||||
|
|
||||||
|
// the last token of the cache is not in the KV cache until the next call to llama_decode
|
||||||
|
// (it was sampled, pushed into the "cache_tokens", but not yet put in the context)
|
||||||
|
if (slot.n_past > 0 && slot.n_past == (int32_t) slot.cache_tokens.size())
|
||||||
|
{
|
||||||
|
slot.n_past -= 1;
|
||||||
|
}
|
||||||
|
|
||||||
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
slot.num_prompt_tokens_processed = slot.num_prompt_tokens - slot.n_past;
|
||||||
|
|
||||||
if (slot.ga_n != 1)
|
if (slot.ga_n != 1)
|
||||||
|
Loading…
Reference in New Issue
Block a user