From 9f42e75489e38d09792ccc169f2eb25a4387afdd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 18 Sep 2023 14:23:52 +0300 Subject: [PATCH] llama : add new llama_decode() API that works with llama_batch --- common/common.cpp | 2 +- examples/beam-search/beam-search.cpp | 2 +- examples/embd-input/embd-input-lib.cpp | 5 +- examples/embedding/embedding.cpp | 2 +- examples/llama-bench/llama-bench.cpp | 4 +- examples/main/main.cpp | 4 +- examples/perplexity/perplexity.cpp | 6 +- examples/save-load-state/save-load-state.cpp | 16 +-- examples/server/server.cpp | 2 +- examples/simple/simple.cpp | 2 +- examples/speculative/speculative.cpp | 12 +- llama.cpp | 119 ++++++++++++------- llama.h | 45 +++++-- 13 files changed, 146 insertions(+), 75 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index b8d306ae2..b638efe9e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -780,7 +780,7 @@ std::tuple llama_init_from_gpt_par LOG("warming up the model with an empty run\n"); std::vector tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; - llama_eval(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, params.n_threads); + llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads); llama_reset_timings(lctx); } diff --git a/examples/beam-search/beam-search.cpp b/examples/beam-search/beam-search.cpp index 2e0481ad6..63da7c3ec 100644 --- a/examples/beam-search/beam-search.cpp +++ b/examples/beam-search/beam-search.cpp @@ -160,7 +160,7 @@ int main(int argc, char ** argv) int n_past = 0; - if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads)) + if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads)) { fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ ); return 1; diff --git a/examples/embd-input/embd-input-lib.cpp b/examples/embd-input/embd-input-lib.cpp index fc6e44eb2..ed0966a51 100644 --- a/examples/embd-input/embd-input-lib.cpp +++ b/examples/embd-input/embd-input-lib.cpp @@ -79,7 +79,8 @@ bool eval_float(void * model, float * input, int N){ if (n_eval > n_batch) { n_eval = n_batch; } - if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) { + llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false }; + if (llama_decode(ctx, batch, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } @@ -100,7 +101,7 @@ bool eval_tokens(void * model, std::vector tokens) { if (n_eval > params.n_batch) { n_eval = params.n_batch; } - if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 0788f362c..54a156b28 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -77,7 +77,7 @@ int main(int argc, char ** argv) { while (!embd_inp.empty()) { int n_tokens = std::min(params.n_batch, (int) embd_inp.size()); - if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 34ddfde39..2551f8422 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -891,7 +891,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat int n_processed = 0; while (n_processed < n_prompt) { int n_tokens = std::min(n_prompt - n_processed, n_batch); - llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads); n_processed += n_tokens; } } @@ -899,7 +899,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { llama_token token = llama_token_bos(ctx); for (int i = 0; i < n_gen; i++) { - llama_eval(ctx, &token, 1, n_past + i, n_threads); + llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads); } } diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 19cbbb2a1..3e78fdaa0 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -571,7 +571,7 @@ int main(int argc, char ** argv) { for (int i = 0; i < input_size; i += params.n_batch) { int n_eval = std::min(input_size - i, params.n_batch); - if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) { + if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } @@ -588,7 +588,7 @@ int main(int argc, char ** argv) { LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); - if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) { LOG_TEE("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 4958cdfb9..2a046d55e 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -199,7 +199,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int batch_size = std::min(end - batch_start, n_batch); //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { //fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -331,7 +331,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par tokens[batch_start] = llama_token_bos(ctx); } - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -409,7 +409,7 @@ static std::vector hellaswag_evaluate_tokens( for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) { size_t n_tokens = tokens.size() - i_chunk * n_batch; n_tokens = std::min(n_tokens, size_t(n_batch)); - if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) { + if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) { fprintf(stderr, "%s : failed to eval\n", __func__); return {}; } diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index eac307904..5e1a097be 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -34,11 +34,11 @@ int main(int argc, char ** argv) { auto last_n_tokens_data = std::vector(params.repeat_last_n, 0); // init - auto model = llama_load_model_from_file(params.model.c_str(), lparams); + auto * model = llama_load_model_from_file(params.model.c_str(), lparams); if (model == nullptr) { return 1; } - auto ctx = llama_new_context_with_model(model, lparams); + auto * ctx = llama_new_context_with_model(model, lparams); if (ctx == nullptr) { llama_free_model(model); return 1; @@ -53,7 +53,7 @@ int main(int argc, char ** argv) { } // evaluate prompt - llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads); + llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads); last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens); n_past += n_prompt_tokens; @@ -77,7 +77,7 @@ int main(int argc, char ** argv) { printf("\n%s", params.prompt.c_str()); for (auto i = 0; i < params.n_predict; i++) { - auto logits = llama_get_logits(ctx); + auto * logits = llama_get_logits(ctx); auto n_vocab = llama_n_vocab(ctx); std::vector candidates; candidates.reserve(n_vocab); @@ -90,7 +90,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx); llama_free_model(model); @@ -105,7 +105,7 @@ int main(int argc, char ** argv) { llama_free(ctx); // make new context - auto ctx2 = llama_new_context_with_model(model, lparams); + auto * ctx2 = llama_new_context_with_model(model, lparams); // Load state (rng, logits, embedding and kv_cache) from file { @@ -137,7 +137,7 @@ int main(int argc, char ** argv) { // second run for (auto i = 0; i < params.n_predict; i++) { - auto logits = llama_get_logits(ctx2); + auto * logits = llama_get_logits(ctx2); auto n_vocab = llama_n_vocab(ctx2); std::vector candidates; candidates.reserve(n_vocab); @@ -150,7 +150,7 @@ int main(int argc, char ** argv) { last_n_tokens_data.push_back(next_token); printf("%s", next_token_str.c_str()); - if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); llama_free(ctx2); llama_free_model(model); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1bb8e92c0..6c81bd618 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -434,7 +434,7 @@ struct llama_server_context { n_eval = params.n_batch; } - if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads)) + if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads)) { LOG_ERROR("failed to eval", { {"n_eval", n_eval}, diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 37eaf3b2c..33ef0770b 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -76,7 +76,7 @@ int main(int argc, char ** argv) { while (n_cur < n_gen) { // evaluate the transformer - if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) { + if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), int(tokens_list.size()), n_cur, 0), params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index aa904183f..06173393c 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -70,9 +70,9 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads); - llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads); - llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads); + llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0), params.n_threads); + llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0), params.n_threads); + llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0), params.n_threads); const auto t_enc_end = ggml_time_us(); @@ -172,7 +172,7 @@ int main(int argc, char ** argv) { LOG("out of drafted tokens\n"); } - llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads); + llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads); ++n_past_dft; // heuristic for n_draft @@ -256,7 +256,7 @@ int main(int argc, char ** argv) { } // evaluate the drafted token on the draft model - llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads); + llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads); ++n_past_cur; if (grammar_dft != NULL) { @@ -265,7 +265,7 @@ int main(int argc, char ** argv) { } // evaluate the target model on the drafted tokens - llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads); + llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads); ++n_past_tgt; // the first token is always proposed by the traget model before the speculation loop diff --git a/llama.cpp b/llama.cpp index 0e1c8755c..601f557ef 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1265,7 +1265,7 @@ static bool llama_kv_cache_init( // updates the cache head static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, - struct llama_batch & batch) { + const struct llama_batch & batch) { const uint32_t n_ctx = cache.size; const uint32_t n_tokens = batch.n_tokens; @@ -2522,7 +2522,7 @@ static bool llama_model_load( static struct ggml_cgraph * llm_build_llama( llama_context & lctx, - llama_batch & batch) { + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -2876,7 +2876,7 @@ static struct ggml_cgraph * llm_build_llama( static struct ggml_cgraph * llm_build_baichaun( llama_context & lctx, - llama_batch & batch) { + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -3247,7 +3247,7 @@ static struct ggml_cgraph * llm_build_baichaun( static struct ggml_cgraph * llm_build_falcon( llama_context & lctx, - llama_batch & batch) { + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -3577,7 +3577,7 @@ static struct ggml_cgraph * llm_build_falcon( static struct ggml_cgraph * llm_build_starcoder( llama_context & lctx, - llama_batch & batch) { + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; @@ -3819,7 +3819,7 @@ static struct ggml_cgraph * llm_build_starcoder( static struct ggml_cgraph * llama_build_graph( llama_context & lctx, - llama_batch & batch) { + const llama_batch & batch) { const auto & model = lctx.model; struct ggml_cgraph * result = NULL; @@ -3856,7 +3856,7 @@ static struct ggml_cgraph * llama_build_graph( // static bool llama_eval_internal( llama_context & lctx, - llama_batch & batch, + llama_batch batch, int n_threads) { const uint32_t n_tokens = batch.n_tokens; @@ -3886,6 +3886,31 @@ static bool llama_eval_internal( const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; + std::vector pos; + std::vector seq_id; + + if (batch.pos == nullptr) { + pos.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + pos[i] = batch.all_pos_0 + i*batch.all_pos_1; + } + + batch.pos = pos.data(); + } + + if (batch.seq_id == nullptr) { + seq_id.resize(n_tokens); + for (uint32_t i = 0; i < n_tokens; i++) { + seq_id[i] = batch.all_seq_id; + } + + batch.seq_id = seq_id.data(); + } + + if (batch.clear_kv) { + llama_kv_cache_clear(kv_self, 0, -1); + } + if (!llama_kv_cache_find_slot(kv_self, batch)) { return false; } @@ -4820,6 +4845,13 @@ struct llama_grammar * llama_grammar_copy(const struct llama_grammar * grammar) // sampling // +void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { + if (seed == LLAMA_DEFAULT_SEED) { + seed = time(NULL); + } + ctx->rng.seed(seed); +} + void llama_sample_softmax(struct llama_context * ctx, llama_token_data_array * candidates) { GGML_ASSERT(candidates->size > 0); @@ -5469,7 +5501,7 @@ struct llama_beam_search_data { } else { // beam is not at end-of-sentence, so branch with next top_k tokens. if (!beam.tokens.empty()) { - llama_eval(ctx, beam.tokens.data(), beam.tokens.size(), n_past, n_threads); + llama_decode(ctx, llama_batch_get_one(beam.tokens.data(), beam.tokens.size(), n_past, 0), n_threads); } llama_logit_info logit_info(ctx); std::vector next_tokens = logit_info.top_k(n_beams); @@ -5543,7 +5575,7 @@ struct llama_beam_search_data { callback(callback_data, get_beams_state(false)); // Sets common_prefix_length update_beams_from_beam_views(); // Update values (p,eob) that callback may have changed. if (common_prefix_length) { - llama_eval(ctx, beams[0].tokens.data(), common_prefix_length, n_past, n_threads); + llama_decode(ctx, llama_batch_get_one(beams[0].tokens.data(), common_prefix_length, n_past, 0), n_threads); n_past += common_prefix_length; } // Zero-out next_beam probabilities to place them last in following min-heap. @@ -6505,8 +6537,7 @@ struct llama_context * llama_new_context_with_model( // build worst-case graph uint32_t n_tokens = std::min((int)hparams.n_ctx, params.n_batch); llama_token token = llama_token_bos(ctx); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph - llama_batch batch = { n_tokens, &token, nullptr, nullptr, nullptr }; - ggml_cgraph * gf = llama_build_graph(*ctx, batch); + ggml_cgraph * gf = llama_build_graph(*ctx, llama_batch_get_one(&token, n_tokens, 0, 0)); #ifdef GGML_USE_METAL if (params.n_gpu_layers > 0) { @@ -6714,15 +6745,6 @@ void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1) { llama_kv_cache_clear(ctx->kv_self, p0, p1); } -#define LLAMA_MAX_RNG_STATE (64*1024) - -void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed) { - if (seed == LLAMA_DEFAULT_SEED) { - seed = time(NULL); - } - ctx->rng.seed(seed); -} - // Returns the *maximum* size of the state size_t llama_get_state_size(const struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. @@ -7116,21 +7138,9 @@ int llama_eval( uint32_t n_tokens, int n_past, int n_threads) { - std::vector pos(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = n_past + i; - } - - std::vector seq_id(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - seq_id[i] = 0; - } - - llama_batch batch = { n_tokens, tokens, nullptr, pos.data(), seq_id.data(), }; - llama_kv_cache_clear(ctx->kv_self, n_past, -1); - if (!llama_eval_internal(*ctx, batch, n_threads)) { + if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -7151,18 +7161,47 @@ int llama_eval_embd( uint32_t n_tokens, int n_past, int n_threads) { - std::vector pos(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - pos[i] = n_past + i; + llama_kv_cache_clear(ctx->kv_self, n_past, -1); + + llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, n_past, 1, 0, n_past == 0, }; + + if (!llama_eval_internal(*ctx, batch, n_threads)) { + LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); + return 1; } - std::vector seq_id(n_tokens); - for (uint32_t i = 0; i < n_tokens; i++) { - seq_id[i] = 0; + // get a more accurate load time, upon first eval + // TODO: fix this + if (!ctx->has_evaluated_once) { + ctx->t_load_us = ggml_time_us() - ctx->t_start_us; + ctx->has_evaluated_once = true; } - llama_batch batch = { n_tokens, nullptr, embd, pos.data(), seq_id.data(), }; + return 0; +} +struct llama_batch llama_batch_get_one( + const llama_token * tokens, + uint32_t n_tokens, + llama_pos pos_0, + llama_seq_id seq_id) { + return { + /*n_tokens =*/ n_tokens, + /*tokens =*/ tokens, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*seq_id =*/ nullptr, + /*all_pos_0 =*/ pos_0, + /*all_pos_1 =*/ 1, + /*all_seq_id =*/ seq_id, + /*clear_kv =*/ pos_0 == 0, + }; +} + +int llama_decode( + struct llama_context * ctx, + struct llama_batch batch, + int n_threads) { if (!llama_eval_internal(*ctx, batch, n_threads)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); return 1; diff --git a/llama.h b/llama.h index 0af9c1089..b844e172b 100644 --- a/llama.h +++ b/llama.h @@ -37,6 +37,8 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF +#define LLAMA_MAX_RNG_STATE (64*1024) + #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN @@ -70,9 +72,20 @@ extern "C" { // TODO: not sure about these consts - might just get in the way all the time with no benefit const llama_token * token; - const float * embd; + const float * embd; const llama_pos * pos; const llama_seq_id * seq_id; + + // NOTE: helpers for smooth API transition - can be deprecated in the future + // for future-proof code, use the above fields instead and ignore everything below + // + // pos[i] = all_pos_0 + i*all_pos_1 + // + llama_pos all_pos_0; // used if pos == NULL + llama_pos all_pos_1; // used if pos == NULL + llama_seq_id all_seq_id; // used if seq_id == NULL + + bool clear_kv; // if true, clear the entire KV cache. common usage for perplexity calculations } llama_seq; enum llama_log_level { @@ -312,9 +325,6 @@ extern "C" { LLAMA_API void llama_kv_clear(struct llama_context * ctx, int32_t p0, int32_t p1); - // Sets the current rng seed. - LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); - // Returns the maximum size in bytes of the state (rng, logits, embedding // and kv_cache) - will often be smaller after compacting tokens LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); @@ -336,19 +346,37 @@ extern "C" { // tokens + n_tokens is the provided batch of new tokens to process // n_past is the number of tokens to use from previous eval calls // Returns 0 on success - LLAMA_API int llama_eval( + LLAMA_API DEPRECATED(int llama_eval( struct llama_context * ctx, const llama_token * tokens, uint32_t n_tokens, int n_past, - int n_threads); + int n_threads), + "please use llama_decode() instead"); // Same as llama_eval, but use float matrix input directly. - LLAMA_API int llama_eval_embd( + LLAMA_API DEPRECATED(int llama_eval_embd( struct llama_context * ctx, const float * embd, uint32_t n_tokens, int n_past, + int n_threads), + "please use llama_decode() instead"); + + // Return batch for single sequence of tokens starting at pos_0 + // If pos_0 == 0, the clear_kv flag will be auto set to true + // + // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it + // + LLAMA_API struct llama_batch llama_batch_get_one( + const llama_token * tokens, + uint32_t n_tokens, + llama_pos pos_0, + llama_seq_id seq_id); + + LLAMA_API int llama_decode( + struct llama_context * ctx, + struct llama_batch batch, int n_threads); // Token logits obtained from the last call to llama_eval() @@ -434,6 +462,9 @@ extern "C" { // Sampling functions // + // Sets the current rng seed. + LLAMA_API void llama_set_rng_seed(struct llama_context * ctx, uint32_t seed); + /// @details Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. LLAMA_API void llama_sample_repetition_penalty(struct llama_context * ctx, llama_token_data_array * candidates, const llama_token * last_tokens, size_t last_tokens_size, float penalty);