mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-31 22:04:35 +00:00
llama : unified KV cache + batch inference API
This commit is contained in:
parent
fad56936d4
commit
d29e76937c
@ -436,8 +436,6 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
|
|||||||
params.use_mmap = false;
|
params.use_mmap = false;
|
||||||
} else if (arg == "--numa") {
|
} else if (arg == "--numa") {
|
||||||
params.numa = true;
|
params.numa = true;
|
||||||
} else if (arg == "--export") {
|
|
||||||
params.export_cgraph = true;
|
|
||||||
} else if (arg == "--verbose-prompt") {
|
} else if (arg == "--verbose-prompt") {
|
||||||
params.verbose_prompt = true;
|
params.verbose_prompt = true;
|
||||||
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
} else if (arg == "-r" || arg == "--reverse-prompt") {
|
||||||
@ -685,7 +683,6 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
printf(" Not recommended since this is both slower and uses more VRAM.\n");
|
printf(" Not recommended since this is both slower and uses more VRAM.\n");
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
#endif
|
#endif
|
||||||
printf(" --export export the computation graph to 'llama.ggml'\n");
|
|
||||||
printf(" --verbose-prompt print prompt before generation\n");
|
printf(" --verbose-prompt print prompt before generation\n");
|
||||||
fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
|
fprintf(stderr, " --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
|
||||||
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
|
||||||
@ -782,7 +779,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||||||
{
|
{
|
||||||
LOG("warming up the model with an empty run\n");
|
LOG("warming up the model with an empty run\n");
|
||||||
|
|
||||||
const std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
|
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
|
||||||
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
|
llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads);
|
||||||
llama_reset_timings(lctx);
|
llama_reset_timings(lctx);
|
||||||
}
|
}
|
||||||
@ -1182,7 +1179,6 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||||||
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false");
|
||||||
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx);
|
||||||
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false");
|
||||||
fprintf(stream, "export: %s # default: false\n", params.export_cgraph ? "true" : "false");
|
|
||||||
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n");
|
||||||
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty);
|
fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", params.frequency_penalty);
|
||||||
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
|
dump_string_yaml_multiline(stream, "grammar", params.grammar.c_str());
|
||||||
|
@ -111,7 +111,6 @@ struct gpt_params {
|
|||||||
bool use_mmap = true; // use mmap for faster loads
|
bool use_mmap = true; // use mmap for faster loads
|
||||||
bool use_mlock = false; // use mlock to keep model in memory
|
bool use_mlock = false; // use mlock to keep model in memory
|
||||||
bool numa = false; // attempt optimizations that help on some NUMA systems
|
bool numa = false; // attempt optimizations that help on some NUMA systems
|
||||||
bool export_cgraph = false; // export the computation graph
|
|
||||||
bool verbose_prompt = false; // print prompt tokens before generation
|
bool verbose_prompt = false; // print prompt tokens before generation
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -158,7 +158,8 @@ int main(int argc, char ** argv)
|
|||||||
}
|
}
|
||||||
std::cout << std::flush;
|
std::cout << std::flush;
|
||||||
|
|
||||||
int n_past = llama_get_kv_cache_token_count(ctx);
|
int n_past = 0;
|
||||||
|
|
||||||
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
|
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
|
||||||
{
|
{
|
||||||
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
|
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
|
||||||
|
@ -198,15 +198,6 @@ int main(int argc, char ** argv) {
|
|||||||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||||
}
|
}
|
||||||
|
|
||||||
// export the cgraph and exit
|
|
||||||
if (params.export_cgraph) {
|
|
||||||
llama_eval_export(ctx, "llama.ggml");
|
|
||||||
llama_free(ctx);
|
|
||||||
llama_free_model(model);
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string path_session = params.path_prompt_cache;
|
std::string path_session = params.path_prompt_cache;
|
||||||
std::vector<llama_token> session_tokens;
|
std::vector<llama_token> session_tokens;
|
||||||
|
|
||||||
|
@ -73,10 +73,12 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const int n_gen = std::min(32, max_context_size);
|
const int n_gen = std::min(32, max_context_size);
|
||||||
|
|
||||||
while (llama_get_kv_cache_token_count(ctx) < n_gen) {
|
int n_cur = 0;
|
||||||
|
|
||||||
|
while (n_cur < n_gen) {
|
||||||
// evaluate the transformer
|
// evaluate the transformer
|
||||||
|
|
||||||
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) {
|
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) {
|
||||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
6
ggml.c
6
ggml.c
@ -12462,13 +12462,11 @@ static void ggml_compute_forward_alibi_f16(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_past = ((int32_t *) dst->op_params)[0];
|
//const int n_past = ((int32_t *) dst->op_params)[0];
|
||||||
const int n_head = ((int32_t *) dst->op_params)[1];
|
const int n_head = ((int32_t *) dst->op_params)[1];
|
||||||
float max_bias;
|
float max_bias;
|
||||||
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
|
||||||
|
|
||||||
assert(n_past >= 0);
|
|
||||||
|
|
||||||
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
|
||||||
const int ne1 = src0->ne[1]; // seq_len_without_past
|
const int ne1 = src0->ne[1]; // seq_len_without_past
|
||||||
const int ne2 = src0->ne[2]; // n_head -> this is k
|
const int ne2 = src0->ne[2]; // n_head -> this is k
|
||||||
@ -12483,7 +12481,7 @@ static void ggml_compute_forward_alibi_f16(
|
|||||||
//const int nb3 = src0->nb[3];
|
//const int nb3 = src0->nb[3];
|
||||||
|
|
||||||
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
|
||||||
GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
//GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
|
||||||
GGML_ASSERT(n_head == ne2);
|
GGML_ASSERT(n_head == ne2);
|
||||||
|
|
||||||
// add alibi to src0 (KQ_scaled)
|
// add alibi to src0 (KQ_scaled)
|
||||||
|
34
llama.h
34
llama.h
@ -60,7 +60,20 @@ extern "C" {
|
|||||||
struct llama_model;
|
struct llama_model;
|
||||||
struct llama_context;
|
struct llama_context;
|
||||||
|
|
||||||
typedef int llama_token;
|
typedef int32_t llama_pos;
|
||||||
|
typedef int32_t llama_token;
|
||||||
|
typedef int32_t llama_seq_id;
|
||||||
|
|
||||||
|
// data used for batch inference
|
||||||
|
typedef struct llama_batch {
|
||||||
|
uint32_t n_tokens;
|
||||||
|
|
||||||
|
// TODO: not sure about these consts - might just get in the way all the time with no benefit
|
||||||
|
const llama_token * token;
|
||||||
|
const float * embd;
|
||||||
|
const llama_pos * pos;
|
||||||
|
const llama_seq_id * seq_id;
|
||||||
|
} llama_seq;
|
||||||
|
|
||||||
enum llama_log_level {
|
enum llama_log_level {
|
||||||
LLAMA_LOG_LEVEL_ERROR = 2,
|
LLAMA_LOG_LEVEL_ERROR = 2,
|
||||||
@ -289,8 +302,15 @@ extern "C" {
|
|||||||
const char * path_base_model,
|
const char * path_base_model,
|
||||||
int n_threads);
|
int n_threads);
|
||||||
|
|
||||||
|
//
|
||||||
|
// KV cache API
|
||||||
|
//
|
||||||
|
|
||||||
// Returns the number of tokens in the KV cache
|
// Returns the number of tokens in the KV cache
|
||||||
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
|
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||||
|
"avoid using this, it will be removed in the future");
|
||||||
|
|
||||||
|
LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1);
|
||||||
|
|
||||||
// Sets the current rng seed.
|
// Sets the current rng seed.
|
||||||
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed);
|
||||||
@ -319,7 +339,7 @@ extern "C" {
|
|||||||
LLAMA_API int llama_eval(
|
LLAMA_API int llama_eval(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
int n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads);
|
int n_threads);
|
||||||
|
|
||||||
@ -327,16 +347,10 @@ extern "C" {
|
|||||||
LLAMA_API int llama_eval_embd(
|
LLAMA_API int llama_eval_embd(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const float * embd,
|
const float * embd,
|
||||||
int n_tokens,
|
uint32_t n_tokens,
|
||||||
int n_past,
|
int n_past,
|
||||||
int n_threads);
|
int n_threads);
|
||||||
|
|
||||||
// Export a static computation graph for context of 511 and batch size of 1
|
|
||||||
// NOTE: since this functionality is mostly for debugging and demonstration purposes, we hardcode these
|
|
||||||
// parameters here to keep things simple
|
|
||||||
// IMPORTANT: do not use for anything else other than debugging and testing!
|
|
||||||
LLAMA_API int llama_eval_export(struct llama_context * ctx, const char * fname);
|
|
||||||
|
|
||||||
// Token logits obtained from the last call to llama_eval()
|
// Token logits obtained from the last call to llama_eval()
|
||||||
// The logits for the last token are stored in the last row
|
// The logits for the last token are stored in the last row
|
||||||
// Can be mutated in order to change the probabilities of the next token
|
// Can be mutated in order to change the probabilities of the next token
|
||||||
|
@ -87,12 +87,13 @@ int main(int argc, char **argv) {
|
|||||||
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
|
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
|
||||||
std::string check = llama_detokenize_spm(ctx, tokens);
|
std::string check = llama_detokenize_spm(ctx, tokens);
|
||||||
if (check != str) {
|
if (check != str) {
|
||||||
fprintf(stderr, "%s : error: token %d detokenizes to >%s<(%llu) but tokenization of this detokenizes to >%s<(%llu)\n",
|
fprintf(stderr, "%s : error: token %d detokenizes to >%s<(%d) but tokenization of this detokenizes to >%s<(%d)\n",
|
||||||
__func__, i, str.c_str(), str.length(), check.c_str(), check.length());
|
__func__, i, str.c_str(), (int) str.length(), check.c_str(), (int) check.length());
|
||||||
if(i != 3)
|
if (i != 3) {
|
||||||
return 2;
|
return 2;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (codepoint cp = 0x0000; cp < 0xffff; ++cp) {
|
for (codepoint cp = 0x0000; cp < 0xffff; ++cp) {
|
||||||
if (cp < 0xd800 || cp > 0xdfff) {
|
if (cp < 0xd800 || cp > 0xdfff) {
|
||||||
@ -100,20 +101,21 @@ int main(int argc, char **argv) {
|
|||||||
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
|
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
|
||||||
std::string check = llama_detokenize_spm(ctx, tokens);
|
std::string check = llama_detokenize_spm(ctx, tokens);
|
||||||
if (str != check) {
|
if (str != check) {
|
||||||
fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n",
|
fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%d) instead of >%s<(%d)\n",
|
||||||
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
|
__func__, cp, check.c_str(), (int) check.length(), str.c_str(), (int) str.length());
|
||||||
if(cp != 0 && cp != 9601)
|
if (cp != 0 && cp != 9601) {
|
||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
for (codepoint cp = 0x10000; cp < 0x0010ffff; ++cp) {
|
for (codepoint cp = 0x10000; cp < 0x0010ffff; ++cp) {
|
||||||
std::string str = codepoint_to_utf8(cp);
|
std::string str = codepoint_to_utf8(cp);
|
||||||
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
|
std::vector<llama_token> tokens = llama_tokenize(ctx, str, false);
|
||||||
std::string check = llama_detokenize_spm(ctx, tokens);
|
std::string check = llama_detokenize_spm(ctx, tokens);
|
||||||
if (str != check) {
|
if (str != check) {
|
||||||
fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%llu) instead of >%s<(%llu)\n",
|
fprintf(stderr, "%s : error: codepoint %d detokenizes to >%s<(%d) instead of >%s<(%d)\n",
|
||||||
__func__, cp, check.c_str(), check.length(), str.c_str(), str.length());
|
__func__, cp, check.c_str(), (int) check.length(), str.c_str(), (int) str.length());
|
||||||
return 4;
|
return 4;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user