mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
parallel : disable hot-plug to avoid cache fragmentation
This commit is contained in:
parent
0161372b9a
commit
466b513851
@ -28,7 +28,7 @@ static std::string trim(const std::string & str) {
|
||||
}
|
||||
|
||||
static std::string k_system = R"(
|
||||
Transcript of a dialog, where the User interacts with an Assistant.
|
||||
Transcript of a never ending dialog, where the User interacts with an Assistant.
|
||||
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
|
||||
|
||||
User: Hello, what is the temperature outside?
|
||||
@ -59,6 +59,9 @@ struct client {
|
||||
|
||||
llama_token sampled;
|
||||
|
||||
int64_t t_start_prompt;
|
||||
int64_t t_start_gen;
|
||||
|
||||
int32_t n_prompt = 0;
|
||||
int32_t n_decoded = 0;
|
||||
int32_t i_batch = -1;
|
||||
@ -133,33 +136,47 @@ int main(int argc, char ** argv) {
|
||||
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1) {
|
||||
client.seq_id = g_seq_id;
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = k_system + client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
|
||||
continue;
|
||||
}
|
||||
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
|
||||
batch_token.push_back(client.sampled);
|
||||
batch_pos.push_back(client.n_decoded);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_clients.push_back(&client);
|
||||
client.n_decoded += 1;
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
|
||||
batch_token.push_back(prompt_tokens[i]);
|
||||
batch_pos.push_back(i);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_clients.push_back(&client);
|
||||
if (batch_token.empty()) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
llama_kv_cache_rm_tokens(ctx, -1, -1);
|
||||
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1) {
|
||||
client.seq_id = g_seq_id;
|
||||
client.t_start_prompt = ggml_time_us();
|
||||
client.t_start_gen = 0;
|
||||
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = k_system + client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
|
||||
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
|
||||
|
||||
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
|
||||
batch_token.push_back(prompt_tokens[i]);
|
||||
batch_pos.push_back(i);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_clients.push_back(&client);
|
||||
}
|
||||
client.n_prompt = prompt_tokens.size();
|
||||
client.n_decoded = prompt_tokens.size();
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
|
||||
g_seq_id += 1;
|
||||
}
|
||||
client.n_prompt = prompt_tokens.size();
|
||||
client.n_decoded = prompt_tokens.size();
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
|
||||
g_seq_id += 1;
|
||||
} else {
|
||||
batch_token.push_back(client.sampled);
|
||||
batch_pos.push_back(client.n_decoded);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_clients.push_back(&client);
|
||||
client.n_decoded += 1;
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,6 +205,10 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
|
||||
|
||||
if (client.t_start_gen == 0) {
|
||||
client.t_start_gen = ggml_time_us();
|
||||
}
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
client.last_tokens.erase(client.last_tokens.begin());
|
||||
client.last_tokens.push_back(id);
|
||||
@ -199,7 +220,10 @@ int main(int argc, char ** argv) {
|
||||
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
|
||||
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
|
||||
|
||||
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || client.response.find("User:") != std::string::npos) {
|
||||
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict ||
|
||||
client.response.find("User:") != std::string::npos ||
|
||||
client.response.find('\n') != std::string::npos) {
|
||||
// basic reverse prompt
|
||||
const size_t pos = client.response.find("User:");
|
||||
if (pos != std::string::npos) {
|
||||
client.response = client.response.substr(0, pos);
|
||||
@ -211,13 +235,18 @@ int main(int argc, char ** argv) {
|
||||
|
||||
n_tokens_total += client.n_decoded - client.n_prompt;
|
||||
|
||||
printf("\033[1mClient %d, seq %d, prompt %d t, response %d t, speed: %.2f t/s\033[0m: \n\nInput: %s\nResponse: %s\n\n",
|
||||
printf("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, speed: PP %5.2f t/s, TG %5.2f, AVG %5.2f \033[0m: \n\nInput: %s\nResponse: %s\n\n",
|
||||
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
|
||||
(double) n_tokens_total / (t_main_end - t_main_start) * 1e6,
|
||||
client.input.c_str(), ::trim(client.response).c_str());
|
||||
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,
|
||||
(double) (client.n_decoded - client.n_prompt) / (t_main_end - client.t_start_gen) * 1e6,
|
||||
(double) (client.n_decoded ) / (t_main_end - client.t_start_prompt) * 1e6,
|
||||
::trim(client.input).c_str(),
|
||||
::trim(client.response).c_str());
|
||||
|
||||
client.seq_id = -1;
|
||||
}
|
||||
|
||||
client.i_batch = -1;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2606,6 +2606,8 @@ static struct ggml_cgraph * llm_build_llama(
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
|
||||
|
||||
//printf("n_kv = %d\n", n_kv);
|
||||
|
||||
const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure(lctx.alloc);
|
||||
|
||||
auto & buf_compute = lctx.buf_compute;
|
||||
@ -4052,6 +4054,8 @@ static bool llama_eval_internal(
|
||||
batch.seq_id = seq_id.data();
|
||||
}
|
||||
|
||||
kv_self.head = 0;
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
return false;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user