server : chunked prefill support

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-12-08 09:48:18 +02:00
parent 62e84d9848
commit a6648b9df7
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -2418,6 +2418,14 @@ struct server_context {
int32_t n_batch = llama_n_batch(ctx); int32_t n_batch = llama_n_batch(ctx);
int32_t n_ubatch = llama_n_ubatch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx);
// there are currently slots with ongoing text generation
const bool is_tg = batch.n_tokens > 0;
// limit the batch to avoid blocking the processing
if (is_tg) {
n_batch = 32; // TODO: configurable
}
// track if this is an embedding or non-embedding batch // track if this is an embedding or non-embedding batch
// if we've added sampled tokens above, we are in non-embedding mode // if we've added sampled tokens above, we are in non-embedding mode
// -1: none, 0: non-embedding, 1: embedding // -1: none, 0: non-embedding, 1: embedding
@ -2426,6 +2434,18 @@ struct server_context {
// next, batch any pending prompts without exceeding n_batch // next, batch any pending prompts without exceeding n_batch
if (params_base.cont_batching || batch.n_tokens == 0) { if (params_base.cont_batching || batch.n_tokens == 0) {
// count how many slots are currently processing prompt
int n_slots_pp = 0;
for (auto & slot : slots) {
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
n_slots_pp++;
}
}
// determine the chunk size of the chunk prefill
// a slot cannot submit more than this number of tokens in a single batch if other slots are processing
const int32_t n_chunk_pp = std::max(n_slots_pp > 0 ? (n_batch / n_slots_pp) : n_batch, 8);
for (auto & slot : slots) { for (auto & slot : slots) {
// this slot still has a prompt to be processed // this slot still has a prompt to be processed
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
@ -2609,8 +2629,10 @@ struct server_context {
// remove the non-common part from the cache // remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past); slot.cache_tokens.resize(slot.n_past);
int n_cur = 0;
// add prompt tokens for processing in the current batch // add prompt tokens for processing in the current batch
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch && n_cur < n_chunk_pp) {
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false); common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
if (slot.params.cache_prompt) { if (slot.params.cache_prompt) {
@ -2619,6 +2641,8 @@ struct server_context {
slot.n_prompt_tokens_processed++; slot.n_prompt_tokens_processed++;
slot.n_past++; slot.n_past++;
n_cur++;
} }
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);