mirror of
https://github.com/ggerganov/llama.cpp.git
synced 2024-11-13 14:29:52 +00:00
rpc : resource management rework (#7562)
* rpc : resource management rework * address review comments
This commit is contained in:
parent
ee3dff6b8e
commit
2b737caae1
131
ggml-rpc.cpp
131
ggml-rpc.cpp
@ -6,6 +6,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
@ -47,6 +48,7 @@ struct socket_t {
|
|||||||
sockfd_t fd;
|
sockfd_t fd;
|
||||||
socket_t(sockfd_t fd) : fd(fd) {}
|
socket_t(sockfd_t fd) : fd(fd) {}
|
||||||
~socket_t() {
|
~socket_t() {
|
||||||
|
GGML_PRINT_DEBUG("[%s] closing socket %d\n", __func__, this->fd);
|
||||||
#ifdef _WIN32
|
#ifdef _WIN32
|
||||||
closesocket(this->fd);
|
closesocket(this->fd);
|
||||||
#else
|
#else
|
||||||
@ -97,7 +99,7 @@ static ggml_guid_t ggml_backend_rpc_guid() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_backend_rpc_buffer_type_context {
|
struct ggml_backend_rpc_buffer_type_context {
|
||||||
std::shared_ptr<socket_t> sock;
|
std::string endpoint;
|
||||||
std::string name;
|
std::string name;
|
||||||
size_t alignment;
|
size_t alignment;
|
||||||
size_t max_size;
|
size_t max_size;
|
||||||
@ -106,8 +108,6 @@ struct ggml_backend_rpc_buffer_type_context {
|
|||||||
struct ggml_backend_rpc_context {
|
struct ggml_backend_rpc_context {
|
||||||
std::string endpoint;
|
std::string endpoint;
|
||||||
std::string name;
|
std::string name;
|
||||||
std::shared_ptr<socket_t> sock;
|
|
||||||
ggml_backend_buffer_type_t buft;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_backend_rpc_buffer_context {
|
struct ggml_backend_rpc_buffer_context {
|
||||||
@ -231,14 +231,13 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool parse_endpoint(const char * endpoint, std::string & host, int & port) {
|
static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
|
||||||
std::string str(endpoint);
|
size_t pos = endpoint.find(':');
|
||||||
size_t pos = str.find(':');
|
|
||||||
if (pos == std::string::npos) {
|
if (pos == std::string::npos) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
host = str.substr(0, pos);
|
host = endpoint.substr(0, pos);
|
||||||
port = std::stoi(str.substr(pos + 1));
|
port = std::stoi(endpoint.substr(pos + 1));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -273,6 +272,44 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|||||||
|
|
||||||
// RPC client-side implementation
|
// RPC client-side implementation
|
||||||
|
|
||||||
|
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
||||||
|
static std::mutex mutex;
|
||||||
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
|
static std::unordered_map<std::string, std::weak_ptr<socket_t>> sockets;
|
||||||
|
static bool initialized = false;
|
||||||
|
|
||||||
|
auto it = sockets.find(endpoint);
|
||||||
|
if (it != sockets.end()) {
|
||||||
|
if (auto sock = it->second.lock()) {
|
||||||
|
return sock;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::string host;
|
||||||
|
int port;
|
||||||
|
if (!parse_endpoint(endpoint, host, port)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
#ifdef _WIN32
|
||||||
|
if (!initialized) {
|
||||||
|
WSADATA wsaData;
|
||||||
|
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
||||||
|
if (res != 0) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
initialized = true;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
UNUSED(initialized);
|
||||||
|
#endif
|
||||||
|
auto sock = socket_connect(host.c_str(), port);
|
||||||
|
if (sock == nullptr) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
||||||
|
sockets[endpoint] = sock;
|
||||||
|
return sock;
|
||||||
|
}
|
||||||
|
|
||||||
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
|
GGML_CALL static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffer) {
|
||||||
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
|
||||||
return ctx->name.c_str();
|
return ctx->name.c_str();
|
||||||
@ -442,7 +479,8 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|||||||
std::vector<uint8_t> input(input_size, 0);
|
std::vector<uint8_t> input(input_size, 0);
|
||||||
memcpy(input.data(), &size, sizeof(size));
|
memcpy(input.data(), &size, sizeof(size));
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(buft_ctx->sock, ALLOC_BUFFER, input, output);
|
auto sock = get_socket(buft_ctx->endpoint);
|
||||||
|
bool status = send_rpc_cmd(sock, ALLOC_BUFFER, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
|
||||||
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
// output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
|
||||||
@ -453,7 +491,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer
|
|||||||
if (remote_ptr != 0) {
|
if (remote_ptr != 0) {
|
||||||
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
|
||||||
ggml_backend_rpc_buffer_interface,
|
ggml_backend_rpc_buffer_interface,
|
||||||
new ggml_backend_rpc_buffer_context{buft_ctx->sock, {}, remote_ptr, "RPC"},
|
new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC"},
|
||||||
remote_size);
|
remote_size);
|
||||||
return buffer;
|
return buffer;
|
||||||
} else {
|
} else {
|
||||||
@ -508,7 +546,7 @@ GGML_CALL static bool ggml_backend_rpc_buffer_type_supports_backend(ggml_backend
|
|||||||
}
|
}
|
||||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
|
||||||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||||
return buft_ctx->sock == rpc_ctx->sock;
|
return buft_ctx->endpoint == rpc_ctx->endpoint;
|
||||||
}
|
}
|
||||||
|
|
||||||
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
||||||
@ -521,7 +559,6 @@ static ggml_backend_buffer_type_i ggml_backend_rpc_buffer_type_interface = {
|
|||||||
/* .is_host = */ NULL,
|
/* .is_host = */ NULL,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
||||||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||||
|
|
||||||
@ -530,16 +567,13 @@ GGML_CALL static const char * ggml_backend_rpc_name(ggml_backend_t backend) {
|
|||||||
|
|
||||||
GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
GGML_CALL static void ggml_backend_rpc_free(ggml_backend_t backend) {
|
||||||
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
|
||||||
ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)rpc_ctx->buft->context;
|
|
||||||
delete buft_ctx;
|
|
||||||
delete rpc_ctx->buft;
|
|
||||||
delete rpc_ctx;
|
delete rpc_ctx;
|
||||||
delete backend;
|
delete backend;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
|
GGML_CALL static ggml_backend_buffer_type_t ggml_backend_rpc_get_default_buffer_type(ggml_backend_t backend) {
|
||||||
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
||||||
return ctx->buft;
|
return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
GGML_CALL static void ggml_backend_rpc_synchronize(ggml_backend_t backend) {
|
||||||
@ -590,7 +624,8 @@ GGML_CALL static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t
|
|||||||
std::vector<uint8_t> input;
|
std::vector<uint8_t> input;
|
||||||
serialize_graph(cgraph, input);
|
serialize_graph(cgraph, input);
|
||||||
std::vector<uint8_t> output;
|
std::vector<uint8_t> output;
|
||||||
bool status = send_rpc_cmd(rpc_ctx->sock, GRAPH_COMPUTE, input, output);
|
auto sock = get_socket(rpc_ctx->endpoint);
|
||||||
|
bool status = send_rpc_cmd(sock, GRAPH_COMPUTE, input, output);
|
||||||
GGML_ASSERT(status);
|
GGML_ASSERT(status);
|
||||||
GGML_ASSERT(output.size() == 1);
|
GGML_ASSERT(output.size() == 1);
|
||||||
return (enum ggml_status)output[0];
|
return (enum ggml_status)output[0];
|
||||||
@ -624,65 +659,48 @@ static ggml_backend_i ggml_backend_rpc_interface = {
|
|||||||
/* .event_synchronize = */ NULL,
|
/* .event_synchronize = */ NULL,
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::unordered_map<std::string, ggml_backend_t> instances;
|
|
||||||
|
|
||||||
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {
|
||||||
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
|
static std::mutex mutex;
|
||||||
return backend != nullptr ? ggml_backend_rpc_get_default_buffer_type(backend) : nullptr;
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
}
|
// NOTE: buffer types are allocated and never freed; this is by design
|
||||||
|
static std::unordered_map<std::string, ggml_backend_buffer_type_t> buft_map;
|
||||||
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
auto it = buft_map.find(endpoint);
|
||||||
std::string endpoint_str(endpoint);
|
if (it != buft_map.end()) {
|
||||||
if (instances.find(endpoint_str) != instances.end()) {
|
return it->second;
|
||||||
return instances[endpoint_str];
|
|
||||||
}
|
}
|
||||||
#ifdef _WIN32
|
auto sock = get_socket(endpoint);
|
||||||
{
|
|
||||||
WSADATA wsaData;
|
|
||||||
int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
|
|
||||||
if (res != 0) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
fprintf(stderr, "Connecting to %s\n", endpoint);
|
|
||||||
std::string host;
|
|
||||||
int port;
|
|
||||||
if (!parse_endpoint(endpoint, host, port)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto sock = socket_connect(host.c_str(), port);
|
|
||||||
if (sock == nullptr) {
|
if (sock == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
size_t alignment = get_alignment(sock);
|
size_t alignment = get_alignment(sock);
|
||||||
size_t max_size = get_max_size(sock);
|
size_t max_size = get_max_size(sock);
|
||||||
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
ggml_backend_rpc_buffer_type_context * buft_ctx = new ggml_backend_rpc_buffer_type_context {
|
||||||
/* .sock = */ sock,
|
/* .endpoint = */ endpoint,
|
||||||
/* .name = */ "RPC" + std::to_string(sock->fd),
|
/* .name = */ "RPC[" + std::string(endpoint) + "]",
|
||||||
/* .alignment = */ alignment,
|
/* .alignment = */ alignment,
|
||||||
/* .max_size = */ max_size
|
/* .max_size = */ max_size
|
||||||
};
|
};
|
||||||
|
|
||||||
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
|
||||||
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
/* .iface = */ ggml_backend_rpc_buffer_type_interface,
|
||||||
/* .context = */ buft_ctx
|
/* .context = */ buft_ctx
|
||||||
};
|
};
|
||||||
|
buft_map[endpoint] = buft;
|
||||||
|
return buft;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_CALL ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
|
||||||
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
|
||||||
/* .endpoint = */ endpoint,
|
/* .endpoint = */ endpoint,
|
||||||
/* .name = */ "RPC" + std::to_string(sock->fd),
|
/* .name = */ "RPC",
|
||||||
/* .sock = */ sock,
|
|
||||||
/* .buft = */ buft
|
|
||||||
};
|
};
|
||||||
|
|
||||||
instances[endpoint] = new ggml_backend {
|
ggml_backend_t backend = new ggml_backend {
|
||||||
/* .guid = */ ggml_backend_rpc_guid(),
|
/* .guid = */ ggml_backend_rpc_guid(),
|
||||||
/* .interface = */ ggml_backend_rpc_interface,
|
/* .interface = */ ggml_backend_rpc_interface,
|
||||||
/* .context = */ ctx
|
/* .context = */ ctx
|
||||||
};
|
};
|
||||||
|
return backend;
|
||||||
return instances[endpoint];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
GGML_API GGML_CALL bool ggml_backend_is_rpc(ggml_backend_t backend) {
|
||||||
@ -706,14 +724,13 @@ static void get_device_memory(const std::shared_ptr<socket_t> & sock, size_t * f
|
|||||||
}
|
}
|
||||||
|
|
||||||
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
GGML_API GGML_CALL void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
|
||||||
ggml_backend_t backend = ggml_backend_rpc_init(endpoint);
|
auto sock = get_socket(endpoint);
|
||||||
if (backend == nullptr) {
|
if (sock == nullptr) {
|
||||||
*free = 0;
|
*free = 0;
|
||||||
*total = 0;
|
*total = 0;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
ggml_backend_rpc_context * ctx = (ggml_backend_rpc_context *)backend->context;
|
get_device_memory(sock, free, total);
|
||||||
get_device_memory(ctx->sock, free, total);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RPC server-side implementation
|
// RPC server-side implementation
|
||||||
|
Loading…
Reference in New Issue
Block a user