rpc : add RPC_CMD_HELLO (llama/12955)

Add RPC_CMD_HELLO for getting the version of the protocol implemend by
the server. Follow the semantic versioning rules at https://semver.org

Hopefully this bring better user experience when we make breaking
changes at the protocol level and avoid issues like #12465
This commit is contained in:
Radoslav Gerganov 2025-04-18 10:13:42 +03:00 committed by Georgi Gerganov
parent 36019c35a3
commit 24d29c55df
2 changed files with 56 additions and 1 deletions

View File

@ -7,6 +7,9 @@
extern "C" {
#endif
#define RPC_PROTO_MAJOR_VERSION 1
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16
// backend API

View File

@ -92,12 +92,19 @@ enum rpc_cmd {
RPC_CMD_GET_DEVICE_MEMORY,
RPC_CMD_INIT_TENSOR,
RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO,
RPC_CMD_COUNT,
};
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
struct rpc_msg_hello_rsp {
uint8_t major;
uint8_t minor;
uint8_t patch;
};
struct rpc_msg_get_alloc_size_req {
rpc_tensor tensor;
};
@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
// RPC client-side implementation
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
rpc_msg_hello_rsp response;
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
GGML_ASSERT(status);
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
return false;
}
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
}
return true;
}
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
if (sock == nullptr) {
return nullptr;
}
if (!check_server_version(sock)) {
return nullptr;
}
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
sockets[endpoint] = sock;
return sock;
@ -818,6 +842,7 @@ public:
}
~rpc_server();
void hello(rpc_msg_hello_rsp & response);
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
void get_alignment(rpc_msg_get_alignment_rsp & response);
void get_max_size(rpc_msg_get_max_size_rsp & response);
@ -846,6 +871,13 @@ private:
std::unordered_set<ggml_backend_buffer_t> buffers;
};
void rpc_server::hello(rpc_msg_hello_rsp & response) {
response.major = RPC_PROTO_MAJOR_VERSION;
response.minor = RPC_PROTO_MINOR_VERSION;
response.patch = RPC_PROTO_PATCH_VERSION;
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
}
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
ggml_backend_buffer_type_t buft;
struct ggml_init_params params {
@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
rpc_server server(backend, cache_dir);
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
return;
}
// the first command sent by the client must be HELLO
if (cmd != RPC_CMD_HELLO) {
fprintf(stderr, "Expected HELLO command, update client\n");
return;
}
if (!recv_msg(sockfd, nullptr, 0)) {
return;
}
rpc_msg_hello_rsp response;
server.hello(response);
if (!send_msg(sockfd, &response, sizeof(response))) {
return;
}
while (true) {
uint8_t cmd;
if (!recv_data(sockfd, &cmd, 1)) {
break;
}
@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
break;
}
switch (cmd) {
case RPC_CMD_HELLO: {
// HELLO command is handled above
return;
}
case RPC_CMD_ALLOC_BUFFER: {
rpc_msg_alloc_buffer_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {