mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-29 04:44:34 +00:00
parallel : example for serving multiple users in parallel
This commit is contained in:
parent
1f17ea631c
commit
0161372b9a
@ -454,8 +454,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
||||
if (params.logdir.back() != DIRECTORY_SEPARATOR) {
|
||||
params.logdir += DIRECTORY_SEPARATOR;
|
||||
}
|
||||
} else if (arg == "--perplexity") {
|
||||
params.perplexity = true;
|
||||
} else if (arg == "--perplexity" || arg == "--all-logits") {
|
||||
params.logits_all = true;
|
||||
} else if (arg == "--ppl-stride") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
@ -653,7 +653,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
||||
printf(" --memory-f32 use f32 instead of f16 for memory key+value (default: disabled)\n");
|
||||
printf(" not recommended: doubles context memory required and no measurable increase in quality\n");
|
||||
printf(" --temp N temperature (default: %.1f)\n", (double)params.temp);
|
||||
printf(" --perplexity compute perplexity over each ctx window of the prompt\n");
|
||||
printf(" --logits-all return logits for all tokens in the batch (default: disabled)\n");
|
||||
printf(" --hellaswag compute HellaSwag score over random tasks from datafile supplied with -f\n");
|
||||
printf(" --hellaswag-tasks N number of tasks to use when computing the HellaSwag score (default: %zu)\n", params.hellaswag_tasks);
|
||||
printf(" --keep N number of tokens to keep from the initial prompt (default: %d, -1 = all)\n", params.n_keep);
|
||||
@ -735,7 +735,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
|
||||
lparams.f16_kv = params.memory_f16;
|
||||
lparams.use_mmap = params.use_mmap;
|
||||
lparams.use_mlock = params.use_mlock;
|
||||
lparams.logits_all = params.perplexity;
|
||||
lparams.logits_all = params.logits_all;
|
||||
lparams.embedding = params.embedding;
|
||||
lparams.rope_freq_base = params.rope_freq_base;
|
||||
lparams.rope_freq_scale = params.rope_freq_scale;
|
||||
|
@ -113,7 +113,7 @@ struct gpt_params {
|
||||
bool ignore_eos = false; // ignore generated EOS tokens
|
||||
bool instruct = false; // instruction mode (used for Alpaca models)
|
||||
bool penalize_nl = true; // consider newlines as a repeatable token
|
||||
bool perplexity = false; // compute perplexity over the prompt
|
||||
bool logits_all = false; // return logits for all tokens in the batch
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
bool use_mlock = false; // use mlock to keep model in memory
|
||||
bool numa = false; // attempt optimizations that help on some NUMA systems
|
||||
|
@ -24,6 +24,7 @@ else()
|
||||
add_subdirectory(convert-llama2c-to-ggml)
|
||||
add_subdirectory(simple)
|
||||
add_subdirectory(speculative)
|
||||
add_subdirectory(parallel)
|
||||
add_subdirectory(embd-input)
|
||||
add_subdirectory(llama-bench)
|
||||
add_subdirectory(beam-search)
|
||||
|
@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
|
||||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
if (params.perplexity) {
|
||||
if (params.logits_all) {
|
||||
printf("\n************\n");
|
||||
printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
|
||||
printf("************\n\n");
|
||||
|
8
examples/parallel/CMakeLists.txt
Normal file
8
examples/parallel/CMakeLists.txt
Normal file
@ -0,0 +1,8 @@
|
||||
set(TARGET parallel)
|
||||
add_executable(${TARGET} parallel.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
if(TARGET BUILD_INFO)
|
||||
add_dependencies(${TARGET} BUILD_INFO)
|
||||
endif()
|
244
examples/parallel/parallel.cpp
Normal file
244
examples/parallel/parallel.cpp
Normal file
@ -0,0 +1,244 @@
|
||||
// A basic application simulating a server with multiple clients.
|
||||
// The clients submite requests to the server and they are processed in parallel.
|
||||
|
||||
#include "build-info.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// trim whitespace from the beginning and end of a string
|
||||
static std::string trim(const std::string & str) {
|
||||
size_t start = 0;
|
||||
size_t end = str.size();
|
||||
|
||||
while (start < end && isspace(str[start])) {
|
||||
start += 1;
|
||||
}
|
||||
|
||||
while (end > start && isspace(str[end - 1])) {
|
||||
end -= 1;
|
||||
}
|
||||
|
||||
return str.substr(start, end - start);
|
||||
}
|
||||
|
||||
static std::string k_system = R"(
|
||||
Transcript of a dialog, where the User interacts with an Assistant.
|
||||
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
|
||||
|
||||
User: Hello, what is the temperature outside?
|
||||
Assistant: It is 72 degrees Fahrenheit.
|
||||
User: What is the definition of a prime number?
|
||||
Assistant: A prime number is a number that is divisible only by itself and 1.
|
||||
User: )";
|
||||
|
||||
static std::vector<std::string> k_prompts = {
|
||||
"What is the meaning of life?",
|
||||
"What is the population of Europe?",
|
||||
"List all planets in the Solar System.",
|
||||
"What is the capital of France?",
|
||||
"Tell me an interesting fact about llamas.",
|
||||
"What is the best way to cook a steak?",
|
||||
"Are you familiar with the Special Theory of Relativity and can you explain it to me?",
|
||||
"Recommend some interesting books to read.",
|
||||
"What is the best way to learn a new language?",
|
||||
"How to get a job at Google?",
|
||||
"If you could have any superpower, what would it be?",
|
||||
"I want to learn how to play the piano.",
|
||||
};
|
||||
|
||||
struct client {
|
||||
int32_t id = 0;
|
||||
|
||||
llama_seq_id seq_id = -1;
|
||||
|
||||
llama_token sampled;
|
||||
|
||||
int32_t n_prompt = 0;
|
||||
int32_t n_decoded = 0;
|
||||
int32_t i_batch = -1;
|
||||
|
||||
std::string input;
|
||||
std::string prompt;
|
||||
std::string response;
|
||||
|
||||
std::vector<llama_token> last_tokens;
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
if (gpt_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
const int n_clients = 16;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("parallel", "log"));
|
||||
LOG_TEE("Log start\n");
|
||||
log_dump_cmdline(argc, argv);
|
||||
#endif // LOG_DISABLE_LOGS
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
llama_model * model = NULL;
|
||||
|
||||
llama_context * ctx = NULL;
|
||||
|
||||
// load the target model
|
||||
params.logits_all = true;
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
fflush(stderr);
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
std::vector<client> clients(n_clients);
|
||||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.last_tokens.resize(n_ctx);
|
||||
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
auto t_main_start = ggml_time_us();
|
||||
|
||||
int64_t n_tokens_total = 0;
|
||||
|
||||
llama_seq_id g_seq_id = 0;
|
||||
|
||||
std::vector<llama_token> batch_token;
|
||||
std::vector<llama_pos> batch_pos;
|
||||
std::vector<llama_seq_id> batch_seq_id;
|
||||
std::vector<client *> batch_clients;
|
||||
|
||||
while (true) {
|
||||
uint32_t n_tokens = 0;
|
||||
|
||||
batch_token.clear();
|
||||
batch_pos.clear();
|
||||
batch_seq_id.clear();
|
||||
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1) {
|
||||
client.seq_id = g_seq_id;
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = k_system + client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
std::fill(client.last_tokens.begin(), client.last_tokens.end(), 0);
|
||||
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
prompt_tokens = ::llama_tokenize(ctx, client.prompt, true);
|
||||
|
||||
for (size_t i = 0; i < prompt_tokens.size(); ++i) {
|
||||
batch_token.push_back(prompt_tokens[i]);
|
||||
batch_pos.push_back(i);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_clients.push_back(&client);
|
||||
}
|
||||
client.n_prompt = prompt_tokens.size();
|
||||
client.n_decoded = prompt_tokens.size();
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
|
||||
g_seq_id += 1;
|
||||
} else {
|
||||
batch_token.push_back(client.sampled);
|
||||
batch_pos.push_back(client.n_decoded);
|
||||
batch_seq_id.push_back(client.seq_id);
|
||||
batch_clients.push_back(&client);
|
||||
client.n_decoded += 1;
|
||||
client.i_batch = batch_token.size() - 1;
|
||||
}
|
||||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
for (size_t i = 0; i < batch_token.size(); i += params.n_batch) {
|
||||
n_tokens = std::min(params.n_batch, (int32_t) (batch_token.size() - i));
|
||||
|
||||
llama_batch batch = {
|
||||
n_tokens,
|
||||
batch_token.data() + i,
|
||||
nullptr,
|
||||
batch_pos.data() + i,
|
||||
batch_seq_id.data() + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
LOG_TEE("%s : failed to decode batch\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
for (auto & client : clients) {
|
||||
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.last_tokens, candidates, client.i_batch - i);
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
client.last_tokens.erase(client.last_tokens.begin());
|
||||
client.last_tokens.push_back(id);
|
||||
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
client.response += token_str;
|
||||
client.sampled = id;
|
||||
|
||||
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
|
||||
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
|
||||
|
||||
if (id == llama_token_eos(ctx) || client.n_decoded > params.n_predict || client.response.find("User:") != std::string::npos) {
|
||||
const size_t pos = client.response.find("User:");
|
||||
if (pos != std::string::npos) {
|
||||
client.response = client.response.substr(0, pos);
|
||||
}
|
||||
|
||||
llama_kv_cache_rm_seq(ctx, client.seq_id, 0, n_ctx);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
n_tokens_total += client.n_decoded - client.n_prompt;
|
||||
|
||||
printf("\033[1mClient %d, seq %d, prompt %d t, response %d t, speed: %.2f t/s\033[0m: \n\nInput: %s\nResponse: %s\n\n",
|
||||
client.id, client.seq_id, client.n_prompt, client.n_decoded - client.n_prompt,
|
||||
(double) n_tokens_total / (t_main_end - t_main_start) * 1e6,
|
||||
client.input.c_str(), ::trim(client.response).c_str());
|
||||
|
||||
client.seq_id = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_first = true;
|
||||
if (is_first) {
|
||||
t_main_start = ggml_time_us();
|
||||
n_tokens_total = 0;
|
||||
is_first = false;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_TEE("\n\n");
|
||||
|
||||
llama_print_timings(ctx);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
|
||||
return 0;
|
||||
}
|
@ -681,7 +681,7 @@ int main(int argc, char ** argv) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
params.perplexity = true;
|
||||
params.logits_all = true;
|
||||
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||
|
||||
if (params.ppl_stride > 0) {
|
||||
|
@ -37,7 +37,7 @@ int main(int argc, char ** argv) {
|
||||
llama_context * ctx_dft = NULL;
|
||||
|
||||
// load the target model
|
||||
params.perplexity = true; // HACK: enable logits_all = true
|
||||
params.logits_all = true;
|
||||
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
||||
|
||||
// load the draft model
|
||||
@ -172,7 +172,6 @@ int main(int argc, char ** argv) {
|
||||
LOG("out of drafted tokens\n");
|
||||
}
|
||||
|
||||
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_dft, n_ctx);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
|
||||
++n_past_dft;
|
||||
|
||||
@ -218,7 +217,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// sample n_draft tokens from the draft model using greedy decoding
|
||||
int n_past_cur = n_past_dft;
|
||||
|
||||
for (int i = 0; i < n_draft; ++i) {
|
||||
float * logits = llama_get_logits(ctx_dft);
|
||||
|
||||
@ -258,7 +256,6 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// evaluate the drafted token on the draft model
|
||||
llama_kv_cache_rm_seq(ctx_dft, 0, n_past_cur, n_ctx);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
|
||||
++n_past_cur;
|
||||
|
||||
@ -268,7 +265,6 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
llama_kv_cache_rm_seq(ctx_tgt, 0, n_past_tgt, n_ctx);
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
|
||||
++n_past_tgt;
|
||||
|
||||
|
@ -6673,7 +6673,7 @@ struct llama_context * llama_new_context_with_model(
|
||||
ctx->alloc = ggml_allocr_new_measure(tensor_alignment);
|
||||
|
||||
// build worst-case graph
|
||||
uint32_t n_tokens = std::min((int)hparams.n_ctx, params.n_batch);
|
||||
uint32_t n_tokens = std::max((int)hparams.n_ctx, params.n_batch);
|
||||
llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
||||
ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, 0, 0));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user