diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1c21e55aa..e80e31cbe 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2418,6 +2418,14 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); + // there are currently slots with ongoing text generation + const bool is_tg = batch.n_tokens > 0; + + // limit the batch to avoid blocking the processing + if (is_tg) { + n_batch = 32; // TODO: configurable + } + // track if this is an embedding or non-embedding batch // if we've added sampled tokens above, we are in non-embedding mode // -1: none, 0: non-embedding, 1: embedding @@ -2426,6 +2434,18 @@ struct server_context { // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { + // count how many slots are currently processing prompt + int n_slots_pp = 0; + for (auto & slot : slots) { + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { + n_slots_pp++; + } + } + + // determine the chunk size of the chunk prefill + // a slot cannot submit more than this number of tokens in a single batch if other slots are processing + const int32_t n_chunk_pp = std::max(n_slots_pp > 0 ? (n_batch / n_slots_pp) : n_batch, 8); + for (auto & slot : slots) { // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { @@ -2609,8 +2629,10 @@ struct server_context { // remove the non-common part from the cache slot.cache_tokens.resize(slot.n_past); + int n_cur = 0; + // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch && n_cur < n_chunk_pp) { common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false); if (slot.params.cache_prompt) { @@ -2619,6 +2641,8 @@ struct server_context { slot.n_prompt_tokens_processed++; slot.n_past++; + + n_cur++; } SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);