llama : unified KV cache + batch inference API

This commit is contained in:
Georgi Gerganov 2023-09-18 10:08:22 +03:00
parent fad56936d4
commit d29e76937c
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
10 changed files with 315 additions and 236 deletions

View File

@ -436,8 +436,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.use_mmap = false; params.use_mmap = false;
} else if (arg == "--numa") { } else if (arg == "--numa") {
params.numa = true; params.numa = true;
} else if (arg == "--export") {
params.export_cgraph = true;
} else if (arg == "--verbose-prompt") { } else if (arg == "--verbose-prompt") {
params.verbose_prompt = true; params.verbose_prompt = true;
} else if (arg == "-r" || arg == "--reverse-prompt") { } else if (arg == "-r" || arg == "--reverse-prompt") {
@ -685,7 +683,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
printf(" Not recommended since this is both slower and uses more VRAM.\n"); printf(" Not recommended since this is both slower and uses more VRAM.\n");
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
#endif #endif
printf(" --export export the computation graph to 'llama.ggml'\n");
printf(" --verbose-prompt print prompt before generation\n"); printf(" --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
@ -782,7 +779,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
{ {
LOG("warming up the model with an empty run\n"); LOG("warming up the model with an empty run\n");
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
llama_reset_timings(lctx); llama_reset_timings(lctx);
} }
@ -1182,7 +1179,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
fprintf(stream, "export: %s # default: false\n", params.export_cgraph ? "true" : "false");
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty); fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty);
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str()); dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());

View File

@ -111,7 +111,6 @@ struct gpt_params {
bool use_mmap = true; // use mmap for faster loads bool use_mmap = true; // use mmap for faster loads
bool use_mlock = false; // use mlock to keep model in memory bool use_mlock = false; // use mlock to keep model in memory
bool numa = false; // attempt optimizations that help on some NUMA systems bool numa = false; // attempt optimizations that help on some NUMA systems
bool export_cgraph = false; // export the computation graph
bool verbose_prompt = false; // print prompt tokens before generation bool verbose_prompt = false; // print prompt tokens before generation
}; };

View File

@ -158,7 +158,8 @@ int main(int argc, char ** argv)
} }
std::cout << std::flush; std::cout << std::flush;
int n_past = llama_get_kv_cache_token_count(ctx); int n_past = 0;
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads)) if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
{ {
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );

View File

@ -198,15 +198,6 @@ int main(int argc, char ** argv) {
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
} }
// export the cgraph and exit
if (params.export_cgraph) {
llama_eval_export(ctx, "llama.ggml");
llama_free(ctx);
llama_free_model(model);
return 0;
}
std::string path_session = params.path_prompt_cache; std::string path_session = params.path_prompt_cache;
std::vector<llama_token> session_tokens; std::vector<llama_token> session_tokens;

View File

@ -400,7 +400,7 @@ results_perplexity perplexity(llama_context * ctx, const gpt_params & params) {
return {tokens, ppl, logit_history, prob_history}; return {tokens, ppl, logit_history, prob_history};
} }
std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, std::vector<float> hellaswag_evaluate_tokens(llama_context * ctx, const std::vector<int> & tokens, int n_past, int n_batch,
int n_vocab, int n_thread) { int n_vocab, int n_thread) {
std::vector<float> result; std::vector<float> result;
result.reserve(tokens.size() * n_vocab); result.reserve(tokens.size() * n_vocab);

View File

@ -73,10 +73,12 @@ int main(int argc, char ** argv) {
const int n_gen = std::min(32, max_context_size); const int n_gen = std::min(32, max_context_size);
while (llama_get_kv_cache_token_count(ctx) < n_gen) { int n_cur = 0;
while (n_cur < n_gen) {
// evaluate the transformer // evaluate the transformer
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) { if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__); fprintf(stderr, "%s : failed to eval\n", __func__);
return 1; return 1;
} }

6
ggml.c
View File

@ -12462,13 +12462,11 @@ static void ggml_compute_forward_alibi_f16(
return; return;
} }
const int n_past = ((int32_t *) dst->op_params)[0]; //const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1]; const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias; float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
assert(n_past >= 0);
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past const int ne1 = src0->ne[1]; // seq_len_without_past
const int ne2 = src0->ne[2]; // n_head -> this is k const int ne2 = src0->ne[2]; // n_head -> this is k
@ -12483,7 +12481,7 @@ static void ggml_compute_forward_alibi_f16(
//const int nb3 = src0->nb[3]; //const int nb3 = src0->nb[3];
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
GGML_ASSERT(ne1 + n_past == ne0); (void) n_past; //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
GGML_ASSERT(n_head == ne2); GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled) // add alibi to src0 (KQ_scaled)

466
llama.cpp

File diff suppressed because it is too large Load Diff

34
llama.h
View File

@ -60,7 +60,20 @@ extern "C" {
struct llama_model; struct llama_model;
struct llama_context; struct llama_context;
typedef int llama_token; typedef int32_t llama_pos;
typedef int32_t llama_token;
typedef int32_t llama_seq_id;
// data used for batch inference
typedef struct llama_batch {
uint32_t n_tokens;
// TODO: not sure about these consts - might just get in the way all the time with no benefit
const llama_token * token;
const float * embd;
const llama_pos * pos;
const llama_seq_id * seq_id;
} llama_seq;
enum llama_log_level { enum llama_log_level {
LLAMA_LOG_LEVEL_ERROR = 2, LLAMA_LOG_LEVEL_ERROR = 2,
@ -289,8 +302,15 @@ extern "C" {
const char * path_base_model, const char * path_base_model,
int n_threads); int n_threads);
//
// KV cache API
//
// Returns the number of tokens in the KV cache // Returns the number of tokens in the KV cache
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx); LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
"avoid using this, it will be removed in the future");
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
// Sets the current rng seed. // Sets the current rng seed.
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
@ -319,7 +339,7 @@ extern "C" {
LLAMA_API int llama_eval( LLAMA_API int llama_eval(
struct llama_context * ctx, struct llama_context * ctx,
const llama_token * tokens, const llama_token * tokens,
int n_tokens, uint32_t n_tokens,
int n_past, int n_past,
int n_threads); int n_threads);
@ -327,16 +347,10 @@ extern "C" {
LLAMA_API int llama_eval_embd( LLAMA_API int llama_eval_embd(
struct llama_context * ctx, struct llama_context * ctx,
const float * embd, const float * embd,
int n_tokens, uint32_t n_tokens,
int n_past, int n_past,
int n_threads); int n_threads);
// Export a static computation graph for context of 511 and batch size of 1
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
// parameters here to keep things simple
// IMPORTANT: do not use for anything else other than debugging and testing!
LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
// Token logits obtained from the last call to llama_eval() // Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row // The logits for the last token are stored in the last row
// Can be mutated in order to change the probabilities of the next token // Can be mutated in order to change the probabilities of the next token

View File

@ -87,12 +87,13 @@ int main(int argc, char **argv) {
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false); std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_spm(ctx, tokens); std::string check = llama_detokenize_spm(ctx, tokens);
if (check != str) { if (check != str) {
fprintf(stderr, "%s : error: token %d detokenizes to >%s<(%llu) but tokenization of this detokenizes to >%s<(%llu)\n", fprintf(stderr, "%s : error: token %d detokenizes to >%s<(%d) but tokenization of this detokenizes to >%s<(%d)\n",
__func__, i, str.c_str(), str.length(), check.c_str(), check.length()); __func__, i, str.c_str(), (int) str.length(), check.c_str(), (int) check.length());
if(i != 3) if (i != 3) {
return 2; return 2;
} }
} }
}
for (codepoint cp = 0x0000; cp < 0xffff; ++cp) { for (codepoint cp = 0x0000; cp < 0xffff; ++cp) {
if (cp < 0xd800 || cp > 0xdfff) { if (cp < 0xd800 || cp > 0xdfff) {
@ -100,20 +101,21 @@ int main(int argc, char **argv) {
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false); std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_spm(ctx, tokens); std::string check = llama_detokenize_spm(ctx, tokens);
if (str != check) { if (str != check) {
fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n", fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%d) instead of >%s<(%d)\n",
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); __func__, cp, check.c_str(), (int) check.length(), str.c_str(), (int) str.length());
if(cp != 0 && cp != 9601) if (cp != 0 && cp != 9601) {
return 3; return 3;
} }
} }
} }
}
for (codepoint cp = 0x10000; cp < 0x0010ffff; ++cp) { for (codepoint cp = 0x10000; cp < 0x0010ffff; ++cp) {
std::string str = codepoint_to_utf8(cp); std::string str = codepoint_to_utf8(cp);
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false); std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
std::string check = llama_detokenize_spm(ctx, tokens); std::string check = llama_detokenize_spm(ctx, tokens);
if (str != check) { if (str != check) {
fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n", fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%d) instead of >%s<(%d)\n",
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length()); __func__, cp, check.c_str(), (int) check.length(), str.c_str(), (int) str.length());
return 4; return 4;
} }
} }