From fa84c4b3e80199a5683438f062009c031a06c4fa Mon Sep 17 00:00:00 2001 From: Kerfuffle <44031344+KerfuffleV2@users.noreply.github.com> Date: Sun, 11 Jun 2023 08:19:17 -0600 Subject: [PATCH] Fix issue where interactive mode crashes when input exceeds ctx size (#1789) * Fix issue where interactive mode in the main example crashes when input exceeds ctx size * Ensure the context size is at least 8 tokens in the main example. Closes #1768 --- examples/common.cpp | 3 +++ examples/common.h | 3 ++- examples/main/main.cpp | 16 ++++++++++++++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/examples/common.cpp b/examples/common.cpp index f5d886acf..df69f2736 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -632,6 +632,9 @@ void console_set_color(console_state & con_st, console_color_t color) { case CONSOLE_COLOR_USER_INPUT: fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_GREEN); break; + case CONSOLE_COLOR_ERROR: + fprintf(con_st.out, ANSI_BOLD ANSI_COLOR_RED); + break; } con_st.color = color; fflush(con_st.out); diff --git a/examples/common.h b/examples/common.h index 826e2ae59..6fedb414a 100644 --- a/examples/common.h +++ b/examples/common.h @@ -112,7 +112,8 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params); enum console_color_t { CONSOLE_COLOR_DEFAULT=0, CONSOLE_COLOR_PROMPT, - CONSOLE_COLOR_USER_INPUT + CONSOLE_COLOR_USER_INPUT, + CONSOLE_COLOR_ERROR }; struct console_state { diff --git a/examples/main/main.cpp b/examples/main/main.cpp index de63faa3e..66d563143 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -81,6 +81,9 @@ int main(int argc, char ** argv) { if (params.n_ctx > 2048) { fprintf(stderr, "%s: warning: model does not support context sizes greater than 2048 tokens (%d specified);" "expect poor results\n", __func__, params.n_ctx); + } else if (params.n_ctx < 8) { + fprintf(stderr, "%s: warning: minimum context size is 8, using minimum size.\n", __func__); + params.n_ctx = 8; } fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); @@ -331,6 +334,19 @@ int main(int argc, char ** argv) { while ((n_remain != 0 && !is_antiprompt) || params.interactive) { // predict if (embd.size() > 0) { + // Note: n_ctx - 4 here is to match the logic for commandline prompt handling via + // --prompt or --file which uses the same value. + auto max_embd_size = n_ctx - 4; + // Ensure the input doesn't exceed the context size by truncating embd if necessary. + if ((int)embd.size() > max_embd_size) { + auto skipped_tokens = embd.size() - max_embd_size; + console_set_color(con_st, CONSOLE_COLOR_ERROR); + printf("<>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); + console_set_color(con_st, CONSOLE_COLOR_DEFAULT); + fflush(stdout); + embd.resize(max_embd_size); + } + // infinite text generation via context swapping // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past)