parallel : example for serving multiple users in parallel

This commit is contained in:
Georgi Gerganov 2023-09-18 20:30:05 +03:00
parent 1f17ea631c
commit 0161372b9a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
9 changed files with 262 additions and 13 deletions

View File

@ -454,8 +454,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
if (params.logdir.back() != DIRECTORY_SEPARATOR) { if (params.logdir.back() != DIRECTORY_SEPARATOR) {
params.logdir += DIRECTORY_SEPARATOR; params.logdir += DIRECTORY_SEPARATOR;
} }
} else if (arg == "--perplexity") { } else if (arg == "--perplexity" || arg == "--all-logits") {
params.perplexity = true; params.logits_all = true;
} else if (arg == "--ppl-stride") { } else if (arg == "--ppl-stride") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -653,7 +653,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n"); printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
printf(" not recommended: doubles context memory required and no measurable increase in quality\n"); printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
printf(" --temp N temperature (default: %.1f)\n", (double)params.temp); printf(" --temp N temperature (default: %.1f)\n", (double)params.temp);
printf(" --perplexity compute perplexity over each ctx window of the prompt\n"); printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n"); printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks); printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep); printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
@ -735,7 +735,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
lparams.f16_kv = params.memory_f16; lparams.f16_kv = params.memory_f16;
lparams.use_mmap = params.use_mmap; lparams.use_mmap = params.use_mmap;
lparams.use_mlock = params.use_mlock; lparams.use_mlock = params.use_mlock;
lparams.logits_all = params.perplexity; lparams.logits_all = params.logits_all;
lparams.embedding = params.embedding; lparams.embedding = params.embedding;
lparams.rope_freq_base = params.rope_freq_base; lparams.rope_freq_base = params.rope_freq_base;
lparams.rope_freq_scale = params.rope_freq_scale; lparams.rope_freq_scale = params.rope_freq_scale;

View File

@ -113,7 +113,7 @@ struct gpt_params {
bool ignore_eos = false; // ignore generated EOS tokens bool ignore_eos = false; // ignore generated EOS tokens
bool instruct = false; // instruction mode (used for Alpaca models) bool instruct = false; // instruction mode (used for Alpaca models)
bool penalize_nl = true; // consider newlines as a repeatable token bool penalize_nl = true; // consider newlines as a repeatable token
bool perplexity = false; // compute perplexity over the prompt bool logits_all = false; // return logits for all tokens in the batch
bool use_mmap = true; // use mmap for faster loads bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory bool use_mlock = false; // use mlock to keep model in memory
bool numa = false; // attempt optimizations that help on some NUMA systems bool numa = false; // attempt optimizations that help on some NUMA systems

View File

@ -24,6 +24,7 @@ else()
add_subdirectory(convert-llama2c-to-ggml) add_subdirectory(convert-llama2c-to-ggml)
add_subdirectory(simple) add_subdirectory(simple)
add_subdirectory(speculative) add_subdirectory(speculative)
add_subdirectory(parallel)
add_subdirectory(embd-input) add_subdirectory(embd-input)
add_subdirectory(llama-bench) add_subdirectory(llama-bench)
add_subdirectory(beam-search) add_subdirectory(beam-search)

View File

@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
console::init(params.simple_io, params.use_color); console::init(params.simple_io, params.use_color);
atexit([]() { console::cleanup(); }); atexit([]() { console::cleanup(); });
if (params.perplexity) { if (params.logits_all) {
printf("\n************\n"); printf("\n************\n");
printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__); printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
printf("************\n\n"); printf("************\n\n");

View File

@ -0,0 +1,8 @@
set(TARGET parallel)
add_executable(${TARGET} parallel.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

View File

@ -0,0 +1,244 @@
// A basic application simulating a server with multiple clients.
// The clients submite requests to the server and they are processed in parallel.
#include "build-info.h"
#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
// trim whitespace from the beginning and end of a string
static std::string trim(const std::string & str) {
size_t start = 0;
size_t end = str.size();
while (start < end && isspace(str[start])) {
start += 1;
}
while (end > start && isspace(str[end - 1])) {
end -= 1;
}
return str.substr(start, end - start);
}
static std::string k_system = R"(
Transcript of a dialog, where the User interacts with an Assistant.
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User: Hello, what is the temperature outside?
Assistant: It is 72 degrees Fahrenheit.
User: What is the definition of a prime number?
Assistant: A prime number is a number that is divisible only by itself and 1.
User: )";
static std::vector<std::string> k_prompts = {
"What is the meaning of life?",
"What is the population of Europe?",
"List all planets in the Solar System.",
"What is the capital of France?",
"Tell me an interesting fact about llamas.",
"What is the best way to cook a steak?",
"Are you familiar with the Special Theory of Relativity and can you explain it to me?",
"Recommend some interesting books to read.",
"What is the best way to learn a new language?",
"How to get a job at Google?",
"If you could have any superpower, what would it be?",
"I want to learn how to play the piano.",
};
struct client {
int32_t id = 0;
llama_seq_id seq_id = -1;
llama_token sampled;
int32_t n_prompt = 0;
int32_t n_decoded = 0;
int32_t i_batch = -1;
std::string input;
std::string prompt;
std::string response;
std::vector<llama_token> last_tokens;
};
int main(int argc, char ** argv) {
gpt_params params;
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}
const int n_clients = 16;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("parallel", "log"));
LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv);
#endif // LOG_DISABLE_LOGS
// init llama.cpp
llama_backend_init(params.numa);
llama_model * model = NULL;
llama_context * ctx = NULL;
// load the target model
params.logits_all = true;
std::tie(model, ctx) = llama_init_from_gpt_params(params);
fprintf(stderr, "\n\n");
fflush(stderr);
const int n_ctx = llama_n_ctx(ctx);
const int n_vocab = llama_n_vocab(ctx);
std::vector<client> clients(n_clients);
for (size_t i = 0; i < clients.size(); ++i) {
auto & client = clients[i];
client.id = i;
client.last_tokens.resize(n_ctx);
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
}
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
auto t_main_start = ggml_time_us();
int64_t n_tokens_total = 0;
llama_seq_id g_seq_id = 0;
std::vector<llama_token> batch_token;
std::vector<llama_pos> batch_pos;
std::vector<llama_seq_id> batch_seq_id;
std::vector<client *> batch_clients;
while (true) {
uint32_t n_tokens = 0;
batch_token.clear();
batch_pos.clear();
batch_seq_id.clear();
for (auto & client : clients) {
if (client.seq_id == -1) {
client.seq_id = g_seq_id;
client.input = k_prompts[rand() % k_prompts.size()];
client.prompt = k_system + client.input + "\nAssistant:";
client.response = "";
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
std::vector<llama_token> prompt_tokens;
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
batch_token.push_back(prompt_tokens[i]);
batch_pos.push_back(i);
batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client);
}
client.n_prompt = prompt_tokens.size();
client.n_decoded = prompt_tokens.size();
client.i_batch = batch_token.size() - 1;
g_seq_id += 1;
} else {
batch_token.push_back(client.sampled);
batch_pos.push_back(client.n_decoded);
batch_seq_id.push_back(client.seq_id);
batch_clients.push_back(&client);
client.n_decoded += 1;
client.i_batch = batch_token.size() - 1;
}
}
// process in chunks of params.n_batch
for (size_t i = 0; i < batch_token.size(); i += params.n_batch) {
n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i));
llama_batch batch = {
n_tokens,
batch_token.data() + i,
nullptr,
batch_pos.data() + i,
batch_seq_id.data() + i,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch, params.n_threads)) {
LOG_TEE("%s : failed to decode batch\n", __func__);
return 1;
}
for (auto & client : clients) {
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
continue;
}
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
// remember which tokens were sampled - used for repetition penalties during sampling
client.last_tokens.erase(client.last_tokens.begin());
client.last_tokens.push_back(id);
const std::string token_str = llama_token_to_piece(ctx, id);
client.response += token_str;
client.sampled = id;
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || client.response.find("User:") != std::string::npos) {
const size_t pos = client.response.find("User:");
if (pos != std::string::npos) {
client.response = client.response.substr(0, pos);
}
llama_kv_cache_rm_seq(ctx, client.seq_id, 0, n_ctx);
const auto t_main_end = ggml_time_us();
n_tokens_total += client.n_decoded - client.n_prompt;
printf("\033[1mClient %d, seq %d, prompt %d t, response %d t, speed: %.2f t/s\033[0m: \n\nInput: %s\nResponse: %s\n\n",
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
(double) n_tokens_total / (t_main_end - t_main_start) * 1e6,
client.input.c_str(), ::trim(client.response).c_str());
client.seq_id = -1;
}
}
}
static bool is_first = true;
if (is_first) {
t_main_start = ggml_time_us();
n_tokens_total = 0;
is_first = false;
}
}
LOG_TEE("\n\n");
llama_print_timings(ctx);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
fprintf(stderr, "\n\n");
return 0;
}

