parallel : remove question with short answers

This commit is contained in:
Georgi Gerganov 2023-09-19 23:34:30 +03:00
parent 4b5f3cd6bf
commit 8a9aca37c1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -31,17 +31,14 @@ static std::string k_system =
R"(Transcript of a never ending dialog, where the User interacts with an Assistant.
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, what is the temperature outside?
Assistant: It is 72 degrees Fahrenheit.
User: What is the definition of a prime number?
Assistant: A prime number is a number that is divisible only by itself and 1.
User: Recommend a nice restaurant in the area.
Assistant: I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
User: Who is Richard Feynman?
Assistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
User:)";
static std::vector<std::string> k_prompts = {
"What is the meaning of life?",
"What is the population of Europe?",
"List all planets in the Solar System.",
"What is the capital of France?",
"Tell me an interesting fact about llamas.",
"What is the best way to cook a steak?",
"Are you familiar with the Special Theory of Relativity and can you explain it to me?",
@ -74,6 +71,8 @@ struct client {
};
int main(int argc, char ** argv) {
srand(1234);
gpt_params params;
if (gpt_params_parse(argc, argv, params) == false) {
@ -177,6 +176,8 @@ int main(int argc, char ** argv) {
LOG_TEE("\n");
}
LOG_TEE("Processing requests ...\n\n");
while (true) {
uint32_t n_tokens = 0;
@ -192,7 +193,7 @@ int main(int argc, char ** argv) {
batch_token.push_back(client.sampled);
batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded);
batch_seq_id.push_back(client.seq_id);
batch_seq_id.push_back(client.id);
batch_logits.push_back(true);
batch_clients.push_back(&client);
client.n_decoded += 1;
@ -209,7 +210,7 @@ int main(int argc, char ** argv) {
if (hot_plug || batch_token.empty()) {
for (auto & client : clients) {
if (client.seq_id == -1 && g_seq_id < n_seq) {
client.seq_id = client.id;
client.seq_id = g_seq_id;
client.t_start_prompt = ggml_time_us();
client.t_start_gen = 0;
@ -224,7 +225,7 @@ int main(int argc, char ** argv) {
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch_token.push_back(tokens_prompt[i]);
batch_pos.push_back(i + n_tokens_system);
batch_seq_id.push_back(client.seq_id);
batch_seq_id.push_back(client.id);
batch_clients.push_back(&client);
batch_logits.push_back(false);
}
@ -236,7 +237,7 @@ int main(int argc, char ** argv) {
g_seq_id += 1;
if (hot_plug) {
break;
//break;
}
}
}
@ -318,11 +319,11 @@ int main(int argc, char ** argv) {
}
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.seq_id, n_tokens_system, n_ctx);
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx);
const auto t_main_end = ggml_time_us();
LOG_TEE("\033[1mClient %2d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
LOG_TEE("\033[1mClient %3d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed: PP %5.2f t/s, TG %5.2f t/s, AVG %5.2f t/s \033[0m: \n\nInput: %s\nResponse: %s\n\n",
client.id, client.seq_id, client.n_prompt, client.n_decoded,
(t_main_end - client.t_start_prompt) / 1e6,
(double) (client.n_prompt ) / (client.t_start_gen - client.t_start_prompt) * 1e6,