llama : fix cell_max logic + rename functions

This commit is contained in:
Georgi Gerganov 2023-09-19 13:21:12 +03:00
parent 36714e16d0
commit ddad227782
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -1319,7 +1319,7 @@ static bool llama_kv_cache_find_slot(
// find how many cells are currently in use
static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
for (uint32_t i = cache.size - 2; i > 0; --i) {
for (uint32_t i = cache.size - 1; i > 0; --i) {
if (cache.cells[i].pos >= 0 && !cache.cells[i].seq_id.empty()) {
return i + 1;
}
@ -2606,7 +2606,7 @@ static struct ggml_cgraph * llm_build_llama(
const int n_gpu_layers = model.n_gpu_layers;
const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
@ -2994,7 +2994,7 @@ static struct ggml_cgraph * llm_build_baichaun(
const int n_gpu_layers = model.n_gpu_layers;
const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
@ -3397,7 +3397,7 @@ static struct ggml_cgraph * llm_build_falcon(
const int n_gpu_layers = model.n_gpu_layers;
const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift;
@ -3758,7 +3758,7 @@ static struct ggml_cgraph * llm_build_starcoder(
const float norm_eps = hparams.f_norm_eps;
const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max + n_tokens;
const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.cell_max;
const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head;
auto & buf_compute = lctx.buf_compute;
@ -4013,13 +4013,13 @@ static struct ggml_cgraph * llama_build_graph(
return result;
}
// evaluate the transformer
// decode a batch of tokens by evaluating the transformer
//
// - lctx: llama context
// - batch: batch to evaluate
// - n_threads: number of threads to use
//
static bool llama_eval_internal(
static bool llama_decode_internal(
llama_context & lctx,
llama_batch batch,
int n_threads) {
@ -4051,6 +4051,8 @@ static bool llama_eval_internal(
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;
// helpers for smoother batch API transistion
// after deprecating the llama_eval calls, these will be removed
std::vector<llama_pos> pos;
std::vector<llama_seq_id> seq_id;
@ -4076,14 +4078,15 @@ static bool llama_eval_internal(
// TODO: better strategies can be implemented
kv_self.head = 0;
if (!llama_kv_cache_find_slot(kv_self, batch)) {
return false;
}
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
kv_self.cell_max = llama_kv_cache_cell_max(kv_self);
if (!llama_kv_cache_find_slot(kv_self, batch)) {
return false;
}
//printf("kv_self.cell_max = %d\n", kv_self.cell_max);
ggml_allocr_reset(lctx.alloc);
@ -7329,7 +7332,7 @@ int llama_eval(
int n_threads) {
llama_kv_cache_rm_tokens(ctx->kv_self, n_past, -1);
if (!llama_eval_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
if (!llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0), n_threads)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}
@ -7354,7 +7357,7 @@ int llama_eval_embd(
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, n_past, 1, 0, };
if (!llama_eval_internal(*ctx, batch, n_threads)) {
if (!llama_decode_internal(*ctx, batch, n_threads)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}
@ -7391,7 +7394,7 @@ int llama_decode(
struct llama_context * ctx,
struct llama_batch batch,
int n_threads) {
if (!llama_eval_internal(*ctx, batch, n_threads)) {
if (!llama_decode_internal(*ctx, batch, n_threads)) {
LLAMA_LOG_ERROR("%s: failed to eval\n", __func__);
return 1;
}