diff --git a/common/common.cpp b/common/common.cpp index c64590385..781f2166b 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -904,6 +904,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.interactive_specials = true; return true; } + if (arg == "--no-special") { + params.no_special = true; + return true; + } if (arg == "--embedding") { params.embedding = true; return true; @@ -1364,6 +1368,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param printf(" -i, --interactive run in interactive mode\n"); printf(" --interactive-specials allow special tokens in user text, in interactive mode\n"); printf(" --interactive-first run in interactive mode and wait for input right away\n"); + printf(" --no-special control tokens output disabled\n"); printf(" -cnv, --conversation run in conversation mode (does not print special tokens and suffix/prefix)\n"); printf(" -ins, --instruct run in instruction mode (use with Alpaca models)\n"); printf(" -cml, --chatml run in chatml mode (use with ChatML-compatible models)\n"); diff --git a/common/common.h b/common/common.h index f68f3c297..5388f6b68 100644 --- a/common/common.h +++ b/common/common.h @@ -146,6 +146,7 @@ struct gpt_params { bool use_color = false; // use color to distinguish generations and inputs bool interactive = false; // interactive mode bool interactive_specials = false; // whether to allow special tokens from user, during interactive mode + bool no_special = false; // disable control token output bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix) bool chatml = false; // chatml mode (used for models trained on chatml syntax) bool prompt_cache_all = false; // save user input and generations to prompt cache diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 09fa85fce..ac35772f1 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -740,18 +740,32 @@ int main(int argc, char ** argv) { // display text if (input_echo && display) { for (auto id : embd) { - const std::string token_str = llama_token_to_piece(ctx, id, !params.conversation); - printf("%s", token_str.c_str()); + const std::string token_str = llama_token_to_piece(ctx, id); + // Console/Stream Output + if (!llama_token_is_control(llama_get_model(ctx), id)) { + // Stream Output Token To Standard Output + fprintf(stdout, "%s", token_str.c_str()); + } else if (!params.no_special && !params.conversation) { + // Stream Control Token To Standard Output Stream + fprintf(stdout, "%s", token_str.c_str()); + } + + // Record Displayed Tokens To Log + // Note: Generated tokens are created one by one hence this check if (embd.size() > 1) { + // Incoming Requested Tokens input_tokens.push_back(id); } else { + // Outgoing Generated Tokens output_tokens.push_back(id); output_ss << token_str; } + + fflush(stdout); } - fflush(stdout); } + // reset color to default if there is no pending user input if (input_echo && (int) embd_inp.size() == n_consumed) { console::set_display(console::reset); diff --git a/llama.cpp b/llama.cpp index 85cb3140d..989d27b9d 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17861,6 +17861,10 @@ bool llama_token_is_eog(const struct llama_model * model, llama_token token) { ); } +bool llama_token_is_control(const struct llama_model * model, llama_token token) { + return llama_is_control_token(model->vocab, token); +} + llama_token llama_token_bos(const struct llama_model * model) { return model->vocab.special_bos_id; } diff --git a/llama.h b/llama.h index 16cece5db..16676269d 100644 --- a/llama.h +++ b/llama.h @@ -823,6 +823,9 @@ extern "C" { // Check if the token is supposed to end generation (end-of-generation, eg. EOS, EOT, etc.) LLAMA_API bool llama_token_is_eog(const struct llama_model * model, llama_token token); + // Identify if Token Id is a control token or a render-able token + LLAMA_API bool llama_token_is_control(const struct llama_model * model, llama_token token); + // Special tokens LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence