mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-31 22:04:35 +00:00
llama : improve llama_batch API + simplify parallel example
This commit is contained in:
parent
a1327c71c6
commit
addae65fd4
@ -127,11 +127,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_seq_id g_seq_id = 0;
|
llama_seq_id g_seq_id = 0;
|
||||||
|
|
||||||
std::vector<llama_token> batch_token;
|
llama_batch batch = llama_batch_init(params.n_batch, 0);
|
||||||
std::vector<llama_pos> batch_pos;
|
|
||||||
std::vector<llama_seq_id> batch_seq_id;
|
|
||||||
std::vector<int8_t> batch_logits;
|
|
||||||
std::vector<client *> batch_clients;
|
|
||||||
|
|
||||||
int32_t n_total_prompt = 0;
|
int32_t n_total_prompt = 0;
|
||||||
int32_t n_total_gen = 0;
|
int32_t n_total_gen = 0;
|
||||||
@ -146,24 +142,15 @@ int main(int argc, char ** argv) {
|
|||||||
{
|
{
|
||||||
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
|
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
|
||||||
|
|
||||||
batch_pos.clear();
|
batch.n_tokens = n_tokens_system;
|
||||||
batch_seq_id.clear();
|
|
||||||
|
|
||||||
for (size_t i = 0; i < n_tokens_system; ++i) {
|
for (uint32_t i = 0; i < batch.n_tokens; ++i) {
|
||||||
batch_pos.push_back(i);
|
batch.token[i] = tokens_system[i];
|
||||||
batch_seq_id.push_back(0);
|
batch.pos[i] = i;
|
||||||
|
batch.seq_id[i] = 0;
|
||||||
|
batch.logits[i] = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch batch = {
|
|
||||||
n_tokens_system,
|
|
||||||
tokens_system.data(),
|
|
||||||
nullptr,
|
|
||||||
batch_pos.data(),
|
|
||||||
batch_seq_id.data(),
|
|
||||||
nullptr,
|
|
||||||
0, 0, 0, // unused
|
|
||||||
};
|
|
||||||
|
|
||||||
if (llama_decode(ctx, batch, params.n_threads) != 0) {
|
if (llama_decode(ctx, batch, params.n_threads) != 0) {
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
@ -180,63 +167,72 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("Processing requests ...\n\n");
|
LOG_TEE("Processing requests ...\n\n");
|
||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
uint32_t n_tokens = 0;
|
batch.n_tokens = 0;
|
||||||
|
|
||||||
batch_token.clear();
|
|
||||||
batch_pos.clear();
|
|
||||||
batch_seq_id.clear();
|
|
||||||
batch_logits.clear();
|
|
||||||
|
|
||||||
|
// decode any currently ongoing sequences
|
||||||
for (auto & client : clients) {
|
for (auto & client : clients) {
|
||||||
if (client.seq_id == -1) {
|
if (client.seq_id == -1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
batch_token.push_back(client.sampled);
|
batch.token [batch.n_tokens] = client.sampled;
|
||||||
batch_pos.push_back(n_tokens_system + client.n_prompt + client.n_decoded);
|
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
|
||||||
batch_seq_id.push_back(client.id);
|
batch.seq_id[batch.n_tokens] = client.id;
|
||||||
batch_logits.push_back(true);
|
batch.logits[batch.n_tokens] = true;
|
||||||
batch_clients.push_back(&client);
|
|
||||||
client.n_decoded += 1;
|
client.n_decoded += 1;
|
||||||
client.i_batch = batch_token.size() - 1;
|
client.i_batch = batch.n_tokens;
|
||||||
|
|
||||||
|
batch.n_tokens += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch_token.empty()) {
|
if (batch.n_tokens == 0) {
|
||||||
// all sequences have ended - clear the entire KV cache
|
// all sequences have ended - clear the entire KV cache
|
||||||
for (int i = 0; i < n_clients; ++i) {
|
for (int i = 0; i < n_clients; ++i) {
|
||||||
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
|
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cont_batching || batch_token.empty()) {
|
// insert new sequences for decoding
|
||||||
|
if (cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & client : clients) {
|
for (auto & client : clients) {
|
||||||
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
||||||
client.seq_id = g_seq_id;
|
client.seq_id = g_seq_id;
|
||||||
|
|
||||||
client.t_start_prompt = ggml_time_us();
|
client.t_start_prompt = ggml_time_us();
|
||||||
client.t_start_gen = 0;
|
client.t_start_gen = 0;
|
||||||
|
|
||||||
client.input = k_prompts[rand() % k_prompts.size()];
|
client.input = k_prompts[rand() % k_prompts.size()];
|
||||||
client.prompt = client.input + "\nAssistant:";
|
client.prompt = client.input + "\nAssistant:";
|
||||||
client.response = "";
|
client.response = "";
|
||||||
|
|
||||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||||
|
|
||||||
std::vector<llama_token> tokens_prompt;
|
std::vector<llama_token> tokens_prompt;
|
||||||
tokens_prompt = ::llama_tokenize(ctx, client.prompt, true);
|
tokens_prompt = ::llama_tokenize(ctx, client.prompt, true);
|
||||||
|
|
||||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||||
batch_token.push_back(tokens_prompt[i]);
|
batch.token [batch.n_tokens] = tokens_prompt[i];
|
||||||
batch_pos.push_back(i + n_tokens_system);
|
batch.pos [batch.n_tokens] = i + n_tokens_system;
|
||||||
batch_seq_id.push_back(client.id);
|
batch.seq_id[batch.n_tokens] = client.id;
|
||||||
batch_clients.push_back(&client);
|
batch.logits[batch.n_tokens] = false;
|
||||||
batch_logits.push_back(false);
|
batch.n_tokens += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// extract the logits only for the last token
|
||||||
|
if (batch.n_tokens > 0) {
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
}
|
}
|
||||||
batch_logits.back() = true;
|
|
||||||
|
|
||||||
client.n_prompt = tokens_prompt.size();
|
client.n_prompt = tokens_prompt.size();
|
||||||
client.n_decoded = 0;
|
client.n_decoded = 0;
|
||||||
client.i_batch = batch_token.size() - 1;
|
client.i_batch = batch.n_tokens - 1;
|
||||||
|
|
||||||
|
LOG_TEE("\033[1mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
|
||||||
|
|
||||||
g_seq_id += 1;
|
g_seq_id += 1;
|
||||||
|
|
||||||
|
// insert new requests one-by-one
|
||||||
//if (cont_batching) {
|
//if (cont_batching) {
|
||||||
// break;
|
// break;
|
||||||
//}
|
//}
|
||||||
@ -244,34 +240,35 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (batch_token.empty()) {
|
if (batch.n_tokens == 0) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
// process in chunks of params.n_batch
|
// process in chunks of params.n_batch
|
||||||
int32_t n_batch = params.n_batch;
|
int32_t n_batch = params.n_batch;
|
||||||
|
|
||||||
for (int32_t i = 0; i < (int32_t) batch_token.size(); i += n_batch) {
|
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||||
n_tokens = std::min(n_batch, (int32_t) (batch_token.size() - i));
|
const uint32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||||
|
|
||||||
llama_batch batch = {
|
llama_batch batch_view = {
|
||||||
n_tokens,
|
n_tokens,
|
||||||
batch_token.data() + i,
|
batch.token + i,
|
||||||
nullptr,
|
nullptr,
|
||||||
batch_pos.data() + i,
|
batch.pos + i,
|
||||||
batch_seq_id.data() + i,
|
batch.seq_id + i,
|
||||||
batch_logits.data() + i,
|
batch.logits + i,
|
||||||
0, 0, 0, // unused
|
0, 0, 0, // unused
|
||||||
};
|
};
|
||||||
|
|
||||||
const int ret = llama_decode(ctx, batch, params.n_threads);
|
const int ret = llama_decode(ctx, batch_view, params.n_threads);
|
||||||
if (ret != 0) {
|
if (ret != 0) {
|
||||||
if (n_batch == 1 || ret < 0) {
|
if (n_batch == 1 || ret < 0) {
|
||||||
LOG_TEE("%s : failed to decode batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||||
|
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG("%s : failed to decode batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
|
LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
|
||||||
|
|
||||||
n_cache_miss += 1;
|
n_cache_miss += 1;
|
||||||
|
|
||||||
@ -357,6 +354,8 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
|
@ -419,7 +419,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<float> hellaswag_evaluate_tokens(
|
static std::vector<float> hellaswag_evaluate_tokens(
|
||||||
llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab, int n_thread
|
llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread
|
||||||
) {
|
) {
|
||||||
std::vector<float> result;
|
std::vector<float> result;
|
||||||
result.reserve(tokens.size() * n_vocab);
|
result.reserve(tokens.size() * n_vocab);
|
||||||
|
@ -10,10 +10,12 @@ int main(int argc, char ** argv) {
|
|||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
if (argc == 1 || argv[1][0] == '-') {
|
if (argc == 1 || argv[1][0] == '-') {
|
||||||
printf("usage: %s MODEL_PATH [PROMPT]\n" , argv[0]);
|
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]);
|
||||||
return 1 ;
|
return 1 ;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int n_parallel = 1;
|
||||||
|
|
||||||
if (argc >= 2) {
|
if (argc >= 2) {
|
||||||
params.model = argv[1];
|
params.model = argv[1];
|
||||||
}
|
}
|
||||||
@ -22,6 +24,10 @@ int main(int argc, char ** argv) {
|
|||||||
params.prompt = argv[2];
|
params.prompt = argv[2];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (argc >= 4) {
|
||||||
|
n_parallel = std::atoi(argv[3]);
|
||||||
|
}
|
||||||
|
|
||||||
if (params.prompt.empty()) {
|
if (params.prompt.empty()) {
|
||||||
params.prompt = "Hello my name is";
|
params.prompt = "Hello my name is";
|
||||||
}
|
}
|
||||||
|
@ -134,7 +134,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
while (true) {
|
while (true) {
|
||||||
// sample from the target model
|
// sample from the target model
|
||||||
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
|
llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
|
||||||
|
|
||||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||||
last_tokens.erase(last_tokens.begin());
|
last_tokens.erase(last_tokens.begin());
|
||||||
|
30
llama.cpp
30
llama.cpp
@ -7356,7 +7356,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
|||||||
|
|
||||||
int llama_eval(
|
int llama_eval(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const llama_token * tokens,
|
llama_token * tokens,
|
||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads) {
|
int n_threads) {
|
||||||
@ -7376,7 +7376,7 @@ int llama_eval(
|
|||||||
|
|
||||||
int llama_eval_embd(
|
int llama_eval_embd(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const float * embd,
|
float * embd,
|
||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads) {
|
int n_threads) {
|
||||||
@ -7397,7 +7397,7 @@ int llama_eval_embd(
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct llama_batch llama_batch_get_one(
|
struct llama_batch llama_batch_get_one(
|
||||||
const llama_token * tokens,
|
llama_token * tokens,
|
||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
llama_pos pos_0,
|
llama_pos pos_0,
|
||||||
llama_seq_id seq_id) {
|
llama_seq_id seq_id) {
|
||||||
@ -7414,6 +7414,30 @@ struct llama_batch llama_batch_get_one(
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd) {
|
||||||
|
llama_batch batch = { n_tokens, nullptr, nullptr, nullptr, nullptr, nullptr, 0, 0, 0, };
|
||||||
|
|
||||||
|
if (embd) {
|
||||||
|
batch.embd = (float *) malloc(sizeof(float) * n_tokens * embd);
|
||||||
|
} else {
|
||||||
|
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_tokens);
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens);
|
||||||
|
batch.seq_id = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_tokens);
|
||||||
|
batch.logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens);
|
||||||
|
|
||||||
|
return batch;
|
||||||
|
}
|
||||||
|
|
||||||
|
void llama_batch_free(struct llama_batch batch) {
|
||||||
|
if (batch.token) free(batch.token);
|
||||||
|
if (batch.embd) free(batch.embd);
|
||||||
|
if (batch.pos) free(batch.pos);
|
||||||
|
if (batch.seq_id) free(batch.seq_id);
|
||||||
|
if (batch.logits) free(batch.logits);
|
||||||
|
}
|
||||||
|
|
||||||
int llama_decode(
|
int llama_decode(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
struct llama_batch batch,
|
struct llama_batch batch,
|
||||||
|
32
llama.h
32
llama.h
@ -70,11 +70,11 @@ extern "C" {
|
|||||||
typedef struct llama_batch {
|
typedef struct llama_batch {
|
||||||
uint32_t n_tokens;
|
uint32_t n_tokens;
|
||||||
|
|
||||||
const llama_token * token;
|
llama_token * token;
|
||||||
const float * embd;
|
float * embd;
|
||||||
const llama_pos * pos;
|
llama_pos * pos;
|
||||||
const llama_seq_id * seq_id;
|
llama_seq_id * seq_id;
|
||||||
const int8_t * logits; // if 0, do not extract logits for that token
|
int8_t * logits; // if 0, do not extract logits for that token
|
||||||
|
|
||||||
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
// NOTE: helpers for smooth API transition - can be deprecated in the future
|
||||||
// for future-proof code, use the above fields instead and ignore everything below
|
// for future-proof code, use the above fields instead and ignore everything below
|
||||||
@ -84,7 +84,7 @@ extern "C" {
|
|||||||
llama_pos all_pos_0; // used if pos == NULL
|
llama_pos all_pos_0; // used if pos == NULL
|
||||||
llama_pos all_pos_1; // used if pos == NULL
|
llama_pos all_pos_1; // used if pos == NULL
|
||||||
llama_seq_id all_seq_id; // used if seq_id == NULL
|
llama_seq_id all_seq_id; // used if seq_id == NULL
|
||||||
} llama_seq;
|
} llama_batch;
|
||||||
|
|
||||||
enum llama_log_level {
|
enum llama_log_level {
|
||||||
LLAMA_LOG_LEVEL_ERROR = 2,
|
LLAMA_LOG_LEVEL_ERROR = 2,
|
||||||
@ -366,34 +366,46 @@ extern "C" {
|
|||||||
// tokens + n_tokens is the provided batch of new tokens to process
|
// tokens + n_tokens is the provided batch of new tokens to process
|
||||||
// n_past is the number of tokens to use from previous eval calls
|
// n_past is the number of tokens to use from previous eval calls
|
||||||
// Returns 0 on success
|
// Returns 0 on success
|
||||||
|
// DEPRECATED: use llama_decode() instead
|
||||||
LLAMA_API DEPRECATED(int llama_eval(
|
LLAMA_API DEPRECATED(int llama_eval(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const llama_token * tokens,
|
llama_token * tokens,
|
||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads),
|
int n_threads),
|
||||||
"please use llama_decode() instead");
|
"please use llama_decode() instead");
|
||||||
|
|
||||||
// Same as llama_eval, but use float matrix input directly.
|
// Same as llama_eval, but use float matrix input directly.
|
||||||
|
// DEPRECATED: use llama_decode() instead
|
||||||
LLAMA_API DEPRECATED(int llama_eval_embd(
|
LLAMA_API DEPRECATED(int llama_eval_embd(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const float * embd,
|
float * embd,
|
||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads),
|
int n_threads),
|
||||||
"please use llama_decode() instead");
|
"please use llama_decode() instead");
|
||||||
|
|
||||||
// Return batch for single sequence of tokens starting at pos_0
|
// Return batch for single sequence of tokens starting at pos_0
|
||||||
// If pos_0 == 0, the clear_kv flag will be auto set to true
|
|
||||||
//
|
//
|
||||||
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
// NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
|
||||||
//
|
//
|
||||||
LLAMA_API struct llama_batch llama_batch_get_one(
|
LLAMA_API struct llama_batch llama_batch_get_one(
|
||||||
const llama_token * tokens,
|
llama_token * tokens,
|
||||||
uint32_t n_tokens,
|
uint32_t n_tokens,
|
||||||
llama_pos pos_0,
|
llama_pos pos_0,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// Allocates a batch of tokens on the heap
|
||||||
|
// The batch needs to be freed with llama_batch_free()
|
||||||
|
// If embd > 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
|
||||||
|
// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
|
||||||
|
// The rest of the llama_batch members are allocated with size n_tokens
|
||||||
|
// All members are left uninitialized
|
||||||
|
LLAMA_API struct llama_batch llama_batch_init(uint32_t n_tokens, int32_t embd);
|
||||||
|
|
||||||
|
// Frees a batch of tokens allocated with llama_batch_init()
|
||||||
|
LLAMA_API void llama_batch_free(struct llama_batch batch);
|
||||||
|
|
||||||
// Positive return values does not mean a fatal error, but rather a warning.
|
// Positive return values does not mean a fatal error, but rather a warning.
|
||||||
// 0 - success
|
// 0 - success
|
||||||
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
|
||||||
|
Loading…
Reference in New Issue
Block a user