mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-11 21:39:52 +00:00
rpc : prevent crashes on invalid input (#9040)
Add more checks which prevent RPC server from crashing if invalid input is received from client
This commit is contained in:
parent
554b049068
commit
18eaf29f4c
@ -82,17 +82,18 @@ static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of
|
|||||||
|
|
||||||
// RPC commands
|
// RPC commands
|
||||||
enum rpc_cmd {
|
enum rpc_cmd {
|
||||||
ALLOC_BUFFER = 0,
|
RPC_CMD_ALLOC_BUFFER = 0,
|
||||||
GET_ALIGNMENT,
|
RPC_CMD_GET_ALIGNMENT,
|
||||||
GET_MAX_SIZE,
|
RPC_CMD_GET_MAX_SIZE,
|
||||||
BUFFER_GET_BASE,
|
RPC_CMD_BUFFER_GET_BASE,
|
||||||
FREE_BUFFER,
|
RPC_CMD_FREE_BUFFER,
|
||||||
BUFFER_CLEAR,
|
RPC_CMD_BUFFER_CLEAR,
|
||||||
SET_TENSOR,
|
RPC_CMD_SET_TENSOR,
|
||||||
GET_TENSOR,
|
RPC_CMD_GET_TENSOR,
|
||||||
COPY_TENSOR,
|
RPC_CMD_COPY_TENSOR,
|
||||||
GRAPH_COMPUTE,
|
RPC_CMD_GRAPH_COMPUTE,
|
||||||
GET_DEVICE_MEMORY,
|
RPC_CMD_GET_DEVICE_MEMORY,
|
||||||
|
RPC_CMD_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
// RPC data structures
|
// RPC data structures
|
||||||
@ -330,7 +331,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t
|
|||||||
uint64_t remote_ptr = ctx->remote_ptr;
|
uint64_t remote_ptr = ctx->remote_ptr;
|
||||||
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(ctx->sock, FREE_BUFFER, input, output);
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.empty());
|
GGML_ASSERT(output.empty());
|
||||||
delete ctx;
|
delete ctx;
|
||||||
@ -346,7 +347,7 @@ GGML_CALL static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t b
|
|||||||
uint64_t remote_ptr = ctx->remote_ptr;
|
uint64_t remote_ptr = ctx->remote_ptr;
|
||||||
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(ctx->sock, BUFFER_GET_BASE, input, output);
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||||
// output serialization format: | base_ptr (8 bytes) |
|
// output serialization format: | base_ptr (8 bytes) |
|
||||||
@ -405,7 +406,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t b
|
|||||||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(ctx->sock, SET_TENSOR, input, output);
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -419,7 +420,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t b
|
|||||||
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
|
||||||
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
|
memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(ctx->sock, GET_TENSOR, input, output);
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == size);
|
GGML_ASSERT(output.size() == size);
|
||||||
// output serialization format: | data (size bytes) |
|
// output serialization format: | data (size bytes) |
|
||||||
@ -444,7 +445,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t b
|
|||||||
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
|
memcpy(input.data(), &rpc_src, sizeof(rpc_src));
|
||||||
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
|
memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(ctx->sock, COPY_TENSOR, input, output);
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
// output serialization format: | result (1 byte) |
|
// output serialization format: | result (1 byte) |
|
||||||
GGML_ASSERT(output.size() == 1);
|
GGML_ASSERT(output.size() == 1);
|
||||||
@ -459,7 +460,7 @@ GGML_CALL static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer
|
|||||||
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
|
memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
|
||||||
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
|
memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(ctx->sock, BUFFER_CLEAR, input, output);
|
bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -488,7 +489,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|||||||
memcpy(input.data(), &size, sizeof(size));
|
memcpy(input.data(), &size, sizeof(size));
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
auto sock = get_socket(buft_ctx->endpoint);
|
auto sock = get_socket(buft_ctx->endpoint);
|
||||||
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
|
bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
||||||
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
||||||
@ -511,7 +512,7 @@ static size_t get_alignment(const std::shared_ptr<socket_t> & sock) {
|
|||||||
// input serialization format: | 0 bytes |
|
// input serialization format: | 0 bytes |
|
||||||
std::vector<uint8_t> input;
|
std::vector<uint8_t> input;
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(sock, GET_ALIGNMENT, input, output);
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||||
// output serialization format: | alignment (8 bytes) |
|
// output serialization format: | alignment (8 bytes) |
|
||||||
@ -529,7 +530,7 @@ static size_t get_max_size(const std::shared_ptr<socket_t> & sock) {
|
|||||||
// input serialization format: | 0 bytes |
|
// input serialization format: | 0 bytes |
|
||||||
std::vector<uint8_t> input;
|
std::vector<uint8_t> input;
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(sock, GET_MAX_SIZE, input, output);
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
GGML_ASSERT(output.size() == sizeof(uint64_t));
|
||||||
// output serialization format: | max_size (8 bytes) |
|
// output serialization format: | max_size (8 bytes) |
|
||||||
@ -622,7 +623,7 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
|
|||||||
serialize_graph(cgraph, input);
|
serialize_graph(cgraph, input);
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
auto sock = get_socket(rpc_ctx->endpoint);
|
auto sock = get_socket(rpc_ctx->endpoint);
|
||||||
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
|
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == 1);
|
GGML_ASSERT(output.size() == 1);
|
||||||
return (enum ggml_status)output[0];
|
return (enum ggml_status)output[0];
|
||||||
@ -719,7 +720,7 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
|
|||||||
// input serialization format: | 0 bytes |
|
// input serialization format: | 0 bytes |
|
||||||
std::vector<uint8_t> input;
|
std::vector<uint8_t> input;
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(sock, GET_DEVICE_MEMORY, input, output);
|
bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
||||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
||||||
@ -1098,59 +1099,69 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
|
|||||||
if (!recv_data(sockfd, &cmd, 1)) {
|
if (!recv_data(sockfd, &cmd, 1)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
if (cmd >= RPC_CMD_COUNT) {
|
||||||
|
// fail fast if the command is invalid
|
||||||
|
fprintf(stderr, "Unknown command: %d\n", cmd);
|
||||||
|
break;
|
||||||
|
}
|
||||||
std::vector<uint8_t> input;
|
std::vector<uint8_t> input;
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
uint64_t input_size;
|
uint64_t input_size;
|
||||||
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
try {
|
||||||
input.resize(input_size);
|
input.resize(input_size);
|
||||||
|
} catch (const std::bad_alloc & e) {
|
||||||
|
fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
|
||||||
|
break;
|
||||||
|
}
|
||||||
if (!recv_data(sockfd, input.data(), input_size)) {
|
if (!recv_data(sockfd, input.data(), input_size)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
switch (cmd) {
|
switch (cmd) {
|
||||||
case ALLOC_BUFFER: {
|
case RPC_CMD_ALLOC_BUFFER: {
|
||||||
ok = server.alloc_buffer(input, output);
|
ok = server.alloc_buffer(input, output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GET_ALIGNMENT: {
|
case RPC_CMD_GET_ALIGNMENT: {
|
||||||
server.get_alignment(output);
|
server.get_alignment(output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GET_MAX_SIZE: {
|
case RPC_CMD_GET_MAX_SIZE: {
|
||||||
server.get_max_size(output);
|
server.get_max_size(output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BUFFER_GET_BASE: {
|
case RPC_CMD_BUFFER_GET_BASE: {
|
||||||
ok = server.buffer_get_base(input, output);
|
ok = server.buffer_get_base(input, output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case FREE_BUFFER: {
|
case RPC_CMD_FREE_BUFFER: {
|
||||||
ok = server.free_buffer(input);
|
ok = server.free_buffer(input);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case BUFFER_CLEAR: {
|
case RPC_CMD_BUFFER_CLEAR: {
|
||||||
ok = server.buffer_clear(input);
|
ok = server.buffer_clear(input);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case SET_TENSOR: {
|
case RPC_CMD_SET_TENSOR: {
|
||||||
ok = server.set_tensor(input);
|
ok = server.set_tensor(input);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GET_TENSOR: {
|
case RPC_CMD_GET_TENSOR: {
|
||||||
ok = server.get_tensor(input, output);
|
ok = server.get_tensor(input, output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case COPY_TENSOR: {
|
case RPC_CMD_COPY_TENSOR: {
|
||||||
ok = server.copy_tensor(input, output);
|
ok = server.copy_tensor(input, output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GRAPH_COMPUTE: {
|
case RPC_CMD_GRAPH_COMPUTE: {
|
||||||
ok = server.graph_compute(input, output);
|
ok = server.graph_compute(input, output);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
case GET_DEVICE_MEMORY: {
|
case RPC_CMD_GET_DEVICE_MEMORY: {
|
||||||
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
// output serialization format: | free (8 bytes) | total (8 bytes) |
|
||||||
output.resize(2*sizeof(uint64_t), 0);
|
output.resize(2*sizeof(uint64_t), 0);
|
||||||
memcpy(output.data(), &free_mem, sizeof(free_mem));
|
memcpy(output.data(), &free_mem, sizeof(free_mem));
|
||||||
@ -1203,8 +1214,10 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
printf("Accepted client connection, free_mem=%zu, total_mem=%zu\n", free_mem, total_mem);
|
||||||
|
fflush(stdout);
|
||||||
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
rpc_serve_client(backend, client_socket->fd, free_mem, total_mem);
|
||||||
printf("Client connection closed\n");
|
printf("Client connection closed\n");
|
||||||
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
WSACleanup();
|
WSACleanup();
|
||||||
|
Loading…
Reference in New Issue
Block a user