From 24d29c55dffdd48474cc5c1310f2e6c24fc33392 Mon Sep 17 00:00:00 2001 From: Radoslav Gerganov Date: Fri, 18 Apr 2025 10:13:42 +0300 Subject: [PATCH] rpc : add RPC_CMD_HELLO (llama/12955) Add RPC_CMD_HELLO for getting the version of the protocol implemend by the server. Follow the semantic versioning rules at https://semver.org Hopefully this bring better user experience when we make breaking changes at the protocol level and avoid issues like #12465 --- ggml/include/ggml-rpc.h | 3 ++ ggml/src/ggml-rpc/ggml-rpc.cpp | 54 +++++++++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h index 4e0d210f..c8b6097f 100644 --- a/ggml/include/ggml-rpc.h +++ b/ggml/include/ggml-rpc.h @@ -7,6 +7,9 @@ extern "C" { #endif +#define RPC_PROTO_MAJOR_VERSION 1 +#define RPC_PROTO_MINOR_VERSION 0 +#define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 // backend API diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 3189ae85..a0667b7d 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -92,12 +92,19 @@ enum rpc_cmd { RPC_CMD_GET_DEVICE_MEMORY, RPC_CMD_INIT_TENSOR, RPC_CMD_GET_ALLOC_SIZE, + RPC_CMD_HELLO, RPC_CMD_COUNT, }; // Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold const size_t HASH_THRESHOLD = 10 * 1024 * 1024; +struct rpc_msg_hello_rsp { + uint8_t major; + uint8_t minor; + uint8_t patch; +}; + struct rpc_msg_get_alloc_size_req { rpc_tensor tensor; }; @@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cm // RPC client-side implementation +static bool check_server_version(const std::shared_ptr & sock) { + rpc_msg_hello_rsp response; + bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response)); + GGML_ASSERT(status); + if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) { + fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + return false; + } + if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) { + fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch); + } + return true; +} + static std::shared_ptr get_socket(const std::string & endpoint) { static std::mutex mutex; std::lock_guard lock(mutex); @@ -433,6 +454,9 @@ static std::shared_ptr get_socket(const std::string & endpoint) { if (sock == nullptr) { return nullptr; } + if (!check_server_version(sock)) { + return nullptr; + } GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd); sockets[endpoint] = sock; return sock; @@ -818,6 +842,7 @@ public: } ~rpc_server(); + void hello(rpc_msg_hello_rsp & response); void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response); void get_alignment(rpc_msg_get_alignment_rsp & response); void get_max_size(rpc_msg_get_max_size_rsp & response); @@ -846,6 +871,13 @@ private: std::unordered_set buffers; }; +void rpc_server::hello(rpc_msg_hello_rsp & response) { + response.major = RPC_PROTO_MAJOR_VERSION; + response.minor = RPC_PROTO_MINOR_VERSION; + response.patch = RPC_PROTO_PATCH_VERSION; + GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch); +} + bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) { ggml_backend_buffer_type_t buft; struct ggml_init_params params { @@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() { static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, sockfd_t sockfd, size_t free_mem, size_t total_mem) { rpc_server server(backend, cache_dir); + uint8_t cmd; + if (!recv_data(sockfd, &cmd, 1)) { + return; + } + // the first command sent by the client must be HELLO + if (cmd != RPC_CMD_HELLO) { + fprintf(stderr, "Expected HELLO command, update client\n"); + return; + } + if (!recv_msg(sockfd, nullptr, 0)) { + return; + } + rpc_msg_hello_rsp response; + server.hello(response); + if (!send_msg(sockfd, &response, sizeof(response))) { + return; + } while (true) { - uint8_t cmd; if (!recv_data(sockfd, &cmd, 1)) { break; } @@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir, break; } switch (cmd) { + case RPC_CMD_HELLO: { + // HELLO command is handled above + return; + } case RPC_CMD_ALLOC_BUFFER: { rpc_msg_alloc_buffer_req request; if (!recv_msg(sockfd, &request, sizeof(request))) {