rpc : resource management rework (#7562)

* rpc : resource management rework

* address review comments
This commit is contained in:
Radoslav Gerganov 2024-05-28 18:13:36 +03:00 committed by GitHub
parent ee3dff6b8e
commit 2b737caae1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -6,6 +6,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include <memory> #include <memory>
#include <mutex>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#ifdef _WIN32 #ifdef _WIN32
@ -47,6 +48,7 @@ struct socket_t {
sockfd_t fd; sockfd_t fd;
socket_t(sockfd_t fd) : fd(fd) {} socket_t(sockfd_t fd) : fd(fd) {}
~socket_t() { ~socket_t() {
GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
#ifdef _WIN32 #ifdef _WIN32
closesocket(this->fd); closesocket(this->fd);
#else #else
@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
} }
struct ggml_backend_rpc_buffer_type_context { struct ggml_backend_rpc_buffer_type_context {
std::shared_ptr<socket_t> sock; std::string endpoint;
std::string name; std::string name;
size_t alignment; size_t alignment;
size_t max_size; size_t max_size;
@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
struct ggml_backend_rpc_context { struct ggml_backend_rpc_context {
std::string endpoint; std::string endpoint;
std::string name; std::string name;
std::shared_ptr<socket_t> sock;
ggml_backend_buffer_type_t buft;
}; };
struct ggml_backend_rpc_buffer_context { struct ggml_backend_rpc_buffer_context {
@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
return true; return true;
} }
static bool parse_endpoint(const char * endpoint, std::string & host, int & port) { static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
std::string str(endpoint); size_t pos = endpoint.find(':');
size_t pos = str.find(':');
if (pos == std::string::npos) { if (pos == std::string::npos) {
return false; return false;
} }
host = str.substr(0, pos); host = endpoint.substr(0, pos);
port = std::stoi(str.substr(pos + 1)); port = std::stoi(endpoint.substr(pos + 1));
return true; return true;
} }
@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
// RPC client-side implementation // RPC client-side implementation
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
static bool initialized = false;
auto it = sockets.find(endpoint);
if (it != sockets.end()) {
if (auto sock = it->second.lock()) {
return sock;
}
}
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
return nullptr;
}
#ifdef _WIN32
if (!initialized) {
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (res != 0) {
return nullptr;
}
initialized = true;
}
#else
UNUSED(initialized);
#endif
auto sock = socket_connect(host.c_str(), port);
if (sock == nullptr) {
return nullptr;
}
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
sockets[endpoint] = sock;
return sock;
}
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) { GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context; ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
return ctx->name.c_str(); return ctx->name.c_str();
@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
std::vector<uint8_t> input(input_size, 0); std::vector<uint8_t> input(input_size, 0);
memcpy(input.data(), &size, sizeof(size)); memcpy(input.data(), &size, sizeof(size));
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output); auto sock = get_socket(buft_ctx->endpoint);
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == 2*sizeof(uint64_t)); GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) | // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
if (remote_ptr != 0) { if (remote_ptr != 0) {
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft, ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
ggml_backend_rpc_buffer_interface, ggml_backend_rpc_buffer_interface,
new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"}, new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
remote_size); remote_size);
return buffer; return buffer;
} else { } else {
@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
} }
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context; ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
return buft_ctx->sock == rpc_ctx->sock; return buft_ctx->endpoint == rpc_ctx->endpoint;
} }
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = { static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
/* .is_host = */ NULL, /* .is_host = */ NULL,
}; };
GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) { GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) { GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context; ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
delete buft_ctx;
delete rpc_ctx->buft;
delete rpc_ctx; delete rpc_ctx;
delete backend; delete backend;
} }
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) { GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
return ctx->buft; return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
} }
GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) { GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
std::vector<uint8_t> input; std::vector<uint8_t> input;
serialize_graph(cgraph, input); serialize_graph(cgraph, input);
std::vector<uint8_t> output; std::vector<uint8_t> output;
bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output); auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
GGML_ASSERT(status); GGML_ASSERT(status);
GGML_ASSERT(output.size() == 1); GGML_ASSERT(output.size() == 1);
return (enum ggml_status)output[0]; return (enum ggml_status)output[0];
@ -624,42 +659,24 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .event_synchronize = */ NULL, /* .event_synchronize = */ NULL,
}; };
static std::unordered_map<std::string, ggml_backend_t> instances;
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
ggml_backend_t backend = ggml_backend_rpc_init(endpoint); static std::mutex mutex;
return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr; std::lock_guard<std::mutex> lock(mutex);
// NOTE: buffer types are allocated and never freed; this is by design
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
auto it = buft_map.find(endpoint);
if (it != buft_map.end()) {
return it->second;
} }
auto sock = get_socket(endpoint);
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
std::string endpoint_str(endpoint);
if (instances.find(endpoint_str) != instances.end()) {
return instances[endpoint_str];
}
#ifdef _WIN32
{
WSADATA wsaData;
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
if (res != 0) {
return nullptr;
}
}
#endif
fprintf(stderr, "Connecting to %s\n", endpoint);
std::string host;
int port;
if (!parse_endpoint(endpoint, host, port)) {
return nullptr;
}
auto sock = socket_connect(host.c_str(), port);
if (sock == nullptr) { if (sock == nullptr) {
return nullptr; return nullptr;
} }
size_t alignment = get_alignment(sock); size_t alignment = get_alignment(sock);
size_t max_size = get_max_size(sock); size_t max_size = get_max_size(sock);
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context { ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
/* .sock = */ sock, /* .endpoint = */ endpoint,
/* .name = */ "RPC" + std::to_string(sock->fd), /* .name = */ "RPC[" + std::string(endpoint) + "]",
/* .alignment = */ alignment, /* .alignment = */ alignment,
/* .max_size = */ max_size /* .max_size = */ max_size
}; };
@ -668,21 +685,22 @@ GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
/* .iface = */ ggml_backend_rpc_buffer_type_interface, /* .iface = */ ggml_backend_rpc_buffer_type_interface,
/* .context = */ buft_ctx /* .context = */ buft_ctx
}; };
buft_map[endpoint] = buft;
return buft;
}
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context { ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint, /* .endpoint = */ endpoint,
/* .name = */ "RPC" + std::to_string(sock->fd), /* .name = */ "RPC",
/* .sock = */ sock,
/* .buft = */ buft
}; };
instances[endpoint] = new ggml_backend { ggml_backend_t backend = new ggml_backend {
/* .guid = */ ggml_backend_rpc_guid(), /* .guid = */ ggml_backend_rpc_guid(),
/* .interface = */ ggml_backend_rpc_interface, /* .interface = */ ggml_backend_rpc_interface,
/* .context = */ ctx /* .context = */ ctx
}; };
return backend;
return instances[endpoint];
} }
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) { GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
} }
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) { GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
ggml_backend_t backend = ggml_backend_rpc_init(endpoint); auto sock = get_socket(endpoint);
if (backend == nullptr) { if (sock == nullptr) {
*free = 0; *free = 0;
*total = 0; *total = 0;
return; return;
} }
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context; get_device_memory(sock, free, total);
get_device_memory(ctx->sock, free, total);
} }
// RPC server-side implementation // RPC server-side implementation