Implement server mode.

This new mode works by first loading the model then listening for TCP
connections on a port. When a connection is received, arguments will be
parsed using a simple protocol:

- First the number of arguments will be read followed by a newline
  character.
- Then each argument will be read, separated by the 0 byte.
- With this we build an argument vector, similar to what is passed to
  the program entry point. We pass this to gpt_params_parse.

Finally `run` will be executed with the input/output streams connected
to the socket.

Signed-off-by: Thiago Padilha <thiago@padilha.cc>
This commit is contained in:
Thiago Padilha 2023-03-22 10:41:26 -03:00
parent bf44faa0ee
commit 3a0dcb3920
No known key found for this signature in database
GPG Key ID: 309C78E5ED1B3D5E
9 changed files with 331 additions and 2 deletions

View File

@ -244,6 +244,10 @@ add_executable(main
run.cpp)
target_link_libraries(main PRIVATE llama ggml utils)
if(NOT WIN32)
target_sources(main PRIVATE tcp_server.cpp)
endif()
add_executable(quantize quantize.cpp)
target_link_libraries(quantize PRIVATE llama ggml utils)

View File

@ -229,11 +229,14 @@ utils.o: utils.cpp utils.h
run.o: run.cpp run.h
$(CXX) $(CXXFLAGS) -c run.cpp -o run.o
tcp_server.o: tcp_server.cpp tcp_server.h
$(CXX) $(CXXFLAGS) -c tcp_server.cpp -o tcp_server.o
clean:
rm -f *.o main quantize
main: main.cpp ggml.o llama.o utils.o run.o
$(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o -o main $(LDFLAGS)
main: main.cpp ggml.o llama.o utils.o run.o tcp_server.o
$(CXX) $(CXXFLAGS) main.cpp ggml.o llama.o utils.o run.o tcp_server.o -o main $(LDFLAGS)
@echo "\x1b[36mrun ./main -h for help\x1b[0m"
quantize: quantize.cpp ggml.o llama.o utils.o

45
chat_tcp_client.sh Executable file
View File

@ -0,0 +1,45 @@
#!/usr/bin/env bash
PORT=${PORT:-8080}
PROMPT="${PROMPT:-"Transcript of a dialog, where the User interacts with an Assistant named Bob. Bob is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
User:Hello, Bob.
Bob:Hello. How may I help you today?
User:Please tell me the largest city in Europe.
Bob:Sure. The largest city in Europe is Moscow, the capital of Russia.
User:"}"
RPROMPT="${RPROMPT:-"User:"}"
N_PREDICT="${N_PREDICT:-"4096"}"
REPEAT_PENALTY="${REPEAT_PENALTY:-"1.0"}"
N_THREADS="${N_THREADS:-"4"}"
# Open connection to the chat server
exec 3<>/dev/tcp/127.0.0.1/${PORT}
# Pass the arguments. The protocol is really simple:
# 1. Pass the number of arguments followed by a linefeed
# 2. Pass the arguments, with each being followed by "0"
(
echo -en "12\n"
echo -en "-t\x00"
echo -en "$N_THREADS\x00"
echo -en "-n\x00"
echo -en "$N_PREDICT\x00"
echo -en "--repeat_penalty\x00"
echo -en "$REPEAT_PENALTY\x00"
echo -en "--color\x00"
echo -en "-i\x00"
echo -en "-r\x00"
echo -en "$RPROMPT\x00"
echo -en "-p\x00"
echo -en "$PROMPT\x00"
) >&3
trap exit TERM
# When we have passed the arguments, start printing socket data to the screen.
# This is done in a background job because we also want to send data when
# running in interactive mode.
cat <&3 && echo "(disconnected, press \"enter\" twice to exit)" &
cat >&3
wait

6
chat_tcp_server.sh Executable file
View File

@ -0,0 +1,6 @@
#!/usr/bin/env bash
PORT=${PORT:-8080}
MODEL=${MODEL:-models/7B/ggml-model-q4_0.bin}
./main -l ${PORT} -m $MODEL

View File

@ -1,5 +1,6 @@
#include "run.h"
#include "ggml.h"
#include "tcp_server.h"
#include <iostream>
@ -125,5 +126,11 @@ int main(int argc, char ** argv) {
exit(0);
}
#ifndef _WIN32
if (params.listen_port != "") {
return listen_tcp(ctx, params);
}
#endif
return run(ctx, params, std::cin, stdout, stderr);
}

