diff --git a/main.cpp b/main.cpp index 8ce9af8c3..0044025e9 100644 --- a/main.cpp +++ b/main.cpp @@ -1,6 +1,8 @@ #include "run.h" #include "ggml.h" +#include + std::vector softmax(const std::vector& logits) { std::vector probs(logits.size()); @@ -123,5 +125,5 @@ int main(int argc, char ** argv) { exit(0); } - return run(ctx, params); + return run(ctx, params, std::cin, stdout, stderr); } diff --git a/run.cpp b/run.cpp index 7b0543732..ab430eb92 100644 --- a/run.cpp +++ b/run.cpp @@ -44,7 +44,7 @@ enum console_state { static console_state con_st = CONSOLE_STATE_DEFAULT; static bool con_use_color = false; -void set_console_state(console_state new_st) +void set_console_state(FILE *stream, console_state new_st) { if (!con_use_color) return; // only emit color code if state changed @@ -52,13 +52,13 @@ void set_console_state(console_state new_st) con_st = new_st; switch(con_st) { case CONSOLE_STATE_DEFAULT: - printf(ANSI_COLOR_RESET); + fprintf(stream, ANSI_COLOR_RESET); return; case CONSOLE_STATE_PROMPT: - printf(ANSI_COLOR_YELLOW); + fprintf(stream, ANSI_COLOR_YELLOW); return; case CONSOLE_STATE_USER_INPUT: - printf(ANSI_BOLD ANSI_COLOR_GREEN); + fprintf(stream, ANSI_BOLD ANSI_COLOR_GREEN); return; } } @@ -68,7 +68,7 @@ static bool is_interacting = false; #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) void sigint_handler(int signo) { - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_state(stdout, CONSOLE_STATE_DEFAULT); printf("\n"); // this also force flush stdout. if (signo == SIGINT) { if (!is_interacting) { @@ -80,13 +80,17 @@ void sigint_handler(int signo) { } #endif -int run(llama_context * ctx, gpt_params params) { +int run(llama_context * ctx, + gpt_params params, + std::istream & instream, + FILE *outstream, + FILE *errstream) { if (params.seed <= 0) { params.seed = time(NULL); } - fprintf(stderr, "%s: seed = %d\n", __func__, params.seed); + fprintf(errstream, "%s: seed = %d\n", __func__, params.seed); std::mt19937 rng(params.seed); if (params.random_prompt) { @@ -138,13 +142,13 @@ int run(llama_context * ctx, gpt_params params) { params.interactive = true; } - fprintf(stderr, "\n"); - fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); - fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); + fprintf(errstream, "\n"); + fprintf(errstream, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); + fprintf(errstream, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); for (int i = 0; i < (int) embd_inp.size(); i++) { - fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); + fprintf(errstream, "%6d -> '%s'\n", embd_inp[i], llama_token_to_str(ctx, embd_inp[i])); } - fprintf(stderr, "\n"); + fprintf(errstream, "\n"); if (params.interactive) { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -156,16 +160,16 @@ int run(llama_context * ctx, gpt_params params) { signal(SIGINT, sigint_handler); #endif - fprintf(stderr, "%s: interactive mode on.\n", __func__); + fprintf(errstream, "%s: interactive mode on.\n", __func__); if(params.antiprompt.size()) { for (auto antiprompt : params.antiprompt) { - fprintf(stderr, "Reverse prompt: '%s'\n", antiprompt.c_str()); + fprintf(errstream, "Reverse prompt: '%s'\n", antiprompt.c_str()); } } } - fprintf(stderr, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); - fprintf(stderr, "\n\n"); + fprintf(errstream, "sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); + fprintf(errstream, "\n\n"); std::vector embd; @@ -174,7 +178,7 @@ int run(llama_context * ctx, gpt_params params) { std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); if (params.interactive) { - fprintf(stderr, "== Running in interactive mode. ==\n" + fprintf(errstream, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) " - Press Ctrl+C to interject at any time.\n" #endif @@ -199,13 +203,13 @@ int run(llama_context * ctx, gpt_params params) { } #endif // the first thing we will do is to output the prompt, so set color accordingly - set_console_state(CONSOLE_STATE_PROMPT); + set_console_state(outstream, CONSOLE_STATE_PROMPT); while (remaining_tokens > 0 || params.interactive) { // predict if (embd.size() > 0) { if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); + fprintf(errstream, "%s : failed to eval\n", __func__); return 1; } } @@ -263,13 +267,13 @@ int run(llama_context * ctx, gpt_params params) { // display text if (!input_noecho) { for (auto id : embd) { - printf("%s", llama_token_to_str(ctx, id)); + fprintf(outstream, "%s", llama_token_to_str(ctx, id)); } - fflush(stdout); + fflush(outstream); } // reset color to default if we there is no pending user input if (!input_noecho && (int)embd_inp.size() == input_consumed) { - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_state(outstream, CONSOLE_STATE_DEFAULT); } // in interactive mode, and not currently processing queued inputs; @@ -290,20 +294,20 @@ int run(llama_context * ctx, gpt_params params) { } if (is_interacting) { // potentially set color to indicate we are taking user input - set_console_state(CONSOLE_STATE_USER_INPUT); + set_console_state(outstream, CONSOLE_STATE_USER_INPUT); if (params.instruct) { input_consumed = embd_inp.size(); embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end()); - printf("\n> "); + fprintf(outstream, "\n> "); } std::string buffer; std::string line; bool another_line = true; do { - std::getline(std::cin, line); + std::getline(instream, line); if (line.empty() || line.back() != '\\') { another_line = false; } else { @@ -313,7 +317,7 @@ int run(llama_context * ctx, gpt_params params) { } while (another_line); // done taking input, reset color - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_state(outstream, CONSOLE_STATE_DEFAULT); auto line_inp = ::llama_tokenize(ctx, buffer, false); embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end()); @@ -334,7 +338,7 @@ int run(llama_context * ctx, gpt_params params) { if (params.interactive) { is_interacting = true; } else { - fprintf(stderr, " [end of text]\n"); + fprintf(errstream, " [end of text]\n"); break; } } @@ -354,7 +358,7 @@ int run(llama_context * ctx, gpt_params params) { llama_free(ctx); - set_console_state(CONSOLE_STATE_DEFAULT); + set_console_state(outstream, CONSOLE_STATE_DEFAULT); return 0; } diff --git a/run.h b/run.h index 3603396da..39c8e9f06 100644 --- a/run.h +++ b/run.h @@ -3,4 +3,8 @@ #include "llama.h" #include "utils.h" -int run(llama_context * ctx, gpt_params params); +int run(llama_context * ctx, + gpt_params params, + std::istream & instream, + FILE *outstream, + FILE *errstream);