llama : remove useless return value for some llama_cache_* functions

This commit is contained in:
Francis Couture-Harpin 2024-04-29 12:59:43 -04:00
parent c460ff1a1c
commit b6fafd1747
2 changed files with 19 additions and 42 deletions

View File

@ -2887,7 +2887,6 @@ static bool llama_cache_init(
bool offload) { bool offload) {
const struct llama_hparams & hparams = model.hparams; const struct llama_hparams & hparams = model.hparams;
// TODO: per layer n_embd_* // TODO: per layer n_embd_*
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
@ -3010,6 +3009,8 @@ static bool llama_cache_find_slot(
const uint32_t rs_size = cache.rs.size; const uint32_t rs_size = cache.rs.size;
const uint32_t n_tokens = batch.n_tokens; const uint32_t n_tokens = batch.n_tokens;
// FIXME: on failure, leave all caches in a consistent state.
if (rs_size > 0) { if (rs_size > 0) {
// For recurrent state architectures (like Mamba), // For recurrent state architectures (like Mamba),
// each cache cell can store the state for a whole sequence. // each cache cell can store the state for a whole sequence.
@ -3509,7 +3510,7 @@ static void llama_cache_seq_keep(struct llama_cache & cache, llama_seq_id seq_id
} }
} }
static llama_pos llama_cache_seq_add( static void llama_cache_seq_add(
struct llama_cache & cache, struct llama_cache & cache,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
@ -3519,8 +3520,6 @@ static llama_pos llama_cache_seq_add(
if (p0 < 0) { p0 = 0; } if (p0 < 0) { p0 = 0; }
if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); } if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); }
llama_pos n_past = p0;
if (cache.rs.size > 0) { if (cache.rs.size > 0) {
// for Mamba-like models, only the pos needs to be shifted // for Mamba-like models, only the pos needs to be shifted
auto & seq = cache.rs.seq_tails[seq_id]; auto & seq = cache.rs.seq_tails[seq_id];
@ -3541,9 +3540,6 @@ static llama_pos llama_cache_seq_add(
} }
} }
} }
if (n_past <= rs_cell.pos) {
n_past = rs_cell.pos + 1;
}
} }
// If we freed up a slot, set head to it so searching can start there. // If we freed up a slot, set head to it so searching can start there.
@ -3573,9 +3569,6 @@ static llama_pos llama_cache_seq_add(
} }
} }
} }
if (n_past <= kv_cell.pos) {
n_past = kv_cell.pos + 1;
}
} }
} }
@ -3583,11 +3576,9 @@ static llama_pos llama_cache_seq_add(
// Otherwise we just start the next search from the beginning. // Otherwise we just start the next search from the beginning.
cache.kv.head = new_head != cache.kv.size ? new_head : 0; cache.kv.head = new_head != cache.kv.size ? new_head : 0;
} }
return n_past;
} }
static llama_pos llama_cache_seq_div( static void llama_cache_seq_div(
struct llama_cache & cache, struct llama_cache & cache,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
@ -3596,8 +3587,6 @@ static llama_pos llama_cache_seq_div(
if (p0 < 0) { p0 = 0; } if (p0 < 0) { p0 = 0; }
if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); } if (p1 < 0) { p1 = std::numeric_limits<llama_pos>::max(); }
llama_pos n_past = p0;
if (cache.rs.size > 0) { if (cache.rs.size > 0) {
// for Mamba-like models, only the pos needs to be changed // for Mamba-like models, only the pos needs to be changed
auto & seq = cache.rs.seq_tails[seq_id]; auto & seq = cache.rs.seq_tails[seq_id];
@ -3609,9 +3598,6 @@ static llama_pos llama_cache_seq_div(
rs_cell.pos /= d; rs_cell.pos /= d;
} }
cell_id = rs_cell.prev; cell_id = rs_cell.prev;
if (n_past <= rs_cell.pos) {
n_past = rs_cell.pos + 1;
}
} }
} }
@ -3628,14 +3614,9 @@ static llama_pos llama_cache_seq_div(
kv_cell.delta += kv_cell.pos - p_old; kv_cell.delta += kv_cell.pos - p_old;
} }
} }
if (n_past <= kv_cell.pos) {
n_past = kv_cell.pos + 1;
}
} }
} }
} }
return n_past;
} }
static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) { static llama_pos llama_cache_seq_pos_max(struct llama_cache & cache, llama_seq_id seq_id) {
@ -16935,13 +16916,11 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) {
llama_cache_seq_keep(ctx, seq_id); llama_cache_seq_keep(ctx, seq_id);
} }
llama_pos llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { void llama_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; }
if (delta == 0) { if (delta == 0) { return; }
return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1;
}
return llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta); llama_cache_seq_add(ctx->cache, seq_id, p0, p1, delta);
} }
// deprecated // deprecated
@ -16949,13 +16928,11 @@ void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, lla
llama_cache_seq_add(ctx, seq_id, p0, p1, delta); llama_cache_seq_add(ctx, seq_id, p0, p1, delta);
} }
llama_pos llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { void llama_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) {
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; }
if (d == 1) { if (d == 1) { return; }
return llama_cache_seq_pos_max(ctx->cache, seq_id) + 1;
}
return llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d); llama_cache_seq_div(ctx->cache, seq_id, p0, p1, d);
} }
// deprecated // deprecated

14
llama.h
View File

@ -562,7 +562,8 @@ extern "C" {
// seq_id < 0 : match any sequence // seq_id < 0 : match any sequence
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
// Returns n_past // Returns n_past (one more than the largest remaining pos in the seq_id)
// which is only meaningful to handle for partial removals.
LLAMA_API llama_pos llama_cache_seq_rm( LLAMA_API llama_pos llama_cache_seq_rm(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
@ -579,7 +580,8 @@ extern "C" {
// Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
// Returns n_past // Returns n_past (one more than the largest remaining pos in the destination seq_id)
// which is only meaningful to handle when partially copying.
LLAMA_API llama_pos llama_cache_seq_cp( LLAMA_API llama_pos llama_cache_seq_cp(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id_src, llama_seq_id seq_id_src,
@ -609,8 +611,7 @@ extern "C" {
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
// Returns n_past LLAMA_API void llama_cache_seq_add(
LLAMA_API llama_pos llama_cache_seq_add(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
@ -630,8 +631,7 @@ extern "C" {
// - explicitly with llama_kv_cache_update() // - explicitly with llama_kv_cache_update()
// p0 < 0 : [0, p1] // p0 < 0 : [0, p1]
// p1 < 0 : [p0, inf) // p1 < 0 : [p0, inf)
// Returns n_past LLAMA_API void llama_cache_seq_div(
LLAMA_API llama_pos llama_cache_seq_div(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
@ -652,7 +652,7 @@ extern "C" {
LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max(
struct llama_context * ctx, struct llama_context * ctx,
llama_seq_id seq_id), llama_seq_id seq_id),
"use llama_cache_seq_pos_max instead, which also now returns -1 instead of 0 when the seq_id has no cells"); "use llama_cache_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells");
// Defragment the KV cache // Defragment the KV cache
// This will be applied: // This will be applied: