llama : add llama_kv_cache_shift_seq + no more context swaps

This commit is contained in:
Georgi Gerganov 2023-09-18 18:00:25 +03:00
parent 86c90e34f5
commit 0cbf3bfef8
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
4 changed files with 66 additions and 29 deletions

View File

@ -781,6 +781,7 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), }; std::vector<llama_token> tmp = { llama_token_bos(lctx), llama_token_eos(lctx), };
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads); llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0), params.n_threads);
llama_kv_cache_keep_seq(lctx, -1);
llama_reset_timings(lctx); llama_reset_timings(lctx);
} }

View File

@ -499,18 +499,23 @@ int main(int argc, char ** argv) {
break; break;
} }
const int n_left = n_past - params.n_keep; const int n_left = n_past - params.n_keep - 1;
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep); const int n_discard = n_left/2;
// always keep the first token - BOS LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past = std::max(1, params.n_keep); n_past, n_left, n_ctx, params.n_keep, n_discard);
n_past_guidance = std::max(1, params.n_keep + guidance_offset);
llama_kv_cache_rm_seq (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_shift_seq(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
n_past -= n_discard;
if (ctx_guidance) {
n_past_guidance -= n_discard;
}
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance); LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
// insert n_left/2 tokens at the start of embd from last_tokens
embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size());
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd)); LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
LOG("clear session path\n"); LOG("clear session path\n");

View File

@ -1007,7 +1007,8 @@ struct llama_layer {
}; };
struct llama_kv_cell { struct llama_kv_cell {
llama_pos pos = -1; llama_pos pos = -1;
llama_pos delta = 0;
std::set<llama_seq_id> seq_id; std::set<llama_seq_id> seq_id;
@ -1018,7 +1019,7 @@ struct llama_kv_cell {
// ring-buffer of cached KV data // ring-buffer of cached KV data
struct llama_kv_cache { struct llama_kv_cache {
bool is_roped = false; bool has_shift = false;
uint32_t head = 0; uint32_t head = 0;
uint32_t size = 0; uint32_t size = 0;
@ -1223,6 +1224,8 @@ static bool llama_kv_cache_init(
const int64_t n_mem = n_layer*n_ctx; const int64_t n_mem = n_layer*n_ctx;
const int64_t n_elements = n_embd*n_mem; const int64_t n_elements = n_embd*n_mem;
cache.has_shift = false;
cache.head = 0; cache.head = 0;
cache.size = n_ctx; cache.size = n_ctx;
@ -1333,9 +1336,13 @@ void llama_kv_cache_rm_tokens(struct llama_kv_cache & cache, int32_t c0, int32_t
} }
} }
void llama_kv_cache_rm_seq(struct llama_kv_cache & cache, llama_seq_id seq_id) { void llama_kv_cache_rm_seq(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
for (uint32_t i = 0; i < cache.size; ++i) { for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id)) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.erase(seq_id); cache.cells[i].seq_id.erase(seq_id);
if (cache.cells[i].seq_id.empty()) { if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1; cache.cells[i].pos = -1;
@ -1353,18 +1360,22 @@ void llama_kv_cache_keep_seq(struct llama_kv_cache & cache, llama_seq_id seq_id)
} }
} }
void llama_kv_cache_shift( void llama_kv_cache_shift_seq(
struct llama_context & ctx, struct llama_kv_cache & cache,
llama_seq_id seq_id, llama_seq_id seq_id,
llama_pos p0, llama_pos p0,
llama_pos p1, llama_pos p1,
llama_pos delta) { llama_pos delta) {
auto & hparams = ctx.model.hparams;
auto & cache = ctx.kv_self;
for (uint32_t i = 0; i < cache.size; ++i) { for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].pos += delta; cache.cells[i].pos += delta;
if (cache.cells[i].pos < 0) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
} else {
cache.has_shift = true;
cache.cells[i].delta = delta;
}
} }
} }
} }
@ -2595,6 +2606,8 @@ static struct ggml_cgraph * llm_build_llama(
const int32_t n_tokens = batch.n_tokens; const int32_t n_tokens = batch.n_tokens;
const int32_t n_kv = llama_kv_cache_cell_max(kv_self); const int32_t n_kv = llama_kv_cache_cell_max(kv_self);
const bool do_rope_shift = kv_self.has_shift || ggml_allocr_is_measure(lctx.alloc);
auto & buf_compute = lctx.buf_compute; auto & buf_compute = lctx.buf_compute;
struct ggml_init_params params = { struct ggml_init_params params = {
@ -2698,6 +2711,16 @@ static struct ggml_cgraph * llm_build_llama(
} }
} }
// K_shift
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx);
ggml_allocr_alloc(lctx.alloc, K_shift);
if (!ggml_allocr_is_measure(lctx.alloc)) {
int * data = (int *) K_shift->data;
for (int i = 0; i < n_ctx; ++i) {
data[i] = kv_self.cells[i].delta;
}
}
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
ggml_format_name(inpL, "layer_inp_%d", il); ggml_format_name(inpL, "layer_inp_%d", il);
@ -2723,6 +2746,17 @@ static struct ggml_cgraph * llm_build_llama(
ggml_set_name(cur, "attention_norm_0"); ggml_set_name(cur, "attention_norm_0");
} }
if (do_rope_shift) {
ggml_build_forward_expand(gf,
ggml_rope_custom_inplace(ctx0,
ggml_view_3d(ctx0, kv_self.k,
n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv_self.k)*n_embd_head,
ggml_element_size(kv_self.k)*n_embd_gqa,
ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il),
K_shift, n_embd_head, 0, 0, freq_base, freq_scale));
}
// self-attention // self-attention
{ {
// compute Q and K and RoPE them // compute Q and K and RoPE them
@ -4033,7 +4067,8 @@ static bool llama_eval_internal(
#endif #endif
// update the kv ring buffer // update the kv ring buffer
lctx.kv_self.head += n_tokens; lctx.kv_self.head += n_tokens;
lctx.kv_self.has_shift = false;
#ifdef GGML_PERF #ifdef GGML_PERF
// print timing information per ggml operation (for debugging purposes) // print timing information per ggml operation (for debugging purposes)
@ -6562,10 +6597,6 @@ struct llama_context * llama_new_context_with_model(
return nullptr; return nullptr;
} }
if (model->arch == LLM_ARCH_LLAMA) {
ctx->kv_self.is_roped = true;
}
{ {
const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v); const size_t memory_size = ggml_nbytes(ctx->kv_self.k) + ggml_nbytes(ctx->kv_self.v);
LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); LLAMA_LOG_INFO("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
@ -6803,16 +6834,16 @@ void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1
llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1); llama_kv_cache_rm_tokens(ctx->kv_self, c0, c1);
} }
void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id) { void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
llama_kv_cache_rm_seq(ctx->kv_self, seq_id); llama_kv_cache_rm_seq(ctx->kv_self, seq_id, p0, p1);
} }
void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) { void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id) {
llama_kv_cache_keep_seq(ctx->kv_self, seq_id); llama_kv_cache_keep_seq(ctx->kv_self, seq_id);
} }
void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) {
llama_kv_cache_shift(*ctx, seq_id, p0, p1, delta); llama_kv_cache_shift_seq(ctx->kv_self, seq_id, p0, p1, delta);
} }
// Returns the *maximum* size of the state // Returns the *maximum* size of the state

View File

@ -324,15 +324,15 @@ extern "C" {
// Remove all tokens data of cells in [c0, c1) // Remove all tokens data of cells in [c0, c1)
LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1); LLAMA_API void llama_kv_cache_rm_tokens(struct llama_context * ctx, int32_t c0, int32_t c1);
// Removes all tokens that belong to the specified sequence // Removes all tokens that belong to the specified sequence and have positions in [p0, p1)
LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API void llama_kv_cache_rm_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1);
// Removes all tokens that do not belong to the specified sequence // Removes all tokens that do not belong to the specified sequence
LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id); LLAMA_API void llama_kv_cache_keep_seq(struct llama_context * ctx, llama_seq_id seq_id);
// Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1)
// If the KV cache is RoPEd, the KV data is updated accordingly // If the KV cache is RoPEd, the KV data is updated accordingly
LLAMA_API void llama_kv_cache_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); LLAMA_API void llama_kv_cache_shift_seq(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta);
// //
// State / sessions // State / sessions