mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-25 13:30:12 +00:00
rpc : add RPC_CMD_HELLO (llama/12955)
Add RPC_CMD_HELLO for getting the version of the protocol implemend by the server. Follow the semantic versioning rules at https://semver.org Hopefully this bring better user experience when we make breaking changes at the protocol level and avoid issues like #12465
This commit is contained in:
parent
36019c35a3
commit
24d29c55df
@ -7,6 +7,9 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
#define RPC_PROTO_MAJOR_VERSION 1
|
||||
#define RPC_PROTO_MINOR_VERSION 0
|
||||
#define RPC_PROTO_PATCH_VERSION 0
|
||||
#define GGML_RPC_MAX_SERVERS 16
|
||||
|
||||
// backend API
|
||||
|
@ -92,12 +92,19 @@ enum rpc_cmd {
|
||||
RPC_CMD_GET_DEVICE_MEMORY,
|
||||
RPC_CMD_INIT_TENSOR,
|
||||
RPC_CMD_GET_ALLOC_SIZE,
|
||||
RPC_CMD_HELLO,
|
||||
RPC_CMD_COUNT,
|
||||
};
|
||||
|
||||
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
||||
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
||||
|
||||
struct rpc_msg_hello_rsp {
|
||||
uint8_t major;
|
||||
uint8_t minor;
|
||||
uint8_t patch;
|
||||
};
|
||||
|
||||
struct rpc_msg_get_alloc_size_req {
|
||||
rpc_tensor tensor;
|
||||
};
|
||||
@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
||||
|
||||
// RPC client-side implementation
|
||||
|
||||
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
|
||||
rpc_msg_hello_rsp response;
|
||||
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
|
||||
GGML_ASSERT(status);
|
||||
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
|
||||
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
||||
return false;
|
||||
}
|
||||
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
|
||||
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
||||
static std::mutex mutex;
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
||||
if (sock == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!check_server_version(sock)) {
|
||||
return nullptr;
|
||||
}
|
||||
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
||||
sockets[endpoint] = sock;
|
||||
return sock;
|
||||
@ -818,6 +842,7 @@ public:
|
||||
}
|
||||
~rpc_server();
|
||||
|
||||
void hello(rpc_msg_hello_rsp & response);
|
||||
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
||||
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
||||
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
||||
@ -846,6 +871,13 @@ private:
|
||||
std::unordered_set<ggml_backend_buffer_t> buffers;
|
||||
};
|
||||
|
||||
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
||||
response.major = RPC_PROTO_MAJOR_VERSION;
|
||||
response.minor = RPC_PROTO_MINOR_VERSION;
|
||||
response.patch = RPC_PROTO_PATCH_VERSION;
|
||||
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
|
||||
}
|
||||
|
||||
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
||||
ggml_backend_buffer_type_t buft;
|
||||
struct ggml_init_params params {
|
||||
@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
|
||||
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
||||
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
||||
rpc_server server(backend, cache_dir);
|
||||
uint8_t cmd;
|
||||
if (!recv_data(sockfd, &cmd, 1)) {
|
||||
return;
|
||||
}
|
||||
// the first command sent by the client must be HELLO
|
||||
if (cmd != RPC_CMD_HELLO) {
|
||||
fprintf(stderr, "Expected HELLO command, update client\n");
|
||||
return;
|
||||
}
|
||||
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||
return;
|
||||
}
|
||||
rpc_msg_hello_rsp response;
|
||||
server.hello(response);
|
||||
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||
return;
|
||||
}
|
||||
while (true) {
|
||||
uint8_t cmd;
|
||||
if (!recv_data(sockfd, &cmd, 1)) {
|
||||
break;
|
||||
}
|
||||
@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
||||
break;
|
||||
}
|
||||
switch (cmd) {
|
||||
case RPC_CMD_HELLO: {
|
||||
// HELLO command is handled above
|
||||
return;
|
||||
}
|
||||
case RPC_CMD_ALLOC_BUFFER: {
|
||||
rpc_msg_alloc_buffer_req request;
|
||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user