llama : add llama_kv_cache_compress (EXPERIMENTAL)

This commit is contained in:
Georgi Gerganov 2024-02-25 22:16:13 +02:00
parent c24a2a6e60
commit 14d757066b
No known key found for this signature in database
GPG Key ID: BF970631944C16B7
3 changed files with 262 additions and 0 deletions

View File

@ -148,6 +148,7 @@ int main(int argc, char ** argv) {
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
llama_kv_cache_compress(ctx, 0);
llama_kv_cache_update (ctx);
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;

253
llama.cpp
View File

@ -1738,6 +1738,9 @@ struct llama_kv_cache {
ggml_type type_k = GGML_TYPE_F16;
ggml_type type_v = GGML_TYPE_F16;
// if non-negative, compress data on next update
llama_pos compress_delta = -1;
std::vector<llama_kv_cell> cells;
std::vector<struct ggml_tensor *> k_l; // per layer
@ -2273,6 +2276,10 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
return result;
}
static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos delta) {
cache.compress_delta = delta;
}
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
cache.do_defrag = true;
}
@ -8109,6 +8116,240 @@ static int llama_decode_internal(
return 0;
}
// summary:
//
// - determine which KV cell pairs (i0, i1) to merge:
//
// abs(cell[i0].pos - cell[i1].pos) <= compress_delta
//
// - move the KV cache to the host memory for easier manipulation
// - processing is done layer-by-layer
// - convert the KV data to F32
// - merge the KV data (different ways to merge)
// - convert the KV data back to the original type
// - move the KV cache back to the device memory
// - update the KV cache metadata
//
// as a side effect, the new KV cache is defragmented
//
static void llama_kv_cache_compress_internal(struct llama_context & lctx) {
auto & kv_self = lctx.kv_self;
const auto & hparams = lctx.model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k);
const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v);
const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv);
const uint32_t kv_size = kv_self.size;
const int64_t t_start = ggml_time_us();
std::vector<uint8_t> buf_q;
std::vector<float> buf_src_f32;
std::vector<float> buf_dst_f32;
struct c_pair { uint32_t i0, i1; };
struct c_info { bool merged; uint32_t id, cnt, r; };
std::vector<c_info> infos(kv_size, { false, 0, 0, 0 });
// the destination cell in the new KV cache
uint32_t id = 0;
// number of pairs merged
uint32_t n_merges = 0;
// determine which KV cells to merge
for (uint32_t i0 = 0; i0 < kv_size; ++i0) {
const auto & cell0 = kv_self.cells[i0];
if (!cell0.is_empty() && !infos[i0].merged) {
infos[i0] = { true, id, 0, 0 };
infos[id].cnt = 1;
const llama_pos p0 = cell0.pos;
for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) {
const auto & cell1 = kv_self.cells[i1];
if (i0 != i1 && cell0.is_same_seq(cell1)) {
const llama_pos p1 = cell1.pos;
if (std::abs(p0 - p1) <= kv_self.compress_delta) {
infos[i1] = { true, id, 0, 0 };
infos[id].cnt++;
n_merges++;
}
}
}
if (i0 != id) {
kv_self.cells[id] = cell0;
}
id++;
}
}
kv_self.head = id;
kv_self.used = id;
for (uint32_t i = id; i < kv_size; ++i) {
kv_self.cells[i] = llama_kv_cell();
}
LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges);
ggml_type_traits_t tt_k;
ggml_type_traits_t tt_v;
tt_k = ggml_internal_get_type_traits(kv_self.type_k);
tt_v = ggml_internal_get_type_traits(kv_self.type_v);
for (uint32_t il = 0; il < n_layer; ++il) {
for (uint32_t i = 0; i < kv_size; ++i) {
infos[i].r = 0;
}
// update keys
{
const int64_t ne = n_embd_k_gqa*kv_size;
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne);
buf_q.resize(k_size);
buf_src_f32.resize(ne);
buf_dst_f32.resize(ne);
ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne);
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
for (uint32_t i = 0; i < kv_size; ++i) {
if (!infos[i].merged) {
continue;
}
const uint32_t id = infos[i].id;
#if 1
// merge using averaging
{
const float scale = 1.0f/float(infos[id].cnt);
const int64_t os = i*n_embd_k_gqa;
const int64_t od = id*n_embd_k_gqa;
for (uint32_t j = 0; j < n_embd_k_gqa; ++j) {
buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
}
}
#else
// merge separate heads
{
for (uint32_t h = 0; h < n_head_kv; ++h) {
if ((h + il) % infos[id].cnt != infos[id].r) {
continue;
}
const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k;
const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k;
for (uint32_t j = 0; j < n_embd_head_k; ++j) {
buf_dst_f32[od + j] = buf_src_f32[os + j];
}
}
}
infos[id].r++;
#endif
}
tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
}
for (uint32_t i = 0; i < kv_size; ++i) {
infos[i].r = 0;
}
// update values (note: they are transposed)
{
const int64_t ne = n_embd_v_gqa*kv_size;
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne);
buf_q.resize(v_size);
buf_src_f32.resize(ne);
buf_dst_f32.resize(ne);
ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne);
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
for (uint32_t i = 0; i < kv_size; ++i) {
if (!infos[i].merged) {
continue;
}
const uint32_t id = infos[i].id;
#if 1
// merge using averaging
{
const float scale = 1.0f/float(infos[id].cnt);
//printf("i: %d -> id: %d, scale: %f\n", i, id, scale);
const int64_t os = i;
const int64_t od = id;
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
}
}
#else
// merge separate heads
{
for (uint32_t h = 0; h < n_head_kv; ++h) {
if ((h + il) % infos[id].cnt != infos[id].r) {
continue;
}
const int64_t os = i;
const int64_t od = id;
for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) {
buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size];
}
}
}
infos[id].r++;
#endif
}
tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
}
}
const int64_t t_end = ggml_time_us();
LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0);
}
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
auto & kv_self = lctx.kv_self;
@ -8340,6 +8581,14 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
}
}
// compress the KV cache data if needed
if (lctx.kv_self.compress_delta >= 0) {
llama_kv_cache_compress_internal(lctx);
lctx.kv_self.compress_delta = -1;
lctx.kv_self.do_defrag = false;
}
// defragment the KV cache if needed
if (lctx.kv_self.do_defrag) {
llama_kv_cache_defrag_internal(lctx);
@ -12450,6 +12699,10 @@ llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id se
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
}
void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) {
llama_kv_cache_compress(ctx->kv_self, delta);
}
void llama_kv_cache_defrag(struct llama_context * ctx) {
llama_kv_cache_defrag(ctx->kv_self);
}

View File

@ -557,6 +557,14 @@ extern "C" {
struct llama_context * ctx,
llama_seq_id seq_id);
// [EXPERIMENTAL] Compress the data in the KV cache
// This will be applied:
// - lazily on next llama_decode()
// - explicitly with llama_kv_cache_update()
LLAMA_API void llama_kv_cache_compress(
struct llama_context * ctx,
llama_pos delta);
// Defragment the KV cache
// This will be applied:
// - lazily on next llama_decode()