From 466b513851ff8ec73889ce6414b8a15d570f77c7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 18 Sep 2023 21:34:20 +0300 Subject: [PATCH] parallel : disable hot-plug to avoid cache fragmentation --- examples/parallel/parallel.cpp | 87 ++++++++++++++++++++++------------ llama.cpp | 4 ++ 2 files changed, 62 insertions(+), 29 deletions(-) diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 1bb4d497f..23fda9d58 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -28,7 +28,7 @@ static std::string trim(const std::string & str) { } static std::string k_system = R"( -Transcript of a dialog, where the User interacts with an Assistant. +Transcript of a never ending dialog, where the User interacts with an Assistant. The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. User: Hello, what is the temperature outside? @@ -59,6 +59,9 @@ struct client { llama_token sampled; + int64_t t_start_prompt; + int64_t t_start_gen; + int32_t n_prompt = 0; int32_t n_decoded = 0; int32_t i_batch = -1; @@ -133,33 +136,47 @@ int main(int argc, char ** argv) { for (auto & client : clients) { if (client.seq_id == -1) { - client.seq_id = g_seq_id; - client.input = k_prompts[rand() % k_prompts.size()]; - client.prompt = k_system + client.input + "\nAssistant:"; - client.response = ""; - std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0); + continue; + } - std::vector prompt_tokens; - prompt_tokens = ::llama_tokenize(ctx, client.prompt, true); + batch_token.push_back(client.sampled); + batch_pos.push_back(client.n_decoded); + batch_seq_id.push_back(client.seq_id); + batch_clients.push_back(&client); + client.n_decoded += 1; + client.i_batch = batch_token.size() - 1; + } - for (size_t i = 0; i < prompt_tokens.size(); ++i) { - batch_token.push_back(prompt_tokens[i]); - batch_pos.push_back(i); - batch_seq_id.push_back(client.seq_id); - batch_clients.push_back(&client); + if (batch_token.empty()) { + // all sequences have ended - clear the entire KV cache + llama_kv_cache_rm_tokens(ctx, -1, -1); + + for (auto & client : clients) { + if (client.seq_id == -1) { + client.seq_id = g_seq_id; + client.t_start_prompt = ggml_time_us(); + client.t_start_gen = 0; + + client.input = k_prompts[rand() % k_prompts.size()]; + client.prompt = k_system + client.input + "\nAssistant:"; + client.response = ""; + std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0); + + std::vector prompt_tokens; + prompt_tokens = ::llama_tokenize(ctx, client.prompt, true); + + for (size_t i = 0; i < prompt_tokens.size(); ++i) { + batch_token.push_back(prompt_tokens[i]); + batch_pos.push_back(i); + batch_seq_id.push_back(client.seq_id); + batch_clients.push_back(&client); + } + client.n_prompt = prompt_tokens.size(); + client.n_decoded = prompt_tokens.size(); + client.i_batch = batch_token.size() - 1; + + g_seq_id += 1; } - client.n_prompt = prompt_tokens.size(); - client.n_decoded = prompt_tokens.size(); - client.i_batch = batch_token.size() - 1; - - g_seq_id += 1; - } else { - batch_token.push_back(client.sampled); - batch_pos.push_back(client.n_decoded); - batch_seq_id.push_back(client.seq_id); - batch_clients.push_back(&client); - client.n_decoded += 1; - client.i_batch = batch_token.size() - 1; } } @@ -188,6 +205,10 @@ int main(int argc, char ** argv) { const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i); + if (client.t_start_gen == 0) { + client.t_start_gen = ggml_time_us(); + } + // remember which tokens were sampled - used for repetition penalties during sampling client.last_tokens.erase(client.last_tokens.begin()); client.last_tokens.push_back(id); @@ -199,7 +220,10 @@ int main(int argc, char ** argv) { //printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n", // client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str()); - if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || client.response.find("User:") != std::string::npos) { + if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || + client.response.find("User:") != std::string::npos || + client.response.find('\n') != std::string::npos) { + // basic reverse prompt const size_t pos = client.response.find("User:"); if (pos != std::string::npos) { client.response = client.response.substr(0, pos); @@ -211,13 +235,18 @@ int main(int argc, char ** argv) { n_tokens_total += client.n_decoded - client.n_prompt; - printf("\033[1mClient %d, seq %d, prompt %d t, response %d t, speed: %.2f t/s\033[0m: \n\nInput: %s\nResponse: %s\n\n", + printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f, AVG %5.2f \033[0m: \n\nInput: %s\nResponse: %s\n\n", client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt, - (double) n_tokens_total / (t_main_end - t_main_start) * 1e6, - client.input.c_str(), ::trim(client.response).c_str()); + (double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6, + (double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6, + (double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6, + ::trim(client.input).c_str(), + ::trim(client.response).c_str()); client.seq_id = -1; } + + client.i_batch = -1; } } diff --git a/llama.cpp b/llama.cpp index 875fd5227..f56ecc272 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2606,6 +2606,8 @@ static struct ggml_cgraph * llm_build_llama( const int32_t n_tokens = batch.n_tokens; const int32_t n_kv = llama_kv_cache_cell_max(kv_self); + //printf("n_kv = %d\n", n_kv); + const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure(lctx.alloc); auto & buf_compute = lctx.buf_compute; @@ -4052,6 +4054,8 @@ static bool llama_eval_internal( batch.seq_id = seq_id.data(); } + kv_self.head = 0; + if (!llama_kv_cache_find_slot(kv_self, batch)) { return false; }