mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2025-01-11 19:21:46 +00:00
main : add self-extend support (#4815)
* examples : add passkey test * passkey : better prints * passkey : select pass key pos from CLI * passkey : simplify n_past logic * llama : "self-extend"-like context extension * passkey : add comment * main : add Self-Extend support * llama : add comment about llama_kv_cache_seq_div
This commit is contained in:
parent
b0034d93ce
commit
52531fdff8
@ -220,6 +220,20 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.n_ctx = std::stoi(argv[i]);
|
params.n_ctx = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--grp-attn-n" || arg == "-gan") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
params.grp_attn_n = std::stoi(argv[i]);
|
||||||
|
} else if (arg == "--grp-attn-w" || arg == "-gaw") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
params.grp_attn_w = std::stoi(argv[i]);
|
||||||
} else if (arg == "--rope-freq-base") {
|
} else if (arg == "--rope-freq-base") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
@ -904,6 +918,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
|
|||||||
printf(" Not recommended since this is both slower and uses more VRAM.\n");
|
printf(" Not recommended since this is both slower and uses more VRAM.\n");
|
||||||
#endif // GGML_USE_CUBLAS
|
#endif // GGML_USE_CUBLAS
|
||||||
#endif
|
#endif
|
||||||
|
printf(" -gan N, --grp-attn-n N\n");
|
||||||
|
printf(" group-attention factor (default: %d)\n", params.grp_attn_n);
|
||||||
|
printf(" -gat N, --grp-attn-w N\n");
|
||||||
|
printf(" group-attention width (default: %.1f)\n", (double)params.grp_attn_w);
|
||||||
printf(" --verbose-prompt print prompt before generation\n");
|
printf(" --verbose-prompt print prompt before generation\n");
|
||||||
printf(" -dkvc, --dump-kv-cache\n");
|
printf(" -dkvc, --dump-kv-cache\n");
|
||||||
printf(" verbose print of the KV cache\n");
|
printf(" verbose print of the KV cache\n");
|
||||||
|
@ -62,6 +62,8 @@ struct gpt_params {
|
|||||||
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors
|
||||||
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
float tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
|
||||||
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
int32_t n_beams = 0; // if non-zero then use beam search of given width.
|
||||||
|
int32_t grp_attn_n = 1; // group-attention factor
|
||||||
|
int32_t grp_attn_w = 512; // group-attention width
|
||||||
float rope_freq_base = 0.0f; // RoPE base frequency
|
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||||
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
||||||
|
@ -439,6 +439,21 @@ int main(int argc, char ** argv) {
|
|||||||
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
||||||
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
|
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
|
||||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||||
|
|
||||||
|
// group-attention state
|
||||||
|
// number of grouped KV tokens so far (used only if params.grp_attn_n > 1)
|
||||||
|
int ga_i = 0;
|
||||||
|
|
||||||
|
const int ga_n = params.grp_attn_n;
|
||||||
|
const int ga_w = params.grp_attn_w;
|
||||||
|
|
||||||
|
if (ga_n != 1) {
|
||||||
|
GGML_ASSERT(ga_n > 0 && "grp_attn_n must be positive"); // NOLINT
|
||||||
|
GGML_ASSERT(ga_w % ga_n == 0 && "grp_attn_w must be a multiple of grp_attn_n"); // NOLINT
|
||||||
|
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of grp_attn_w"); // NOLINT
|
||||||
|
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * grp_attn_n"); // NOLINT
|
||||||
|
LOG_TEE("self-extend: n_ctx_train = %d, grp_attn_n = %d, grp_attn_w = %d\n", n_ctx_train, ga_n, ga_w);
|
||||||
|
}
|
||||||
LOG_TEE("\n\n");
|
LOG_TEE("\n\n");
|
||||||
|
|
||||||
if (params.interactive) {
|
if (params.interactive) {
|
||||||
@ -500,37 +515,61 @@ int main(int argc, char ** argv) {
|
|||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
// infinite text generation via context swapping
|
if (ga_n == 1) {
|
||||||
// if we run out of context:
|
// infinite text generation via context shifting
|
||||||
// - take the n_keep first tokens from the original prompt (via n_past)
|
// if we run out of context:
|
||||||
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
|
// - take the n_keep first tokens from the original prompt (via n_past)
|
||||||
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
|
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
|
||||||
if (params.n_predict == -2) {
|
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) > n_ctx) {
|
||||||
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
if (params.n_predict == -2) {
|
||||||
break;
|
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_left = n_past - params.n_keep - 1;
|
||||||
|
const int n_discard = n_left/2;
|
||||||
|
|
||||||
|
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||||
|
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||||
|
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||||
|
|
||||||
|
n_past -= n_discard;
|
||||||
|
|
||||||
|
if (ctx_guidance) {
|
||||||
|
n_past_guidance -= n_discard;
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
|
||||||
|
|
||||||
|
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
||||||
|
|
||||||
|
LOG("clear session path\n");
|
||||||
|
path_session.clear();
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// context extension via Self-Extend
|
||||||
|
while (n_past >= ga_i + ga_w) {
|
||||||
|
const int ib = (ga_n*ga_i)/ga_w;
|
||||||
|
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
||||||
|
const int dd = (ga_w/ga_n) - ib*bd - ga_w;
|
||||||
|
|
||||||
const int n_left = n_past - params.n_keep - 1;
|
LOG("\n");
|
||||||
const int n_discard = n_left/2;
|
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i, n_past, ib*bd, ga_i + ib*bd, n_past + ib*bd);
|
||||||
|
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
|
||||||
|
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
|
||||||
|
|
||||||
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
llama_kv_cache_seq_shift(ctx, 0, ga_i, n_past, ib*bd);
|
||||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
llama_kv_cache_seq_div (ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
|
||||||
|
llama_kv_cache_seq_shift(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
n_past -= bd;
|
||||||
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
|
||||||
|
|
||||||
n_past -= n_discard;
|
ga_i += ga_w/ga_n;
|
||||||
|
|
||||||
if (ctx_guidance) {
|
LOG("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", n_past + bd, n_past, ga_i);
|
||||||
n_past_guidance -= n_discard;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
|
|
||||||
|
|
||||||
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
|
||||||
|
|
||||||
LOG("clear session path\n");
|
|
||||||
path_session.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
|
// try to reuse a matching prefix from the loaded session instead of re-eval (via n_past)
|
||||||
|
4
llama.h
4
llama.h
@ -484,6 +484,10 @@ extern "C" {
|
|||||||
llama_pos p1,
|
llama_pos p1,
|
||||||
llama_pos delta);
|
llama_pos delta);
|
||||||
|
|
||||||
|
// Integer division of the positions by factor of `d > 1`
|
||||||
|
// If the KV cache is RoPEd, the KV data is updated accordingly
|
||||||
|
// p0 < 0 : [0, p1]
|
||||||
|
// p1 < 0 : [p0, inf)
|
||||||
LLAMA_API void llama_kv_cache_seq_div(
|
LLAMA_API void llama_kv_cache_seq_div(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id,
|
llama_seq_id seq_id,
|
||||||
|
Loading…
Reference in New Issue
Block a user