From 3a0dcb39207a18ab3f8d825914d0c4359ae9736d Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Wed, 22 Mar 2023 10:41:26 -0300 Subject: [PATCH] Implement server mode. This new mode works by first loading the model then listening for TCP connections on a port. When a connection is received, arguments will be parsed using a simple protocol: - First the number of arguments will be read followed by a newline character. - Then each argument will be read, separated by the 0 byte. - With this we build an argument vector, similar to what is passed to the program entry point. We pass this to gpt_params_parse. Finally `run` will be executed with the input/output streams connected to the socket. Signed-off-by: Thiago Padilha --- CMakeLists.txt | 4 + Makefile | 7 +- chat_tcp_client.sh | 45 +++++++++ chat_tcp_server.sh | 6 ++ main.cpp | 7 ++ tcp_server.cpp | 245 +++++++++++++++++++++++++++++++++++++++++++++ tcp_server.h | 7 ++ utils.cpp | 8 ++ utils.h | 4 + 9 files changed, 331 insertions(+), 2 deletions(-) create mode 100755 chat_tcp_client.sh create mode 100755 chat_tcp_server.sh create mode 100644 tcp_server.cpp create mode 100644 tcp_server.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4db24fbbb..d95d93f99 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -244,6 +244,10 @@ add_executable(main run.cpp) target_link_libraries(main PRIVATE llama ggml utils) +if(NOT WIN32) + target_sources(main PRIVATE tcp_server.cpp) +endif() + add_executable(quantize quantize.cpp) target_link_libraries(quantize PRIVATE llama ggml utils) diff --git a/Makefile b/Makefile index 2f11ea166..59400a803 100644 --- a/Makefile +++ b/Makefile @@ -229,11 +229,14 @@ utils.o: utils.cpp utils.h run.o: run.cpp run.h $(CXX) $(CXXFLAGS) -c run.cpp -o run.o +tcp_server.o: tcp_server.cpp tcp_server.h + $(CXX) $(CXXFLAGS) -c tcp_server.cpp -o tcp_server.o + clean: rm -f *.o main quantize -main: main.cpp ggml.o llama.o utils.o run.o - $(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o -o main $(LDFLAGS) +main: main.cpp ggml.o llama.o utils.o run.o tcp_server.o + $(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o tcp_server.o -o main $(LDFLAGS) @echo "\x1b[36mrun ./main -h for help\x1b[0m" quantize: quantize.cpp ggml.o llama.o utils.o diff --git a/chat_tcp_client.sh b/chat_tcp_client.sh new file mode 100755 index 000000000..f154ae57d --- /dev/null +++ b/chat_tcp_client.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash + +PORT=${PORT:-8080} +PROMPT="${PROMPT:-"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision. + +User:Hello, Bob. +Bob:Hello. How may I help you today? +User:Please tell me the largest city in Europe. +Bob:Sure. The largest city in Europe is Moscow, the capital of Russia. +User:"}" +RPROMPT="${RPROMPT:-"User:"}" +N_PREDICT="${N_PREDICT:-"4096"}" +REPEAT_PENALTY="${REPEAT_PENALTY:-"1.0"}" +N_THREADS="${N_THREADS:-"4"}" + +# Open connection to the chat server +exec 3<>/dev/tcp/127.0.0.1/${PORT} + +# Pass the arguments. The protocol is really simple: +# 1. Pass the number of arguments followed by a linefeed +# 2. Pass the arguments, with each being followed by "0" +( +echo -en "12\n" +echo -en "-t\x00" +echo -en "$N_THREADS\x00" +echo -en "-n\x00" +echo -en "$N_PREDICT\x00" +echo -en "--repeat_penalty\x00" +echo -en "$REPEAT_PENALTY\x00" +echo -en "--color\x00" +echo -en "-i\x00" +echo -en "-r\x00" +echo -en "$RPROMPT\x00" +echo -en "-p\x00" +echo -en "$PROMPT\x00" +) >&3 + +trap exit TERM + +# When we have passed the arguments, start printing socket data to the screen. +# This is done in a background job because we also want to send data when +# running in interactive mode. +cat <&3 && echo "(disconnected, press \"enter\" twice to exit)" & +cat >&3 +wait diff --git a/chat_tcp_server.sh b/chat_tcp_server.sh new file mode 100755 index 000000000..79320906d --- /dev/null +++ b/chat_tcp_server.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +PORT=${PORT:-8080} +MODEL=${MODEL:-models/7B/ggml-model-q4_0.bin} + +./main -l ${PORT} -m $MODEL diff --git a/main.cpp b/main.cpp index 0044025e9..975714f93 100644 --- a/main.cpp +++ b/main.cpp @@ -1,5 +1,6 @@ #include "run.h" #include "ggml.h" +#include "tcp_server.h" #include @@ -125,5 +126,11 @@ int main(int argc, char ** argv) { exit(0); } +#ifndef _WIN32 + if (params.listen_port != "") { + return listen_tcp(ctx, params); + } +#endif + return run(ctx, params, std::cin, stdout, stderr); } diff --git a/tcp_server.cpp b/tcp_server.cpp new file mode 100644 index 000000000..9077c1807 --- /dev/null +++ b/tcp_server.cpp @@ -0,0 +1,245 @@ +#include "tcp_server.h" +#include "llama.h" +#include "utils.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +class PosixStream : public std::istream { + public: + PosixStream(int fd) : std::istream(&buf), buf(fd) {} + ~PosixStream() { close(buf.get_fd()); } + + private: + class PosixStreamBuf : public std::streambuf { + public: + PosixStreamBuf(int fd) : fd(fd) {} + int get_fd() const { return fd; } + + protected: + virtual int_type underflow() { + if (gptr() < egptr()) { + return traits_type::to_int_type(*gptr()); + } + + ssize_t num_read = ::read(fd, buffer, BUFFER_SIZE); + if (num_read <= 0) { + return traits_type::eof(); + } + + setg(buffer, buffer, buffer + num_read); + return traits_type::to_int_type(*gptr()); + } + + private: + static const int BUFFER_SIZE = 1024; + int fd; + char buffer[BUFFER_SIZE]; + }; + + PosixStreamBuf buf; +}; + +void die(const char *msg, ...) +{ + va_list ap; + + va_start(ap, msg); + vfprintf(stderr, msg, ap); + va_end(ap); + fputc('\n', stderr); + exit(1); +} + +static char *read_argument(uint8_t **param_buf, size_t *param_buf_size, FILE *instream) { + bool done = false; + uint8_t *buf = *param_buf; + size_t bufsize = *param_buf_size; + size_t bufpos = 0; + while (!done) { + if (bufpos == bufsize) { + bufsize += 1024; + buf = (uint8_t *)realloc(buf, bufsize); + if (!buf) { + die("failed to allocate memory"); + } + } + + int c = fgetc(instream); + if (c == EOF) { + die("unexpected EOF client socket"); + } + buf[bufpos++] = (uint8_t)c; + if (c == 0) { + // done reading argument + break; + } + } + *param_buf = buf; + *param_buf_size = bufsize; + return strdup((char *)buf); +} + +static int read_arguments(int argc, char **argv, FILE *instream) { + int i = 1; + size_t param_buf_size = 0; + uint8_t *param_buf = nullptr; + + for (i = 1; i < argc; i++) { + argv[i] = read_argument(¶m_buf, ¶m_buf_size, instream); + } + + free(param_buf); + return i; +} + +static int serve_model(llama_context * ctx, + gpt_params params, + int sock_fd) +{ + int argc; + char **argv; + FILE *instream = fdopen(sock_fd, "r"); + FILE *outstream = fdopen(sock_fd, "w"); + setvbuf(instream, NULL, _IONBF, 0); + + // start by reading the parameter count + if (fscanf(instream, "%d\n", &argc) != 1) { + fprintf(outstream, "Error: First line must be character count\n"); + fflush(outstream); + return 1; + } + + argc += 1; // add one extra argument to emulate the program command line + argv = (char **)malloc(argc * sizeof *argv); + argv[0] = nullptr; + if (read_arguments(argc, argv, instream) != argc) { + fprintf(outstream, "Error: Failed to read arguments\n"); + fflush(outstream); + } + + if (gpt_params_parse(argc, argv, params) == false) { + fprintf(outstream, "Error: Failed to parse parameters\n"); + fflush(outstream); + return 1; + } + + for (int i = 1; i < argc; i++) { + free(argv[i]); + } + free(argv); + + PosixStream tcp_instream(sock_fd); + + return run(ctx, params, tcp_instream, outstream, outstream); +} + +int listen_tcp(llama_context * ctx, gpt_params params) { + int listen_fd; + int status; + pid_t child; + struct addrinfo hints; + struct addrinfo *servinfo, *p; + int yes = 1; + + memset(&hints, 0, sizeof hints); + hints.ai_family = AF_INET; + hints.ai_socktype = SOCK_STREAM; + hints.ai_flags = AI_PASSIVE; + + // This should only ever listen on a loopback address. Access from outside + // should be proxied via socat or similar software + status = getaddrinfo("127.0.0.1", params.listen_port.c_str(), &hints, &servinfo); + if (status) { + die("getaddrinfo error: %s", gai_strerror(status)); + } + + // bind to the first addrinfo we can from the getaddrinfo results + for (p = servinfo; p != NULL; p = p->ai_next) { + listen_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol); + if (listen_fd == -1) { + perror("server: socket"); + continue; + } + + if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &yes, sizeof yes)) { + die("setsockopt error: %s", params.listen_port.c_str(), strerror(errno)); + } + + if (bind(listen_fd, p->ai_addr, p->ai_addrlen) == 0) { + struct sockaddr_in addr_in; + socklen_t addr_in_len = sizeof(addr_in); + memset(&addr_in, 0, addr_in_len); + getsockname(listen_fd, (struct sockaddr*)&addr_in, &addr_in_len); + + printf("Listening on %s:%d\n", inet_ntoa(addr_in.sin_addr), ntohs(addr_in.sin_port)); + break; + } + + close(listen_fd); + perror("server: bind"); + } + + freeaddrinfo(servinfo); + + if (p == NULL) { + die("failed to bind: %s", strerror(errno)); + } + + if (listen(listen_fd, 20)) { + die("listen error: %s", strerror(errno)); + } + // Don't track child processes, so ignore SIGCHLD to prevent zombies + signal(SIGCHLD, SIG_IGN); + + for (;;) { + struct sockaddr_in client_addr; + socklen_t client_addr_len = 0; + memset(&client_addr, 0, sizeof(client_addr)); + + int sock_fd = accept(listen_fd, + (struct sockaddr *)&client_addr, + &client_addr_len); + if (sock_fd < 0) { + fprintf(stderr, "accept error: %s\n", strerror(errno)); + break; + } + + child = fork(); + if (child == 0) { + // close the listen_fd since we won't use it in the child + close(listen_fd); + int ret = serve_model(ctx, params, sock_fd); + close(sock_fd); + return ret; + } else { + // close the client since we won't use it in the server + close(sock_fd); + sock_fd = 0; + } + } + close(listen_fd); + + // ignore SIGTERM since we'll send it to the group + signal(SIGTERM, SIG_IGN); + // tell children to exit + kill(0, SIGTERM); + // wait for children to terminate + wait(&status); + return 0; +} diff --git a/tcp_server.h b/tcp_server.h new file mode 100644 index 000000000..38d6ecc81 --- /dev/null +++ b/tcp_server.h @@ -0,0 +1,7 @@ +#pragma once + +#include "utils.h" +#include "llama.h" +#include "run.h" + +int listen_tcp(llama_context * ctx, gpt_params params); diff --git a/utils.cpp b/utils.cpp index 1d5309c3a..78baf924c 100644 --- a/utils.cpp +++ b/utils.cpp @@ -77,6 +77,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.ignore_eos = true; } else if (arg == "--n_parts") { params.n_parts = std::stoi(argv[++i]); +#ifndef _WIN32 + } else if (arg == "-l" || arg == "--listen") { + params.listen_port = argv[++i]; +#endif } else if (arg == "-h" || arg == "--help") { gpt_print_usage(argc, argv, params); exit(0); @@ -125,6 +129,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " --perplexity compute perplexity over the prompt\n"); fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); +#ifndef _WIN32 + fprintf(stderr, " -l PORT, --listen PORT\n"); + fprintf(stderr, " Run in TCP mode, listening on PORT\n"); +#endif fprintf(stderr, "\n"); } diff --git a/utils.h b/utils.h index b0de556c9..487892b12 100644 --- a/utils.h +++ b/utils.h @@ -42,6 +42,10 @@ struct gpt_params { bool instruct = false; // instruction mode (used for Alpaca models) bool ignore_eos = false; // do not stop generating after eos bool perplexity = false; // compute perplexity over the prompt + +#ifndef _WIN32 + std::string listen_port = ""; // TCP port for when running in server mode +#endif }; bool gpt_params_parse(int argc, char ** argv, gpt_params & params);