diff --git a/src/llama.cpp b/src/llama.cpp index 7f2f00031..b0ee74bc4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2739,6 +2739,9 @@ struct llama_context { std::vector buf_compute_meta; ggml_backend_sched_t sched = nullptr; + std::vector buf_compute_meta_next; + struct ggml_cgraph * gf_next = nullptr; + ggml_abort_callback abort_callback = nullptr; void * abort_callback_data = nullptr; @@ -8383,7 +8386,7 @@ struct llm_build_context { const float norm_rms_eps; const int32_t n_tokens; - const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) const int32_t n_outputs; const int32_t n_outputs_enc; const int32_t kv_head; // index of where we store new KV data in the cache @@ -8405,7 +8408,8 @@ struct llm_build_context { llama_context & lctx, const llama_batch & batch, const llm_build_cb & cb, - bool worst_case) : + bool worst_case, + bool prepare_only = false) : model (lctx.model), lctx (lctx), hparams (model.hparams), @@ -8442,8 +8446,12 @@ struct llm_build_context { pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), cb (cb), - buf_compute_meta (lctx.buf_compute_meta) { + buf_compute_meta (prepare_only ? lctx.buf_compute_meta_next : lctx.buf_compute_meta) { // all initializations should be done in init() + if (prepare_only) { + const uint32_t pad = llama_kv_cache_get_padding(cparams); + n_kv = std::min(kv_self.size, std::max(pad, GGML_PAD(llama_kv_cache_cell_max(kv_self) + 1, pad))); + } } void init() { @@ -13805,7 +13813,8 @@ static struct ggml_cgraph * llama_build_graph_s_copy(llama_context & lctx) { static struct ggml_cgraph * llama_build_graph( llama_context & lctx, const llama_batch & batch, - bool worst_case) { + bool worst_case, + bool prepare_only = false) { const auto & model = lctx.model; // this callback allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) @@ -13841,7 +13850,7 @@ static struct ggml_cgraph * llama_build_graph( struct ggml_cgraph * result = NULL; - struct llm_build_context llm(lctx, batch, cb, worst_case); + struct llm_build_context llm(lctx, batch, cb, worst_case, prepare_only); llm.init(); @@ -14536,7 +14545,8 @@ static void llama_graph_compute( // static int llama_decode_internal( llama_context & lctx, - llama_batch batch_all) { // TODO: rename back to batch + llama_batch batch_all, // TODO: rename back to batch + bool prepare_only = false) { lctx.is_encoding = false; const uint32_t n_tokens_all = batch_all.n_tokens; @@ -14556,10 +14566,12 @@ static int llama_decode_internal( GGML_ASSERT((cparams.causal_attn || cparams.n_ubatch >= n_tokens_all) && "non-causal attention requires n_ubatch >= n_tokens"); - if (lctx.t_compute_start_us == 0) { - lctx.t_compute_start_us = ggml_time_us(); + if (!prepare_only) { + if (lctx.t_compute_start_us == 0) { + lctx.t_compute_start_us = ggml_time_us(); + } + lctx.n_queued_tokens += n_tokens_all; } - lctx.n_queued_tokens += n_tokens_all; auto & kv_self = lctx.kv_self; @@ -14612,6 +14624,10 @@ static int llama_decode_internal( } } + if (n_tokens_all != 1) { + lctx.gf_next = nullptr; + } + for (uint32_t cur_token = 0; cur_token < n_tokens_all; cur_token += n_ubatch) { const uint32_t n_tokens = std::min(n_ubatch, n_tokens_all - cur_token); llama_batch u_batch = { @@ -14678,7 +14694,7 @@ static int llama_decode_internal( } // non-causal masks do not use the KV cache - if (hparams.causal_attn) { + if (hparams.causal_attn && !prepare_only) { llama_kv_cache_update(&lctx); // if we have enough unused cells before the current head -> @@ -14703,10 +14719,23 @@ static int llama_decode_internal( //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); - ggml_backend_sched_reset(lctx.sched); - ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); - ggml_cgraph * gf = llama_build_graph(lctx, u_batch, false); + ggml_cgraph * gf = lctx.gf_next; + + if (!gf) { + ggml_backend_sched_reset(lctx.sched); + ggml_backend_sched_set_eval_callback(lctx.sched, lctx.cparams.cb_eval, lctx.cparams.cb_eval_user_data); + gf = llama_build_graph(lctx, u_batch, false, prepare_only); + ggml_backend_sched_alloc_graph(lctx.sched, gf); + } + + if (prepare_only) { + lctx.gf_next = gf; + return 0; + } + + lctx.gf_next = nullptr; + // the output is always the last tensor in the graph struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; @@ -14732,7 +14761,6 @@ static int llama_decode_internal( } // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); - ggml_backend_sched_alloc_graph(lctx.sched, gf); llama_set_inputs(lctx, u_batch); @@ -14836,6 +14864,15 @@ static int llama_decode_internal( // overlap with device computation. ggml_backend_sched_reset(lctx.sched); + if (n_tokens_all == 1 && !prepare_only) { + // prepare graph for the next token + llama_token next_token_dummy = 0; + llama_pos n_past = batch_all.all_pos_0 + 1; + llama_seq_id seq_id = 0; + llama_batch batch_next = llama_batch_get_one(&next_token_dummy, 1, n_past, seq_id); + llama_decode_internal(lctx, batch_next, true); + } + return 0; } @@ -16940,6 +16977,7 @@ struct llama_context * llama_new_context_with_model( // buffer used to store the computation graph and the tensor meta data ctx->buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + ctx->buf_compute_meta_next.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel =