diff --git a/examples/server/server.cpp b/examples/server/server.cpp index df508839..8b6c5a96 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -14,10 +14,23 @@ #include #include #include +#include +#include +#include +#include +#include +#if defined (_WIN32) +#include +#endif using namespace httplib; using json = nlohmann::ordered_json; +enum server_state { + SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet + SERVER_STATE_READY, // Server is ready and model is loaded +}; + namespace { // output formats @@ -27,6 +40,20 @@ const std::string srt_format = "srt"; const std::string vjson_format = "verbose_json"; const std::string vtt_format = "vtt"; +std::function shutdown_handler; +std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + +inline void signal_handler(int signal) { + if (is_terminating.test_and_set()) { + // in case it hangs, we can force terminate the server by hitting Ctrl+C twice + // this is for better developer experience, we can remove when the server is stable enough + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + + shutdown_handler(signal); +} + struct server_params { std::string hostname = "127.0.0.1"; @@ -654,6 +681,9 @@ int main(int argc, char ** argv) { } } + std::unique_ptr svr = std::make_unique(); + std::atomic state{SERVER_STATE_LOADING_MODEL}; + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); if (ctx == nullptr) { @@ -663,9 +693,10 @@ int main(int argc, char ** argv) { // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + state.store(SERVER_STATE_READY); - Server svr; - svr.set_default_headers({{"Server", "whisper.cpp"}, + + svr->set_default_headers({{"Server", "whisper.cpp"}, {"Access-Control-Allow-Origin", "*"}, {"Access-Control-Allow-Headers", "content-type, authorization"}}); @@ -744,15 +775,15 @@ int main(int argc, char ** argv) { whisper_params default_params = params; // this is only called if no index.html is found in the public --path - svr.Get(sparams.request_path + "/", [&default_content](const Request &, Response &res){ + svr->Get(sparams.request_path + "/", [&](const Request &, Response &res){ res.set_content(default_content, "text/html"); return false; }); - svr.Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){ + svr->Options(sparams.request_path + sparams.inference_path, [&](const Request &, Response &){ }); - svr.Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){ + svr->Post(sparams.request_path + sparams.inference_path, [&](const Request &req, Response &res){ // acquire whisper model mutex lock std::lock_guard lock(whisper_mutex); @@ -1068,8 +1099,9 @@ int main(int argc, char ** argv) { // reset params to their defaults params = default_params; }); - svr.Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ + svr->Post(sparams.request_path + "/load", [&](const Request &req, Response &res){ std::lock_guard lock(whisper_mutex); + state.store(SERVER_STATE_LOADING_MODEL); if (!req.has_file("model")) { fprintf(stderr, "error: no 'model' field in the request\n"); @@ -1101,18 +1133,25 @@ int main(int argc, char ** argv) { // initialize openvino encoder. this has no effect on whisper.cpp builds that don't have OpenVINO configured whisper_ctx_init_openvino_encoder(ctx, nullptr, params.openvino_encode_device.c_str(), nullptr); + state.store(SERVER_STATE_READY); const std::string success = "Load was successful!"; res.set_content(success, "application/text"); // check if the model is in the file system }); - svr.Get(sparams.request_path + "/health", [&](const Request &, Response &res){ - const std::string health_response = "{\"status\":\"ok\"}"; - res.set_content(health_response, "application/json"); + svr->Get(sparams.request_path + "/health", [&](const Request &, Response &res){ + server_state current_state = state.load(); + if (current_state == SERVER_STATE_READY) { + const std::string health_response = "{\"status\":\"ok\"}"; + res.set_content(health_response, "application/json"); + } else { + res.set_content("{\"status\":\"loading model\"}", "application/json"); + res.status = 503; + } }); - svr.set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) { + svr->set_exception_handler([](const Request &, Response &res, std::exception_ptr ep) { const char fmt[] = "500 Internal Server Error\n%s"; char buf[BUFSIZ]; try { @@ -1126,7 +1165,7 @@ int main(int argc, char ** argv) { res.status = 500; }); - svr.set_error_handler([](const Request &req, Response &res) { + svr->set_error_handler([](const Request &req, Response &res) { if (res.status == 400) { res.set_content("Invalid request", "text/plain"); } else if (res.status != 500) { @@ -1136,10 +1175,10 @@ int main(int argc, char ** argv) { }); // set timeouts and change hostname and port - svr.set_read_timeout(sparams.read_timeout); - svr.set_write_timeout(sparams.write_timeout); + svr->set_read_timeout(sparams.read_timeout); + svr->set_write_timeout(sparams.write_timeout); - if (!svr.bind_to_port(sparams.hostname, sparams.port)) + if (!svr->bind_to_port(sparams.hostname, sparams.port)) { fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port); @@ -1147,18 +1186,50 @@ int main(int argc, char ** argv) { } // Set the base directory for serving static files - svr.set_base_dir(sparams.public_path); + svr->set_base_dir(sparams.public_path); // to make it ctrl+clickable: printf("\nwhisper server listening at http://%s:%d\n\n", sparams.hostname.c_str(), sparams.port); - if (!svr.listen_after_bind()) - { - return 1; - } + shutdown_handler = [&](int signal) { + printf("\nCaught signal %d, shutting down gracefully...\n", signal); + if (svr) { + svr->stop(); + } + }; - whisper_print_timings(ctx); - whisper_free(ctx); +#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) + struct sigaction sigint_action; + sigint_action.sa_handler = signal_handler; + sigemptyset (&sigint_action.sa_mask); + sigint_action.sa_flags = 0; + sigaction(SIGINT, &sigint_action, NULL); + sigaction(SIGTERM, &sigint_action, NULL); +#elif defined (_WIN32) + auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { + return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; + }; + SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); +#endif + + // clean up function, to be called before exit + auto clean_up = [&]() { + whisper_print_timings(ctx); + whisper_free(ctx); + }; + + std::thread t([&] { + if (!svr->listen_after_bind()) { + fprintf(stderr, "error: server listen failed\n"); + } + }); + + svr->wait_until_ready(); + + t.join(); + + + clean_up(); return 0; }