245
tcp_server.cpp Normal file
View File

@ -0,0 +1,245 @@
#include "tcp_server.h"
#include "llama.h"
#include "utils.h"
#include <iostream>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdbool.h>
#include <string.h>
#include <errno.h>
#include <signal.h>
#include <unistd.h>
#include <sys/wait.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netdb.h>
class PosixStream : public std::istream {
public:
PosixStream(int fd) : std::istream(&buf), buf(fd) {}
~PosixStream() { close(buf.get_fd()); }
private:
class PosixStreamBuf : public std::streambuf {
public:
PosixStreamBuf(int fd) : fd(fd) {}
int get_fd() const { return fd; }
protected:
virtual int_type underflow() {
if (gptr() < egptr()) {
return traits_type::to_int_type(*gptr());
}
ssize_t num_read = ::read(fd, buffer, BUFFER_SIZE);
if (num_read <= 0) {
return traits_type::eof();
}
setg(buffer, buffer, buffer + num_read);
return traits_type::to_int_type(*gptr());
}
private:
static const int BUFFER_SIZE = 1024;
int fd;
char buffer[BUFFER_SIZE];
};
PosixStreamBuf buf;
};
void die(const char *msg, ...)
{
va_list ap;
va_start(ap, msg);
vfprintf(stderr, msg, ap);
va_end(ap);
fputc('\n', stderr);
exit(1);
}
static char *read_argument(uint8_t **param_buf, size_t *param_buf_size, FILE *instream) {
bool done = false;
uint8_t *buf = *param_buf;
size_t bufsize = *param_buf_size;
size_t bufpos = 0;
while (!done) {
if (bufpos == bufsize) {
bufsize += 1024;
buf = (uint8_t *)realloc(buf, bufsize);
if (!buf) {
die("failed to allocate memory");
}
}
int c = fgetc(instream);
if (c == EOF) {
die("unexpected EOF client socket");
}
buf[bufpos++] = (uint8_t)c;
if (c == 0) {
// done reading argument
break;
}
}
*param_buf = buf;
*param_buf_size = bufsize;
return strdup((char *)buf);
}
static int read_arguments(int argc, char **argv, FILE *instream) {
int i = 1;
size_t param_buf_size = 0;
uint8_t *param_buf = nullptr;
for (i = 1; i < argc; i++) {
argv[i] = read_argument(&param_buf, &param_buf_size, instream);
}
free(param_buf);
return i;
}
static int serve_model(llama_context * ctx,
gpt_params params,
int sock_fd)
{
int argc;
char **argv;
FILE *instream = fdopen(sock_fd, "r");
FILE *outstream = fdopen(sock_fd, "w");
setvbuf(instream, NULL, _IONBF, 0);
// start by reading the parameter count
if (fscanf(instream, "%d\n", &argc) != 1) {
fprintf(outstream, "Error: First line must be character count\n");
fflush(outstream);
return 1;
}
argc += 1; // add one extra argument to emulate the program command line
argv = (char **)malloc(argc * sizeof *argv);
argv[0] = nullptr;
if (read_arguments(argc, argv, instream) != argc) {
fprintf(outstream, "Error: Failed to read arguments\n");
fflush(outstream);
}
if (gpt_params_parse(argc, argv, params) == false) {
fprintf(outstream, "Error: Failed to parse parameters\n");
fflush(outstream);
return 1;
}
for (int i = 1; i < argc; i++) {
free(argv[i]);
}
free(argv);
PosixStream tcp_instream(sock_fd);
return run(ctx, params, tcp_instream, outstream, outstream);
}
int listen_tcp(llama_context * ctx, gpt_params params) {
int listen_fd;
int status;
pid_t child;
struct addrinfo hints;
struct addrinfo *servinfo, *p;
int yes = 1;
memset(&hints, 0, sizeof hints);
hints.ai_family = AF_INET;
hints.ai_socktype = SOCK_STREAM;
hints.ai_flags = AI_PASSIVE;
// This should only ever listen on a loopback address. Access from outside
// should be proxied via socat or similar software
status = getaddrinfo("127.0.0.1", params.listen_port.c_str(), &hints, &servinfo);
if (status) {
die("getaddrinfo error: %s", gai_strerror(status));
}
// bind to the first addrinfo we can from the getaddrinfo results
for (p = servinfo; p != NULL; p = p->ai_next) {
listen_fd = socket(p->ai_family, p->ai_socktype, p->ai_protocol);
if (listen_fd == -1) {
perror("server: socket");
continue;
}
if (setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &yes, sizeof yes)) {
die("setsockopt error: %s", params.listen_port.c_str(), strerror(errno));
}
if (bind(listen_fd, p->ai_addr, p->ai_addrlen) == 0) {
struct sockaddr_in addr_in;
socklen_t addr_in_len = sizeof(addr_in);
memset(&addr_in, 0, addr_in_len);
getsockname(listen_fd, (struct sockaddr*)&addr_in, &addr_in_len);
printf("Listening on %s:%d\n", inet_ntoa(addr_in.sin_addr), ntohs(addr_in.sin_port));
break;
}
close(listen_fd);
perror("server: bind");
}
freeaddrinfo(servinfo);
if (p == NULL) {
die("failed to bind: %s", strerror(errno));
}
if (listen(listen_fd, 20)) {
die("listen error: %s", strerror(errno));
}
// Don't track child processes, so ignore SIGCHLD to prevent zombies
signal(SIGCHLD, SIG_IGN);
for (;;) {
struct sockaddr_in client_addr;
socklen_t client_addr_len = 0;
memset(&client_addr, 0, sizeof(client_addr));
int sock_fd = accept(listen_fd,
(struct sockaddr *)&client_addr,
&client_addr_len);
if (sock_fd < 0) {
fprintf(stderr, "accept error: %s\n", strerror(errno));
break;
}
child = fork();
if (child == 0) {
// close the listen_fd since we won't use it in the child
close(listen_fd);
int ret = serve_model(ctx, params, sock_fd);
close(sock_fd);
return ret;
} else {
// close the client since we won't use it in the server
close(sock_fd);
sock_fd = 0;
}
}
close(listen_fd);
// ignore SIGTERM since we'll send it to the group
signal(SIGTERM, SIG_IGN);
// tell children to exit
kill(0, SIGTERM);
// wait for children to terminate
wait(&status);
return 0;
}

7
tcp_server.h Normal file
View File

@ -0,0 +1,7 @@
#pragma once
#include "utils.h"
#include "llama.h"
#include "run.h"
int listen_tcp(llama_context * ctx, gpt_params params);

View File

@ -77,6 +77,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.ignore_eos = true;
} else if (arg == "--n_parts") {
params.n_parts = std::stoi(argv[++i]);
#ifndef _WIN32
} else if (arg == "-l" || arg == "--listen") {
params.listen_port = argv[++i];
#endif
} else if (arg == "-h" || arg == "--help") {
gpt_print_usage(argc, argv, params);
exit(0);
@ -125,6 +129,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --perplexity compute perplexity over the prompt\n");
fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
#ifndef _WIN32
fprintf(stderr, " -l PORT, --listen PORT\n");
fprintf(stderr, " Run in TCP mode, listening on PORT\n");
#endif
fprintf(stderr, "\n");
}

View File

@ -42,6 +42,10 @@ struct gpt_params {
bool instruct = false; // instruction mode (used for Alpaca models)
bool ignore_eos = false; // do not stop generating after eos
bool perplexity = false; // compute perplexity over the prompt
#ifndef _WIN32
std::string listen_port = ""; // TCP port for when running in server mode
#endif
};
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);