mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-25 02:44:36 +00:00
Extend llama_kv_cache_seq_rm to allow matching any sequence (#3843)
* Extend llama_kv_cache_seq_rm to allow matichng any sequence * Replace llama_kv_cache_tokens_rm with llama_kv_cache_clear Use llama_kv_cache_clear for cache clearing Change calls to llama_kv_cache_tokens_rm that want to delete by position to use llama_kv_cache_seq_rm functionality
This commit is contained in:
parent
2046eb4345
commit
6e08281e58
@ -889,7 +889,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||||||
|
|
||||||
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
|
std::vector<llama_token> tmp = { llama_token_bos(model), llama_token_eos(model), };
|
||||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0));
|
||||||
llama_kv_cache_tokens_rm(lctx, -1, -1);
|
llama_kv_cache_clear(lctx);
|
||||||
llama_reset_timings(lctx);
|
llama_reset_timings(lctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -185,7 +185,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
const auto t_pp_start = ggml_time_us();
|
const auto t_pp_start = ggml_time_us();
|
||||||
|
|
||||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
|
@ -1037,7 +1037,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
test t(inst, lmodel, ctx);
|
test t(inst, lmodel, ctx);
|
||||||
|
|
||||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
// warmup run
|
// warmup run
|
||||||
if (t.n_prompt > 0) {
|
if (t.n_prompt > 0) {
|
||||||
@ -1048,7 +1048,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < params.reps; i++) {
|
for (int i = 0; i < params.reps; i++) {
|
||||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
uint64_t t_start = get_time_ns();
|
uint64_t t_start = get_time_ns();
|
||||||
if (t.n_prompt > 0) {
|
if (t.n_prompt > 0) {
|
||||||
|
@ -298,7 +298,7 @@ int main(int argc, char ** argv) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// remove any "future" tokens that we might have inherited from the previous session
|
// remove any "future" tokens that we might have inherited from the previous session
|
||||||
llama_kv_cache_tokens_rm(ctx, n_matching_session_tokens, -1);
|
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
||||||
}
|
}
|
||||||
|
|
||||||
LOGLN(
|
LOGLN(
|
||||||
|
@ -210,7 +210,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
@ -339,7 +339,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||||
|
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
for (int j = 0; j < num_batches; ++j) {
|
for (int j = 0; j < num_batches; ++j) {
|
||||||
const int batch_start = start + j * n_batch;
|
const int batch_start = start + j * n_batch;
|
||||||
@ -573,7 +573,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// clear the KV cache
|
// clear the KV cache
|
||||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
|
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab);
|
||||||
if (logits.empty()) {
|
if (logits.empty()) {
|
||||||
|
@ -857,7 +857,7 @@ struct llama_server_context
|
|||||||
|
|
||||||
void kv_cache_clear() {
|
void kv_cache_clear() {
|
||||||
// clear the entire KV cache
|
// clear the entire KV cache
|
||||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
llama_kv_cache_clear(ctx);
|
||||||
clean_kv_cache = false;
|
clean_kv_cache = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
27
llama.cpp
27
llama.cpp
@ -1466,17 +1466,12 @@ static int32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) {
|
|||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, int32_t c1) {
|
static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
|
||||||
if (c0 < 0) c0 = 0;
|
for (int32_t i = 0; i < cache.size; ++i) {
|
||||||
if (c1 < 0) c1 = cache.size;
|
|
||||||
|
|
||||||
for (int32_t i = c0; i < c1; ++i) {
|
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
cache.cells[i].seq_id.clear();
|
cache.cells[i].seq_id.clear();
|
||||||
}
|
}
|
||||||
|
cache.head = 0;
|
||||||
// Searching for a free slot can start here since we know it will be empty.
|
|
||||||
cache.head = uint32_t(c0);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_seq_rm(
|
static void llama_kv_cache_seq_rm(
|
||||||
@ -1490,8 +1485,14 @@ static void llama_kv_cache_seq_rm(
|
|||||||
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cache.size; ++i) {
|
for (uint32_t i = 0; i < cache.size; ++i) {
|
||||||
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
|
||||||
|
if (seq_id < 0) {
|
||||||
|
cache.cells[i].seq_id.clear();
|
||||||
|
} else if (cache.cells[i].has_seq_id(seq_id)) {
|
||||||
cache.cells[i].seq_id.erase(seq_id);
|
cache.cells[i].seq_id.erase(seq_id);
|
||||||
|
} else {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
if (cache.cells[i].seq_id.empty()) {
|
if (cache.cells[i].seq_id.empty()) {
|
||||||
cache.cells[i].pos = -1;
|
cache.cells[i].pos = -1;
|
||||||
if (new_head == cache.size) new_head = i;
|
if (new_head == cache.size) new_head = i;
|
||||||
@ -9207,8 +9208,8 @@ int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
|
|||||||
return ctx->kv_self.head;
|
return ctx->kv_self.head;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_tokens_rm(struct llama_context * ctx, int32_t c0, int32_t c1) {
|
void llama_kv_cache_clear(struct llama_context * ctx) {
|
||||||
llama_kv_cache_tokens_rm(ctx->kv_self, c0, c1);
|
llama_kv_cache_clear(ctx->kv_self);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
void llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
|
||||||
@ -9654,7 +9655,7 @@ int llama_eval(
|
|||||||
llama_token * tokens,
|
llama_token * tokens,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int n_past) {
|
int n_past) {
|
||||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
|
||||||
|
|
||||||
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
|
const int ret = llama_decode_internal(*ctx, llama_batch_get_one(tokens, n_tokens, n_past, 0));
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
@ -9669,7 +9670,7 @@ int llama_eval_embd(
|
|||||||
float * embd,
|
float * embd,
|
||||||
int32_t n_tokens,
|
int32_t n_tokens,
|
||||||
int n_past) {
|
int n_past) {
|
||||||
llama_kv_cache_tokens_rm(ctx->kv_self, n_past, -1);
|
llama_kv_cache_seq_rm(ctx->kv_self, -1, n_past, -1);
|
||||||
|
|
||||||
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
llama_batch batch = { n_tokens, nullptr, embd, nullptr, nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||||
|
|
||||||
|
11
llama.h
11
llama.h
@ -334,15 +334,12 @@ extern "C" {
|
|||||||
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx),
|
||||||
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
|
"avoid using this, it will be removed in the future, instead - count the tokens in user code");
|
||||||
|
|
||||||
// Remove all tokens data of cells in [c0, c1)
|
// Clear the KV cache
|
||||||
// c0 < 0 : [0, c1]
|
LLAMA_API void llama_kv_cache_clear(
|
||||||
// c1 < 0 : [c0, inf)
|
struct llama_context * ctx);
|
||||||
LLAMA_API void llama_kv_cache_tokens_rm(
|
|
||||||
struct llama_context * ctx,
|
|
||||||
int32_t c0,
|
|
||||||
int32_t c1);
|
|
||||||
|
|
||||||
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
// Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
|
||||||
|
// seq_id < 0 : match any sequence
|
||||||
// p0 < 0 : [0, p1]
|
// p0 < 0 : [0, p1]
|
||||||
// p1 < 0 : [p0, inf)
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_rm(
|
LLAMA_API void llama_kv_cache_seq_rm(
|
||||||
|
Loading…
Reference in New Issue
Block a user