diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 862dde996..5c0022832 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -10,7 +10,7 @@ int main(int argc, char ** argv) { gpt_params params; if (argc == 1 || argv[1][0] == '-') { - printf("usage: %s MODEL_PATH N_JUNK I_POS SEED\n" , argv[0]); + printf("usage: %s MODEL_PATH N_JUNK N_GRP I_POS SEED\n" , argv[0]); return 1 ; } @@ -18,6 +18,7 @@ int main(int argc, char ** argv) { int n_junk = 250; // number of times to repeat the junk text int n_keep = 32; // number of tokens in the prompt prefix + int n_grp = 1; // if more than 1 - perform LongLM SelfExtend int i_pos = -1; // position of the passkey in the junk text if (argc >= 2) { @@ -29,11 +30,15 @@ int main(int argc, char ** argv) { } if (argc >= 4) { - i_pos = std::stoi(argv[3]); + n_grp = std::stoi(argv[3]); } if (argc >= 5) { - seed = std::stoi(argv[4]); + i_pos = std::stoi(argv[4]); + } + + if (argc >= 6) { + seed = std::stoi(argv[5]); } if (seed == -1) { @@ -86,11 +91,13 @@ int main(int argc, char ** argv) { llama_context_params ctx_params = llama_context_default_params(); ctx_params.seed = seed; - ctx_params.n_ctx = llama_n_ctx_train(model) + n_keep; + ctx_params.n_ctx = llama_n_ctx_train(model)*n_grp + n_keep; ctx_params.n_batch = 512; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp"); + llama_context * ctx = llama_new_context_with_model(model, ctx_params); if (ctx == NULL) { @@ -113,11 +120,12 @@ int main(int argc, char ** argv) { // total length of the sequences including the prompt const int n_len = n_tokens_all + n_predict; - const int n_ctx = llama_n_ctx(ctx) - n_keep; - const int n_kv_req = llama_n_ctx(ctx); - const int n_batch = ctx_params.n_batch; + const int n_ctx = llama_n_ctx(ctx) - n_keep; + const int n_kv_req = llama_n_ctx(ctx); + const int n_batch = ctx_params.n_batch; + const int n_batch_grp = ctx_params.n_batch/n_grp; - LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req); + LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch); // print the prompt token-by-token @@ -132,6 +140,17 @@ int main(int argc, char ** argv) { // fill the KV cache for (int i = 0; i < n_ctx; i += n_batch) { + if (i > 0 && n_grp > 1) { + // if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp + const int ib = i/n_batch - 1; + const int bd = n_batch_grp*(n_grp - 1); + + llama_kv_cache_seq_shift(ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + + n_past -= bd; + } + llama_batch_clear(batch); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { diff --git a/llama.cpp b/llama.cpp index 91aa3f8e7..63853d1c3 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1903,6 +1903,28 @@ static void llama_kv_cache_seq_shift( cache.head = new_head != cache.size ? new_head : 0; } +static void llama_kv_cache_seq_div( + struct llama_kv_cache & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (p0 < 0) p0 = 0; + if (p1 < 0) p1 = std::numeric_limits::max(); + + for (uint32_t i = 0; i < cache.size; ++i) { + if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { + cache.has_shift = true; + + { + llama_pos p_old = cache.cells[i].pos; + cache.cells[i].pos /= d; + cache.cells[i].delta += cache.cells[i].pos - p_old; + } + } + } +} + // // model loading and saving // @@ -10140,9 +10162,21 @@ void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { } void llama_kv_cache_seq_shift(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (delta == 0) { + return; + } + llama_kv_cache_seq_shift(ctx->kv_self, seq_id, p0, p1, delta); } +void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); +} + // Returns the *maximum* size of the state size_t llama_get_state_size(const struct llama_context * ctx) { // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. diff --git a/llama.h b/llama.h index 461d4604a..5305de90b 100644 --- a/llama.h +++ b/llama.h @@ -484,6 +484,13 @@ extern "C" { llama_pos p1, llama_pos delta); + LLAMA_API void llama_kv_cache_seq_div( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d); + // // State / sessions //