mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-12-26 11:24:35 +00:00
llama : add llama_kv_cache_compress (EXPERIMENTAL)
This commit is contained in:
parent
c24a2a6e60
commit
14d757066b
@ -148,6 +148,7 @@ int main(int argc, char ** argv) {
|
||||
|
||||
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
|
||||
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||
llama_kv_cache_compress(ctx, 0);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||
|
253
llama.cpp
253
llama.cpp
@ -1738,6 +1738,9 @@ struct llama_kv_cache {
|
||||
ggml_type type_k = GGML_TYPE_F16;
|
||||
ggml_type type_v = GGML_TYPE_F16;
|
||||
|
||||
// if non-negative, compress data on next update
|
||||
llama_pos compress_delta = -1;
|
||||
|
||||
std::vector<llama_kv_cell> cells;
|
||||
|
||||
std::vector<struct ggml_tensor *> k_l; // per layer
|
||||
@ -2273,6 +2276,10 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
|
||||
return result;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos delta) {
|
||||
cache.compress_delta = delta;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
|
||||
cache.do_defrag = true;
|
||||
}
|
||||
@ -8109,6 +8116,240 @@ static int llama_decode_internal(
|
||||
return 0;
|
||||
}
|
||||
|
||||
// summary:
|
||||
//
|
||||
// - determine which KV cell pairs (i0, i1) to merge:
|
||||
//
|
||||
// abs(cell[i0].pos - cell[i1].pos) <= compress_delta
|
||||
//
|
||||
// - move the KV cache to the host memory for easier manipulation
|
||||
// - processing is done layer-by-layer
|
||||
// - convert the KV data to F32
|
||||
// - merge the KV data (different ways to merge)
|
||||
// - convert the KV data back to the original type
|
||||
// - move the KV cache back to the device memory
|
||||
// - update the KV cache metadata
|
||||
//
|
||||
// as a side effect, the new KV cache is defragmented
|
||||
//
|
||||
static void llama_kv_cache_compress_internal(struct llama_context & lctx) {
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k);
|
||||
const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v);
|
||||
const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv);
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
std::vector<uint8_t> buf_q;
|
||||
|
||||
std::vector<float> buf_src_f32;
|
||||
std::vector<float> buf_dst_f32;
|
||||
|
||||
struct c_pair { uint32_t i0, i1; };
|
||||
struct c_info { bool merged; uint32_t id, cnt, r; };
|
||||
|
||||
std::vector<c_info> infos(kv_size, { false, 0, 0, 0 });
|
||||
|
||||
// the destination cell in the new KV cache
|
||||
uint32_t id = 0;
|
||||
|
||||
// number of pairs merged
|
||||
uint32_t n_merges = 0;
|
||||
|
||||
// determine which KV cells to merge
|
||||
for (uint32_t i0 = 0; i0 < kv_size; ++i0) {
|
||||
const auto & cell0 = kv_self.cells[i0];
|
||||
|
||||
if (!cell0.is_empty() && !infos[i0].merged) {
|
||||
infos[i0] = { true, id, 0, 0 };
|
||||
infos[id].cnt = 1;
|
||||
|
||||
const llama_pos p0 = cell0.pos;
|
||||
|
||||
for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) {
|
||||
const auto & cell1 = kv_self.cells[i1];
|
||||
|
||||
if (i0 != i1 && cell0.is_same_seq(cell1)) {
|
||||
const llama_pos p1 = cell1.pos;
|
||||
|
||||
if (std::abs(p0 - p1) <= kv_self.compress_delta) {
|
||||
infos[i1] = { true, id, 0, 0 };
|
||||
infos[id].cnt++;
|
||||
n_merges++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (i0 != id) {
|
||||
kv_self.cells[id] = cell0;
|
||||
}
|
||||
|
||||
id++;
|
||||
}
|
||||
}
|
||||
|
||||
kv_self.head = id;
|
||||
kv_self.used = id;
|
||||
|
||||
for (uint32_t i = id; i < kv_size; ++i) {
|
||||
kv_self.cells[i] = llama_kv_cell();
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges);
|
||||
|
||||
ggml_type_traits_t tt_k;
|
||||
ggml_type_traits_t tt_v;
|
||||
|
||||
tt_k = ggml_internal_get_type_traits(kv_self.type_k);
|
||||
tt_v = ggml_internal_get_type_traits(kv_self.type_v);
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
infos[i].r = 0;
|
||||
}
|
||||
|
||||
// update keys
|
||||
{
|
||||
const int64_t ne = n_embd_k_gqa*kv_size;
|
||||
|
||||
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne);
|
||||
|
||||
buf_q.resize(k_size);
|
||||
|
||||
buf_src_f32.resize(ne);
|
||||
buf_dst_f32.resize(ne);
|
||||
|
||||
ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
|
||||
|
||||
tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne);
|
||||
|
||||
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
if (!infos[i].merged) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t id = infos[i].id;
|
||||
|
||||
#if 1
|
||||
// merge using averaging
|
||||
{
|
||||
const float scale = 1.0f/float(infos[id].cnt);
|
||||
|
||||
const int64_t os = i*n_embd_k_gqa;
|
||||
const int64_t od = id*n_embd_k_gqa;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_k_gqa; ++j) {
|
||||
buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
|
||||
}
|
||||
}
|
||||
#else
|
||||
// merge separate heads
|
||||
{
|
||||
for (uint32_t h = 0; h < n_head_kv; ++h) {
|
||||
if ((h + il) % infos[id].cnt != infos[id].r) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k;
|
||||
const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_head_k; ++j) {
|
||||
buf_dst_f32[od + j] = buf_src_f32[os + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
infos[id].r++;
|
||||
#endif
|
||||
}
|
||||
|
||||
tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
||||
|
||||
ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
infos[i].r = 0;
|
||||
}
|
||||
|
||||
// update values (note: they are transposed)
|
||||
{
|
||||
const int64_t ne = n_embd_v_gqa*kv_size;
|
||||
|
||||
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne);
|
||||
|
||||
buf_q.resize(v_size);
|
||||
|
||||
buf_src_f32.resize(ne);
|
||||
buf_dst_f32.resize(ne);
|
||||
|
||||
ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
|
||||
|
||||
tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne);
|
||||
|
||||
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
if (!infos[i].merged) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t id = infos[i].id;
|
||||
|
||||
#if 1
|
||||
// merge using averaging
|
||||
{
|
||||
const float scale = 1.0f/float(infos[id].cnt);
|
||||
//printf("i: %d -> id: %d, scale: %f\n", i, id, scale);
|
||||
|
||||
const int64_t os = i;
|
||||
const int64_t od = id;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
|
||||
}
|
||||
}
|
||||
#else
|
||||
// merge separate heads
|
||||
{
|
||||
for (uint32_t h = 0; h < n_head_kv; ++h) {
|
||||
if ((h + il) % infos[id].cnt != infos[id].r) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t os = i;
|
||||
const int64_t od = id;
|
||||
|
||||
for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) {
|
||||
buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
infos[id].r++;
|
||||
#endif
|
||||
}
|
||||
|
||||
tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
||||
|
||||
ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||
}
|
||||
|
||||
// find holes from the beginning of the KV cache and fill them by moving data from the end of the cache
|
||||
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
auto & kv_self = lctx.kv_self;
|
||||
@ -8340,6 +8581,14 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
}
|
||||
}
|
||||
|
||||
// compress the KV cache data if needed
|
||||
if (lctx.kv_self.compress_delta >= 0) {
|
||||
llama_kv_cache_compress_internal(lctx);
|
||||
|
||||
lctx.kv_self.compress_delta = -1;
|
||||
lctx.kv_self.do_defrag = false;
|
||||
}
|
||||
|
||||
// defragment the KV cache if needed
|
||||
if (lctx.kv_self.do_defrag) {
|
||||
llama_kv_cache_defrag_internal(lctx);
|
||||
@ -12450,6 +12699,10 @@ llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id se
|
||||
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
||||
}
|
||||
|
||||
void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) {
|
||||
llama_kv_cache_compress(ctx->kv_self, delta);
|
||||
}
|
||||
|
||||
void llama_kv_cache_defrag(struct llama_context * ctx) {
|
||||
llama_kv_cache_defrag(ctx->kv_self);
|
||||
}
|
||||
|
8
llama.h
8
llama.h
@ -557,6 +557,14 @@ extern "C" {
|
||||
struct llama_context * ctx,
|
||||
llama_seq_id seq_id);
|
||||
|
||||
// [EXPERIMENTAL] Compress the data in the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_cache_update()
|
||||
LLAMA_API void llama_kv_cache_compress(
|
||||
struct llama_context * ctx,
|
||||
llama_pos delta);
|
||||
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
|
Loading…
Reference in New Issue
Block a user