llama : improve llama_batch API + simplify parallel example

This commit is contained in:
Georgi Gerganov 2023-09-20 10:46:18 +03:00
parent a1327c71c6
commit addae65fd4
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
6 changed files with 111 additions and 70 deletions

View File

@ -127,11 +127,7 @@ int main(int argc, char ** argv) {
llama_seq_id g_seq_id = 0; llama_seq_id g_seq_id = 0;
std::vector<llama_token> batch_token; llama_batch batch = llama_batch_init(params.n_batch, 0);
std::vector<llama_pos> batch_pos;
std::vector<llama_seq_id> batch_seq_id;
std::vector<int8_t> batch_logits;
std::vector<client *> batch_clients;
int32_t n_total_prompt = 0; int32_t n_total_prompt = 0;
int32_t n_total_gen = 0; int32_t n_total_gen = 0;
@ -146,24 +142,15 @@ int main(int argc, char ** argv) {
{ {
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__); LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
batch_pos.clear(); batch.n_tokens = n_tokens_system;
batch_seq_id.clear();
for (size_t i = 0; i < n_tokens_system; ++i) { for (uint32_t i = 0; i < batch.n_tokens; ++i) {
batch_pos.push_back(i); batch.token[i] = tokens_system[i];
batch_seq_id.push_back(0); batch.pos[i] = i;
batch.seq_id[i] = 0;
batch.logits[i] = false;
} }
llama_batch batch = {
n_tokens_system,
tokens_system.data(),
nullptr,
batch_pos.data(),
batch_seq_id.data(),
nullptr,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch, params.n_threads) != 0) { if (llama_decode(ctx, batch, params.n_threads) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__); LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1; return 1;
@ -180,63 +167,72 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n"); LOG_TEE("Processing requests ...\n\n");
while (true) { while (true) {
uint32_t n_tokens = 0; batch.n_tokens = 0;
batch_token.clear();
batch_pos.clear();
batch_seq_id.clear();
batch_logits.clear();
// decode any currently ongoing sequences
for (auto & client : clients) { for (auto & client : clients) {
if (client.seq_id == -1) { if (client.seq_id == -1) {
continue; continue;
} }
batch_token.push_back(client.sampled); batch.token [batch.n_tokens] = client.sampled;
batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded); batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
batch_seq_id.push_back(client.id); batch.seq_id[batch.n_tokens] = client.id;
batch_logits.push_back(true); batch.logits[batch.n_tokens] = true;
batch_clients.push_back(&client);
client.n_decoded += 1; client.n_decoded += 1;
client.i_batch = batch_token.size() - 1; client.i_batch = batch.n_tokens;
batch.n_tokens += 1;
} }
if (batch_token.empty()) { if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache // all sequences have ended - clear the entire KV cache
for (int i = 0; i < n_clients; ++i) { for (int i = 0; i < n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1); llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
} }
} }
if (cont_batching || batch_token.empty()) { // insert new sequences for decoding
if (cont_batching || batch.n_tokens == 0) {
for (auto & client : clients) { for (auto & client : clients) {
if (client.seq_id == -1 && g_seq_id < n_seq) { if (client.seq_id == -1 && g_seq_id < n_seq) {
client.seq_id = g_seq_id; client.seq_id = g_seq_id;
client.t_start_prompt = ggml_time_us(); client.t_start_prompt = ggml_time_us();
client.t_start_gen = 0; client.t_start_gen = 0;
client.input = k_prompts[rand() % k_prompts.size()]; client.input = k_prompts[rand() % k_prompts.size()];
client.prompt = client.input + "\nAssistant:"; client.prompt = client.input + "\nAssistant:";
client.response = ""; client.response = "";
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0); std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
std::vector<llama_token> tokens_prompt; std::vector<llama_token> tokens_prompt;
tokens_prompt = ::llama_tokenize(ctx, client.prompt, true); tokens_prompt = ::llama_tokenize(ctx, client.prompt, true);
for (size_t i = 0; i < tokens_prompt.size(); ++i) { for (size_t i = 0; i < tokens_prompt.size(); ++i) {
batch_token.push_back(tokens_prompt[i]); batch.token [batch.n_tokens] = tokens_prompt[i];
batch_pos.push_back(i + n_tokens_system); batch.pos [batch.n_tokens] = i + n_tokens_system;
batch_seq_id.push_back(client.id); batch.seq_id[batch.n_tokens] = client.id;
batch_clients.push_back(&client); batch.logits[batch.n_tokens] = false;
batch_logits.push_back(false); batch.n_tokens += 1;
}
// extract the logits only for the last token
if (batch.n_tokens > 0) {
batch.logits[batch.n_tokens - 1] = true;
} }
batch_logits.back() = true;
client.n_prompt = tokens_prompt.size(); client.n_prompt = tokens_prompt.size();
client.n_decoded = 0; client.n_decoded = 0;
client.i_batch = batch_token.size() - 1; client.i_batch = batch.n_tokens - 1;
LOG_TEE("\033[1mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
g_seq_id += 1; g_seq_id += 1;
// insert new requests one-by-one
//if (cont_batching) { //if (cont_batching) {
// break; // break;
//} //}
@ -244,34 +240,35 @@ int main(int argc, char ** argv) {
} }
} }
if (batch_token.empty()) { if (batch.n_tokens == 0) {
break; break;
} }
// process in chunks of params.n_batch // process in chunks of params.n_batch
int32_t n_batch = params.n_batch; int32_t n_batch = params.n_batch;
for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) { for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i)); const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch = { llama_batch batch_view = {
n_tokens, n_tokens,
batch_token.data() + i, batch.token + i,
nullptr, nullptr,
batch_pos.data() + i, batch.pos + i,
batch_seq_id.data() + i, batch.seq_id + i,
batch_logits.data() + i, batch.logits + i,
0, 0, 0, // unused 0, 0, 0, // unused
}; };
const int ret = llama_decode(ctx, batch, params.n_threads); const int ret = llama_decode(ctx, batch_view, params.n_threads);
if (ret != 0) { if (ret != 0) {
if (n_batch == 1 || ret < 0) { if (n_batch == 1 || ret < 0) {
LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret); // if you get here, it means the KV cache is full - try increasing it via the context size
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
return 1; return 1;
} }
LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2); LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
n_cache_miss += 1; n_cache_miss += 1;
@ -357,6 +354,8 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx); llama_print_timings(ctx);
llama_batch_free(batch);
llama_free(ctx); llama_free(ctx);
llama_free_model(model); llama_free_model(model);

View File

@ -419,7 +419,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
} }
static std::vector<float> hellaswag_evaluate_tokens( static std::vector<float> hellaswag_evaluate_tokens(
llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab, int n_thread llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread
) { ) {
std::vector<float> result; std::vector<float> result;
result.reserve(tokens.size() * n_vocab); result.reserve(tokens.size() * n_vocab);

View File

@ -10,10 +10,12 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
if (argc == 1 || argv[1][0] == '-') { if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]); printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]);
return 1 ; return 1 ;
} }
int n_parallel = 1;
if (argc >= 2) { if (argc >= 2) {
params.model = argv[1]; params.model = argv[1];
} }
@ -22,6 +24,10 @@ int main(int argc, char ** argv) {
params.prompt = argv[2]; params.prompt = argv[2];
} }
if (argc >= 4) {
n_parallel = std::atoi(argv[3]);
}
if (params.prompt.empty()) { if (params.prompt.empty()) {
params.prompt = "Hello my name is"; params.prompt = "Hello my name is";
} }

