From a1327c71c6cfee3b1697aa6646f52f1de249120b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 20 Sep 2023 09:24:02 +0300 Subject: [PATCH] parallel : rename hot-plug to continuous-batching --- common/common.cpp | 8 ++++---- common/common.h | 2 +- examples/parallel/parallel.cpp | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 8bd006960..303b38240 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -372,8 +372,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.multiline_input = true; } else if (arg == "--simple-io") { params.simple_io = true; - } else if (arg == "--hot-plug") { - params.hot_plug = true; + } else if (arg == "-cb" || arg == "--cont-batching") { + params.cont_batching = true; } else if (arg == "--color") { params.use_color = true; } else if (arg == "--mlock") { @@ -675,7 +675,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --chunks N max number of chunks to process (default: %d, -1 = all)\n", params.n_chunks); printf(" -np N, --parallel N number of parallel sequences to decode (default: %d)\n", params.n_parallel); printf(" -ns N, --sequences N number of sequences to decode (default: %d)\n", params.n_sequences); - printf(" --hot-plug enable hot-plugging of new sequences for decoding (default: disabled)\n"); + printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: disabled)\n"); if (llama_mlock_supported()) { printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n"); } @@ -1270,7 +1270,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); fprintf(stream, "seed: %d # default: -1 (random seed)\n", params.seed); fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); - fprintf(stream, "hot_plug: %s # default: false\n", params.hot_plug ? "true" : "false"); + fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); fprintf(stream, "temp: %f # default: 0.8\n", params.temp); const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + LLAMA_MAX_DEVICES); diff --git a/common/common.h b/common/common.h index 9269a5d36..4218f4698 100644 --- a/common/common.h +++ b/common/common.h @@ -110,7 +110,7 @@ struct gpt_params { bool interactive_first = false; // wait for user input immediately bool multiline_input = false; // reverse the usage of `\` bool simple_io = false; // improves compatibility with subprocesses and limited consoles - bool hot_plug = false; // hot-plug new sequences for decoding + bool cont_batching = false; // insert new sequences for decoding on-the-fly bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool ignore_eos = false; // ignore generated EOS tokens diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index a7b5bad71..4af4d2cd2 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -86,7 +86,7 @@ int main(int argc, char ** argv) { const int32_t n_seq = params.n_sequences; // insert new requests as soon as the previous one is done - const bool hot_plug = params.hot_plug; + const bool cont_batching = params.cont_batching; #ifndef LOG_DISABLE_LOGS log_set_target(log_filename_generator("parallel", "log")); @@ -140,7 +140,7 @@ int main(int argc, char ** argv) { const auto t_main_start = ggml_time_us(); LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__); - LOG_TEE("%s: n_parallel = %d, n_sequences = %d, hot_plug = %d, system tokens = %d\n", __func__, n_clients, n_seq, hot_plug, n_tokens_system); + LOG_TEE("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system); LOG_TEE("\n"); { @@ -208,7 +208,7 @@ int main(int argc, char ** argv) { } } - if (hot_plug || batch_token.empty()) { + if (cont_batching || batch_token.empty()) { for (auto & client : clients) { if (client.seq_id == -1 && g_seq_id < n_seq) { client.seq_id = g_seq_id; @@ -237,9 +237,9 @@ int main(int argc, char ** argv) { client.i_batch = batch_token.size() - 1; g_seq_id += 1; - if (hot_plug) { - //break; - } + //if (cont_batching) { + // break; + //} } } }