View File

@ -681,7 +681,7 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
params.perplexity = true; params.logits_all = true;
params.n_batch = std::min(params.n_batch, params.n_ctx); params.n_batch = std::min(params.n_batch, params.n_ctx);
if (params.ppl_stride > 0) { if (params.ppl_stride > 0) {

View File

@ -37,7 +37,7 @@ int main(int argc, char ** argv) {
llama_context * ctx_dft = NULL; llama_context * ctx_dft = NULL;
// load the target model // load the target model
params.perplexity = true; // HACK: enable logits_all = true params.logits_all = true;
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params); std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
// load the draft model // load the draft model
@ -172,7 +172,6 @@ int main(int argc, char ** argv) {
LOG("out of drafted tokens\n"); LOG("out of drafted tokens\n");
} }
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads); llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
++n_past_dft; ++n_past_dft;
@ -218,7 +217,6 @@ int main(int argc, char ** argv) {
// sample n_draft tokens from the draft model using greedy decoding // sample n_draft tokens from the draft model using greedy decoding
int n_past_cur = n_past_dft; int n_past_cur = n_past_dft;
for (int i = 0; i < n_draft; ++i) { for (int i = 0; i < n_draft; ++i) {
float * logits = llama_get_logits(ctx_dft); float * logits = llama_get_logits(ctx_dft);
@ -258,7 +256,6 @@ int main(int argc, char ** argv) {
} }
// evaluate the drafted token on the draft model // evaluate the drafted token on the draft model
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads); llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
++n_past_cur; ++n_past_cur;
@ -268,7 +265,6 @@ int main(int argc, char ** argv) {
} }
// evaluate the target model on the drafted tokens // evaluate the target model on the drafted tokens
llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads); llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
++n_past_tgt; ++n_past_tgt;

View File

@ -6673,7 +6673,7 @@ struct llama_context * llama_new_context_with_model(
ctx->alloc = ggml_allocr_new_measure(tensor_alignment); ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
// build worst-case graph // build worst-case graph
uint32_t n_tokens = std::min((int)hparams.n_ctx, params.n_batch); uint32_t n_tokens = std::max((int)hparams.n_ctx, params.n_batch);
llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, 0, 0)); ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, 0, 0));