View File

@ -134,7 +134,7 @@ int main(int argc, char ** argv) {
while (true) { while (true) {
// sample from the target model // sample from the target model
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft); llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
// remember which tokens were sampled - used for repetition penalties during sampling // remember which tokens were sampled - used for repetition penalties during sampling
last_tokens.erase(last_tokens.begin()); last_tokens.erase(last_tokens.begin());

View File

@ -7356,7 +7356,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
int llama_eval( int llama_eval(
struct llama_context * ctx, struct llama_context * ctx,
const llama_token * tokens, llama_token * tokens,
uint32_t n_tokens, uint32_t n_tokens,
int n_past, int n_past,
int n_threads) { int n_threads) {
@ -7376,7 +7376,7 @@ int llama_eval(
int llama_eval_embd( int llama_eval_embd(
struct llama_context * ctx, struct llama_context * ctx,
const float * embd, float * embd,
uint32_t n_tokens, uint32_t n_tokens,
int n_past, int n_past,
int n_threads) { int n_threads) {
@ -7397,7 +7397,7 @@ int llama_eval_embd(
} }
struct llama_batch llama_batch_get_one( struct llama_batch llama_batch_get_one(
const llama_token * tokens, llama_token * tokens,
uint32_t n_tokens, uint32_t n_tokens,
llama_pos pos_0, llama_pos pos_0,
llama_seq_id seq_id) { llama_seq_id seq_id) {
@ -7414,6 +7414,30 @@ struct llama_batch llama_batch_get_one(
}; };
} }
struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd) {
llama_batch batch = { n_tokens, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
if (embd) {
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
} else {
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
}
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens);
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
return batch;
}
void llama_batch_free(struct llama_batch batch) {
if (batch.token) free(batch.token);
if (batch.embd) free(batch.embd);
if (batch.pos) free(batch.pos);
if (batch.seq_id) free(batch.seq_id);
if (batch.logits) free(batch.logits);
}
int llama_decode( int llama_decode(
struct llama_context * ctx, struct llama_context * ctx,
struct llama_batch batch, struct llama_batch batch,

32
llama.h
View File

@ -70,11 +70,11 @@ extern "C" {
typedef struct llama_batch { typedef struct llama_batch {
uint32_t n_tokens; uint32_t n_tokens;
const llama_token * token; llama_token * token;
const float * embd; float * embd;
const llama_pos * pos; llama_pos * pos;
const llama_seq_id * seq_id; llama_seq_id * seq_id;
const int8_t * logits; // if 0, do not extract logits for that token int8_t * logits; // if 0, do not extract logits for that token
// NOTE: helpers for smooth API transition - can be deprecated in the future // NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below // for future-proof code, use the above fields instead and ignore everything below
@ -84,7 +84,7 @@ extern "C" {
llama_pos all_pos_0; // used if pos == NULL llama_pos all_pos_0; // used if pos == NULL
llama_pos all_pos_1; // used if pos == NULL llama_pos all_pos_1; // used if pos == NULL
llama_seq_id all_seq_id; // used if seq_id == NULL llama_seq_id all_seq_id; // used if seq_id == NULL
} llama_seq; } llama_batch;
enum llama_log_level { enum llama_log_level {
LLAMA_LOG_LEVEL_ERROR = 2, LLAMA_LOG_LEVEL_ERROR = 2,
@ -366,34 +366,46 @@ extern "C" {
// tokens + n_tokens is the provided batch of new tokens to process // tokens + n_tokens is the provided batch of new tokens to process
// n_past is the number of tokens to use from previous eval calls // n_past is the number of tokens to use from previous eval calls
// Returns 0 on success // Returns 0 on success
// DEPRECATED: use llama_decode() instead
LLAMA_API DEPRECATED(int llama_eval( LLAMA_API DEPRECATED(int llama_eval(
struct llama_context * ctx, struct llama_context * ctx,
const llama_token * tokens, llama_token * tokens,
uint32_t n_tokens, uint32_t n_tokens,
int n_past, int n_past,
int n_threads), int n_threads),
"please use llama_decode() instead"); "please use llama_decode() instead");
// Same as llama_eval, but use float matrix input directly. // Same as llama_eval, but use float matrix input directly.
// DEPRECATED: use llama_decode() instead
LLAMA_API DEPRECATED(int llama_eval_embd( LLAMA_API DEPRECATED(int llama_eval_embd(
struct llama_context * ctx, struct llama_context * ctx,
const float * embd, float * embd,
uint32_t n_tokens, uint32_t n_tokens,
int n_past, int n_past,
int n_threads), int n_threads),
"please use llama_decode() instead"); "please use llama_decode() instead");
// Return batch for single sequence of tokens starting at pos_0 // Return batch for single sequence of tokens starting at pos_0
// If pos_0 == 0, the clear_kv flag will be auto set to true
// //
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
// //
LLAMA_API struct llama_batch llama_batch_get_one( LLAMA_API struct llama_batch llama_batch_get_one(
const llama_token * tokens, llama_token * tokens,
uint32_t n_tokens, uint32_t n_tokens,
llama_pos pos_0, llama_pos pos_0,
llama_seq_id seq_id); llama_seq_id seq_id);
// Allocates a batch of tokens on the heap
// The batch needs to be freed with llama_batch_free()
// If embd > 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
// The rest of the llama_batch members are allocated with size n_tokens
// All members are left uninitialized
LLAMA_API struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd);
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);
// Positive return values does not mean a fatal error, but rather a warning. // Positive return values does not mean a fatal error, but rather a warning.
// 0 - success // 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)