mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-03 15:24:35 +00:00
llama : fix cell_max logic + rename functions
This commit is contained in:
parent
36714e16d0
commit
ddad227782
31
llama.cpp
31
llama.cpp
@ -1319,7 +1319,7 @@ static bool llama_kv_cache_find_slot(
|
||||
|
||||
// find how many cells are currently in use
|
||||
static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
||||
for (uint32_t i = cache.size - 2; i > 0; --i) {
|
||||
for (uint32_t i = cache.size - 1; i > 0; --i) {
|
||||
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
|
||||
return i + 1;
|
||||
}
|
||||
@ -2606,7 +2606,7 @@ static struct ggml_cgraph * llm_build_llama(
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
|
||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
|
||||
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||
|
||||
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
|
||||
@ -2994,7 +2994,7 @@ static struct ggml_cgraph * llm_build_baichaun(
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
|
||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
|
||||
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||
|
||||
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
|
||||
@ -3397,7 +3397,7 @@ static struct ggml_cgraph * llm_build_falcon(
|
||||
const int n_gpu_layers = model.n_gpu_layers;
|
||||
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
|
||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
|
||||
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||
|
||||
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
|
||||
@ -3758,7 +3758,7 @@ static struct ggml_cgraph * llm_build_starcoder(
|
||||
const float norm_eps = hparams.f_norm_eps;
|
||||
|
||||
const int32_t n_tokens = batch.n_tokens;
|
||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
|
||||
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
|
||||
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
|
||||
|
||||
auto & buf_compute = lctx.buf_compute;
|
||||
@ -4013,13 +4013,13 @@ static struct ggml_cgraph * llama_build_graph(
|
||||
return result;
|
||||
}
|
||||
|
||||
// evaluate the transformer
|
||||
// decode a batch of tokens by evaluating the transformer
|
||||
//
|
||||
// - lctx: llama context
|
||||
// - batch: batch to evaluate
|
||||
// - n_threads: number of threads to use
|
||||
//
|
||||
static bool llama_eval_internal(
|
||||
static bool llama_decode_internal(
|
||||
llama_context & lctx,
|
||||
llama_batch batch,
|
||||
int n_threads) {
|
||||
@ -4051,6 +4051,8 @@ static bool llama_eval_internal(
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
|
||||
// helpers for smoother batch API transistion
|
||||
// after deprecating the llama_eval calls, these will be removed
|
||||
std::vector<llama_pos> pos;
|
||||
std::vector<llama_seq_id> seq_id;
|
||||
|
||||
@ -4076,14 +4078,15 @@ static bool llama_eval_internal(
|
||||
// TODO: better strategies can be implemented
|
||||
kv_self.head = 0;
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
// if we start defragmenting the cache, the benefit from this will be more important
|
||||
kv_self.cell_max = llama_kv_cache_cell_max(kv_self);
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
||||
return false;
|
||||
}
|
||||
//printf("kv_self.cell_max = %d\n", kv_self.cell_max);
|
||||
|
||||
ggml_allocr_reset(lctx.alloc);
|
||||
|
||||
@ -7329,7 +7332,7 @@ int llama_eval(
|
||||
int n_threads) {
|
||||
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
|
||||
|
||||
if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
|
||||
if (!llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -7354,7 +7357,7 @@ int llama_eval_embd(
|
||||
|
||||
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||
|
||||
if (!llama_eval_internal(*ctx, batch, n_threads)) {
|
||||
if (!llama_decode_internal(*ctx, batch, n_threads)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
@ -7391,7 +7394,7 @@ int llama_decode(
|
||||
struct llama_context * ctx,
|
||||
struct llama_batch batch,
|
||||
int n_threads) {
|
||||
if (!llama_eval_internal(*ctx, batch, n_threads)) {
|
||||
if (!llama_decode_internal(*ctx, batch, n_threads)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user