From ddad2277827865f69456e1864973f580e6c241c3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 19 Sep 2023 13:21:12 +0300 Subject: [PATCH] llama : fix cell_max logic + rename functions --- llama.cpp | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/llama.cpp b/llama.cpp index abfc16c1a..0ecda7268 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1319,7 +1319,7 @@ static bool llama_kv_cache_find_slot( // find how many cells are currently in use static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { - for (uint32_t i = cache.size - 2; i > 0; --i) { + for (uint32_t i = cache.size - 1; i > 0; --i) { if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) { return i + 1; } @@ -2606,7 +2606,7 @@ static struct ggml_cgraph * llm_build_llama( const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max; const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; @@ -2994,7 +2994,7 @@ static struct ggml_cgraph * llm_build_baichaun( const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max; const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; @@ -3397,7 +3397,7 @@ static struct ggml_cgraph * llm_build_falcon( const int n_gpu_layers = model.n_gpu_layers; const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max; const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; @@ -3758,7 +3758,7 @@ static struct ggml_cgraph * llm_build_starcoder( const float norm_eps = hparams.f_norm_eps; const int32_t n_tokens = batch.n_tokens; - const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max; const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; auto & buf_compute = lctx.buf_compute; @@ -4013,13 +4013,13 @@ static struct ggml_cgraph * llama_build_graph( return result; } -// evaluate the transformer +// decode a batch of tokens by evaluating the transformer // // - lctx: llama context // - batch: batch to evaluate // - n_threads: number of threads to use // -static bool llama_eval_internal( +static bool llama_decode_internal( llama_context & lctx, llama_batch batch, int n_threads) { @@ -4051,6 +4051,8 @@ static bool llama_eval_internal( const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; + // helpers for smoother batch API transistion + // after deprecating the llama_eval calls, these will be removed std::vector pos; std::vector seq_id; @@ -4076,14 +4078,15 @@ static bool llama_eval_internal( // TODO: better strategies can be implemented kv_self.head = 0; + if (!llama_kv_cache_find_slot(kv_self, batch)) { + return false; + } + // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important kv_self.cell_max = llama_kv_cache_cell_max(kv_self); - - if (!llama_kv_cache_find_slot(kv_self, batch)) { - return false; - } + //printf("kv_self.cell_max = %d\n", kv_self.cell_max); ggml_allocr_reset(lctx.alloc); @@ -7329,7 +7332,7 @@ int llama_eval( int n_threads) { llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1); - if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) { + if (!llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -7354,7 +7357,7 @@ int llama_eval_embd( llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, }; - if (!llama_eval_internal(*ctx, batch, n_threads)) { + if (!llama_decode_internal(*ctx, batch, n_threads)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); return 1; } @@ -7391,7 +7394,7 @@ int llama_decode( struct llama_context * ctx, struct llama_batch batch, int n_threads) { - if (!llama_eval_internal(*ctx, batch, n_threads)) { + if (!llama_decode_internal(*ctx, batch, n_threads)) { LLAMA_LOG_ERROR("%s: failed to eval\n", __func__); return 1; }