Compare commits

...

2 Commits

Author SHA1 Message Date
Jesse Gross
accf266be6
Merge a2d4b6fc81 into 30caac3a68 2024-12-24 13:18:46 +05:30
Jesse Gross
a2d4b6fc81 llama: Ensure KV cache is fully defragmented.
Sometimes the KV cache requires defragmentation even without
triggering the threshold heuristic. In this case, decoding
will not being able to find a KV cache slot. This is particularly
difficult for the caller to handle if it happens in between
ubatches. To avoid this, we should immediately trigger a defrag.

In addition, a heavily fragmented cache can require more than
max_moves to defragment. Currently, we stop when we hit the limit
but this can leave a cache that still does not have adequate space
even after defragmentation is triggered. Instead, we should do
multiple batches of processing until everything is complete.
2024-12-17 12:43:17 -08:00

View File

@ -3053,6 +3053,13 @@ struct llama_kv_cache {
} }
}; };
// block of KV slots to move when defragging
struct llama_kv_defrag_move {
uint32_t src;
uint32_t dst;
uint32_t len;
};
struct llama_control_vector { struct llama_control_vector {
std::vector<struct ggml_tensor *> tensors; // per layer std::vector<struct ggml_tensor *> tensors; // per layer
std::vector<ggml_context_ptr> ctxs; std::vector<ggml_context_ptr> ctxs;
@ -10990,35 +10997,23 @@ struct llm_build_context {
return gf; return gf;
} }
struct ggml_cgraph * build_defrag(const std::vector<uint32_t> & ids) { struct ggml_cgraph * build_defrag(const std::vector<struct llama_kv_defrag_move> & moves) {
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false);
for (uint32_t i = 0; i < ids.size(); ++i) { for (const auto & move : moves) {
const uint32_t id = ids[i];
if (i == id || id == ids.size()) {
continue;
}
uint32_t nm = 1;
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
nm++;
}
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il], ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self.k_l[il],
n_embd_k_gqa, nm, n_embd_k_gqa, move.len,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*i)); ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.src));
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il], ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self.k_l[il],
n_embd_k_gqa, nm, n_embd_k_gqa, move.len,
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*id)); ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*move.dst));
ggml_tensor * view_v_src; ggml_tensor * view_v_src;
ggml_tensor * view_v_dst; ggml_tensor * view_v_dst;
@ -11026,31 +11021,29 @@ struct llm_build_context {
if (flash_attn) { if (flash_attn) {
// NOTE: the V cache is not transposed when using flash attention // NOTE: the V cache is not transposed when using flash attention
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm, n_embd_v_gqa, move.len,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*i)); ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.src));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
n_embd_v_gqa, nm, n_embd_v_gqa, move.len,
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa), ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa),
ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*id)); ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa*move.dst));
} else { } else {
view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il], view_v_src = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa, move.len, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size), ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, i)); ggml_row_size(kv_self.v_l[il]->type, move.src));
view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il], view_v_dst = ggml_view_2d(ctx0, kv_self.v_l[il],
nm, n_embd_v_gqa, move.len, n_embd_v_gqa,
ggml_row_size(kv_self.v_l[il]->type, kv_self.size), ggml_row_size(kv_self.v_l[il]->type, kv_self.size),
ggml_row_size(kv_self.v_l[il]->type, id)); ggml_row_size(kv_self.v_l[il]->type, move.dst));
} }
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
} }
i += nm - 1;
} }
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
@ -17601,7 +17594,7 @@ struct llm_build_context {
} }
}; };
static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) { static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<struct llama_kv_defrag_move> & moves) {
llama_ubatch dummy = {}; llama_ubatch dummy = {};
dummy.equal_seqs = true; dummy.equal_seqs = true;
@ -17611,7 +17604,7 @@ static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const
llm.init(); llm.init();
struct ggml_cgraph * result = llm.build_defrag(ids); struct ggml_cgraph * result = llm.build_defrag(moves);
llm.free(); llm.free();
@ -18627,7 +18620,12 @@ static int llama_decode_internal(
kv_self.head = 0; kv_self.head = 0;
} }
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch); auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
if (!slot) {
llama_kv_cache_defrag(kv_self);
llama_kv_cache_update(&lctx);
slot = llama_kv_cache_find_slot(kv_self, ubatch);
}
if (!slot) { if (!slot) {
return 1; return 1;
} }
@ -19030,8 +19028,8 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
//const int64_t t_start = ggml_time_us(); //const int64_t t_start = ggml_time_us();
// number of cells moved // groups of cells moved
uint32_t n_moves = 0; std::vector<struct llama_kv_defrag_move> moves;
// each move requires 6*n_layer tensors (see build_defrag) // each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation // - source view, destination view, copy operation
@ -19095,19 +19093,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// are we moving a continuous block of memory? // are we moving a continuous block of memory?
bool cont = false; bool cont = false;
// should we stop searching for the next move?
bool stop = false;
// go back and move the nf cells to the hole // go back and move the nf cells to the hole
for (; i1 < n_kv; ++i1) { for (; i1 < n_kv; ++i1) {
auto & cell1 = kv_self.cells[i1]; auto & cell1 = kv_self.cells[i1];
if (cell1.is_empty() || ids[i1] != n_kv) { if (cell1.is_empty() || ids[i1] != n_kv) {
if (n_moves == max_moves) {
stop = true;
break;
}
cont = false; cont = false;
continue; continue;
} }
@ -19123,8 +19113,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
kv_self.head = n_used; kv_self.head = n_used;
if (!cont) { if (!cont) {
n_moves++; moves.push_back({i1, i0 + nf, 1});
cont = true; cont = true;
} else {
moves.back().len++;
} }
nf++; nf++;
@ -19134,22 +19126,16 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
} }
} }
if (stop || n_moves == max_moves) {
break;
}
//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh); //LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);
i0 += nh - 1; i0 += nh - 1;
} }
if (n_moves == 0) { if (moves.size() == 0) {
return; return;
} }
//LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves); //LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", moves.size());
//LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*n_moves*n_layer);
#if 0 #if 0
// CPU defrag // CPU defrag
@ -19224,11 +19210,18 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
#else #else
// ggml_graph defrag // ggml_graph defrag
for (std::size_t i = 0; i < moves.size(); i += max_moves) {
std::vector<struct llama_kv_defrag_move> chunk;
auto end = std::min(i + max_moves, moves.size());
chunk.assign(moves.begin() + i, moves.begin() + end);
ggml_backend_sched_reset(lctx.sched.get()); ggml_backend_sched_reset(lctx.sched.get());
ggml_cgraph * gf = llama_build_graph_defrag(lctx, ids); //LLAMA_LOG_INFO("expected gf nodes: %u\n", 6*chunk.size()*n_layer);
ggml_cgraph * gf = llama_build_graph_defrag(lctx, chunk);
llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool); llama_graph_compute(lctx, gf, lctx.cparams.n_threads, lctx.threadpool);
}
#endif #endif
//const int64_t t_end = ggml_time_us(); //const int64_t t_end = ggml_time_us();