mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-25 21:39:44 +00:00
rpc : add RPC_CMD_HELLO (llama/12955)
Add RPC_CMD_HELLO for getting the version of the protocol implemend by the server. Follow the semantic versioning rules at https://semver.org Hopefully this bring better user experience when we make breaking changes at the protocol level and avoid issues like #12465
This commit is contained in:
parent
36019c35a3
commit
24d29c55df
@ -7,6 +7,9 @@
|
|||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define RPC_PROTO_MAJOR_VERSION 1
|
||||||
|
#define RPC_PROTO_MINOR_VERSION 0
|
||||||
|
#define RPC_PROTO_PATCH_VERSION 0
|
||||||
#define GGML_RPC_MAX_SERVERS 16
|
#define GGML_RPC_MAX_SERVERS 16
|
||||||
|
|
||||||
// backend API
|
// backend API
|
||||||
|
@ -92,12 +92,19 @@ enum rpc_cmd {
|
|||||||
RPC_CMD_GET_DEVICE_MEMORY,
|
RPC_CMD_GET_DEVICE_MEMORY,
|
||||||
RPC_CMD_INIT_TENSOR,
|
RPC_CMD_INIT_TENSOR,
|
||||||
RPC_CMD_GET_ALLOC_SIZE,
|
RPC_CMD_GET_ALLOC_SIZE,
|
||||||
|
RPC_CMD_HELLO,
|
||||||
RPC_CMD_COUNT,
|
RPC_CMD_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
// Try RPC_CMD_SET_TENSOR_HASH first when data size is larger than this threshold
|
||||||
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
const size_t HASH_THRESHOLD = 10 * 1024 * 1024;
|
||||||
|
|
||||||
|
struct rpc_msg_hello_rsp {
|
||||||
|
uint8_t major;
|
||||||
|
uint8_t minor;
|
||||||
|
uint8_t patch;
|
||||||
|
};
|
||||||
|
|
||||||
struct rpc_msg_get_alloc_size_req {
|
struct rpc_msg_get_alloc_size_req {
|
||||||
rpc_tensor tensor;
|
rpc_tensor tensor;
|
||||||
};
|
};
|
||||||
@ -400,6 +407,20 @@ static bool send_rpc_cmd(const std::shared_ptr<socket_t> & sock, enum rpc_cmd cm
|
|||||||
|
|
||||||
// RPC client-side implementation
|
// RPC client-side implementation
|
||||||
|
|
||||||
|
static bool check_server_version(const std::shared_ptr<socket_t> & sock) {
|
||||||
|
rpc_msg_hello_rsp response;
|
||||||
|
bool status = send_rpc_cmd(sock, RPC_CMD_HELLO, nullptr, 0, &response, sizeof(response));
|
||||||
|
GGML_ASSERT(status);
|
||||||
|
if (response.major != RPC_PROTO_MAJOR_VERSION || response.minor > RPC_PROTO_MINOR_VERSION) {
|
||||||
|
fprintf(stderr, "RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (response.minor != RPC_PROTO_MINOR_VERSION || response.patch != RPC_PROTO_PATCH_VERSION) {
|
||||||
|
fprintf(stderr, "WARNING: RPC server version mismatch: %d.%d.%d\n", response.major, response.minor, response.patch);
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
||||||
static std::mutex mutex;
|
static std::mutex mutex;
|
||||||
std::lock_guard<std::mutex> lock(mutex);
|
std::lock_guard<std::mutex> lock(mutex);
|
||||||
@ -433,6 +454,9 @@ static std::shared_ptr<socket_t> get_socket(const std::string & endpoint) {
|
|||||||
if (sock == nullptr) {
|
if (sock == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (!check_server_version(sock)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
GGML_PRINT_DEBUG("[%s] connected to %s, sockfd=%d\n", __func__, endpoint.c_str(), sock->fd);
|
||||||
sockets[endpoint] = sock;
|
sockets[endpoint] = sock;
|
||||||
return sock;
|
return sock;
|
||||||
@ -818,6 +842,7 @@ public:
|
|||||||
}
|
}
|
||||||
~rpc_server();
|
~rpc_server();
|
||||||
|
|
||||||
|
void hello(rpc_msg_hello_rsp & response);
|
||||||
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
|
||||||
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
void get_alignment(rpc_msg_get_alignment_rsp & response);
|
||||||
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
void get_max_size(rpc_msg_get_max_size_rsp & response);
|
||||||
@ -846,6 +871,13 @@ private:
|
|||||||
std::unordered_set<ggml_backend_buffer_t> buffers;
|
std::unordered_set<ggml_backend_buffer_t> buffers;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void rpc_server::hello(rpc_msg_hello_rsp & response) {
|
||||||
|
response.major = RPC_PROTO_MAJOR_VERSION;
|
||||||
|
response.minor = RPC_PROTO_MINOR_VERSION;
|
||||||
|
response.patch = RPC_PROTO_PATCH_VERSION;
|
||||||
|
GGML_PRINT_DEBUG("[%s] version: %d.%d.%d\n", __func__, response.major, response.minor, response.patch);
|
||||||
|
}
|
||||||
|
|
||||||
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
bool rpc_server::get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response) {
|
||||||
ggml_backend_buffer_type_t buft;
|
ggml_backend_buffer_type_t buft;
|
||||||
struct ggml_init_params params {
|
struct ggml_init_params params {
|
||||||
@ -1271,8 +1303,24 @@ rpc_server::~rpc_server() {
|
|||||||
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
||||||
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
sockfd_t sockfd, size_t free_mem, size_t total_mem) {
|
||||||
rpc_server server(backend, cache_dir);
|
rpc_server server(backend, cache_dir);
|
||||||
|
uint8_t cmd;
|
||||||
|
if (!recv_data(sockfd, &cmd, 1)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
// the first command sent by the client must be HELLO
|
||||||
|
if (cmd != RPC_CMD_HELLO) {
|
||||||
|
fprintf(stderr, "Expected HELLO command, update client\n");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!recv_msg(sockfd, nullptr, 0)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
rpc_msg_hello_rsp response;
|
||||||
|
server.hello(response);
|
||||||
|
if (!send_msg(sockfd, &response, sizeof(response))) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
while (true) {
|
while (true) {
|
||||||
uint8_t cmd;
|
|
||||||
if (!recv_data(sockfd, &cmd, 1)) {
|
if (!recv_data(sockfd, &cmd, 1)) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -1282,6 +1330,10 @@ static void rpc_serve_client(ggml_backend_t backend, const char * cache_dir,
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
switch (cmd) {
|
switch (cmd) {
|
||||||
|
case RPC_CMD_HELLO: {
|
||||||
|
// HELLO command is handled above
|
||||||
|
return;
|
||||||
|
}
|
||||||
case RPC_CMD_ALLOC_BUFFER: {
|
case RPC_CMD_ALLOC_BUFFER: {
|
||||||
rpc_msg_alloc_buffer_req request;
|
rpc_msg_alloc_buffer_req request;
|
||||||
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
if (!recv_msg(sockfd, &request, sizeof(request))) {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user