main : option to disable context shift (#9484)

* added cli arg to disable context shift

* reverted precommit

* updated README.md for main

* white space

* allow disabling context shift in the server

* Update common/arg.cpp

no-context-shift only works for main example

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* added server example to --no-context-shift args

* removed server changes

* white space

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Vinesh Janarthanan 2024-09-16 01:20:01 -05:00 committed by GitHub
parent c4965a64f7
commit 441b72b91f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 36 additions and 21 deletions

View File

@ -685,6 +685,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.n_keep = value;
}
));
add_opt(llama_arg(
{"--no-context-shift"},
format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
[](gpt_params & params) {
params.ctx_shift = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(llama_arg(
{"--chunks"}, "N",
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
@ -1985,4 +1992,3 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
return ctx_arg;
}

View File

@ -246,6 +246,7 @@ struct gpt_params {
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch

View File

@ -161,6 +161,8 @@ A value of -1 will enable infinite text generation, even though we have a finite
If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled.
The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full.
It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter.
### Temperature

View File

@ -559,29 +559,35 @@ int main(int argc, char ** argv) {
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
if (n_past + (int) embd.size() >= n_ctx) {
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
if (!params.ctx_shift){
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break;
} else {
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}
const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard;
LOG_DBG("after swap: n_past = %d\n", n_past);
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
LOG_DBG("clear session path\n");
path_session.clear();
}
const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard;
LOG_DBG("after swap: n_past = %d\n", n_past);
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
LOG_DBG("clear session path\n");
path_session.clear();
}
} else {
// context extension via Self-Extend