parallel : disable hot-plug to avoid cache fragmentation

This commit is contained in:
Georgi Gerganov 2023-09-18 21:34:20 +03:00
parent 0161372b9a
commit 466b513851
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 62 additions and 29 deletions

View File

@ -28,7 +28,7 @@ static std::string trim(const std::string & str) {
}
static std::string k_system = R"(
Transcript of a dialog, where the User interacts with an Assistant.
Transcript of a never ending dialog, where the User interacts with an Assistant.
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, what is the temperature outside?
@ -59,6 +59,9 @@ struct client {
llama_token sampled;
int64_t t_start_prompt;
int64_t t_start_gen;
int32_t n_prompt = 0;
int32_t n_decoded = 0;
int32_t i_batch = -1;
@ -131,9 +134,29 @@ int main(int argc, char ** argv) {
batch_pos.clear();
batch_seq_id.clear();
for (auto & client : clients) {
if (client.seq_id == -1) {
continue;
}
batch_token.push_back(client.sampled);
batch_pos.push_back(client.n_decoded);
batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client);
client.n_decoded += 1;
client.i_batch = batch_token.size() - 1;
}
if (batch_token.empty()) {
// all sequences have ended - clear the entire KV cache
llama_kv_cache_rm_tokens(ctx, -1, -1);
for (auto & client : clients) {
if (client.seq_id == -1) {
client.seq_id = g_seq_id;
client.t_start_prompt = ggml_time_us();
client.t_start_gen = 0;
client.input = k_prompts[rand() % k_prompts.size()];
client.prompt = k_system + client.input + "\nAssistant:";
client.response = "";
@ -153,13 +176,7 @@ int main(int argc, char ** argv) {
client.i_batch = batch_token.size() - 1;
g_seq_id += 1;
} else {
batch_token.push_back(client.sampled);
batch_pos.push_back(client.n_decoded);
batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client);
client.n_decoded += 1;
client.i_batch = batch_token.size() - 1;
}
}
}
@ -188,6 +205,10 @@ int main(int argc, char ** argv) {
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
if (client.t_start_gen == 0) {
client.t_start_gen = ggml_time_us();
}
// remember which tokens were sampled - used for repetition penalties during sampling
client.last_tokens.erase(client.last_tokens.begin());
client.last_tokens.push_back(id);
@ -199,7 +220,10 @@ int main(int argc, char ** argv) {
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || client.response.find("User:") != std::string::npos) {
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict ||
client.response.find("User:") != std::string::npos ||
client.response.find('\n') != std::string::npos) {
// basic reverse prompt
const size_t pos = client.response.find("User:");
if (pos != std::string::npos) {
client.response = client.response.substr(0, pos);
@ -211,13 +235,18 @@ int main(int argc, char ** argv) {
n_tokens_total += client.n_decoded - client.n_prompt;
printf("\033[1mClient %d, seq %d, prompt %d t, response %d t, speed: %.2f t/s\033[0m: \n\nInput: %s\nResponse: %s\n\n",
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f, AVG %5.2f \033[0m: \n\nInput: %s\nResponse: %s\n\n",
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
(double) n_tokens_total / (t_main_end - t_main_start) * 1e6,
client.input.c_str(), ::trim(client.response).c_str());
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
(double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6,
(double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6,
::trim(client.input).c_str(),
::trim(client.response).c_str());
client.seq_id = -1;
}
client.i_batch = -1;
}
}

View File

@ -2606,6 +2606,8 @@ static struct ggml_cgraph * llm_build_llama(
const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
//printf("n_kv = %d\n", n_kv);
const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure(lctx.alloc);
auto & buf_compute = lctx.buf_compute;
@ -4052,6 +4054,8 @@ static bool llama_eval_internal(
batch.seq_id = seq_id.data();
}
kv_self.head = 0;
if (!llama_kv_cache_find_slot(kv_self, batch)) {
return false;
}