From 944ce4943947387ad60e56ed0360bafdb5a6befb Mon Sep 17 00:00:00 2001 From: Sacha Arbonel Date: Sat, 21 Dec 2024 11:05:05 +0100 Subject: [PATCH] server : add option to suppress non-speech tokens (#2649) * The parameter will suppress non-speech tokens like [LAUGH], [SIGH], etc. from the output when enabled. * add to whisper_params_parse * add missing param --- examples/server/server.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index af467513..f0484b34 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -76,6 +76,7 @@ struct whisper_params { bool no_timestamps = false; bool use_gpu = true; bool flash_attn = false; + bool suppress_non_speech_tokens = false; std::string language = "en"; std::string prompt = ""; @@ -135,6 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str()); fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false"); + fprintf(stderr, " -sns, --suppress-non-speech [%-7s] suppress non-speech tokens\n", params.suppress_non_speech_tokens ? "true" : "false"); fprintf(stderr, "\n"); } @@ -179,6 +181,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-sns" || arg == "--suppress-non-speech") { params.suppress_non_speech_tokens = true; } // server params else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--host") { sparams.hostname = argv[++i]; } @@ -472,6 +475,10 @@ void get_req_parameters(const Request & req, whisper_params & params) { params.temperature_inc = std::stof(req.get_file_value("temperature_inc").content); } + if (req.has_file("suppress_non_speech")) + { + params.suppress_non_speech_tokens = parse_str_to_bool(req.get_file_value("suppress_non_speech").content); + } } } // namespace @@ -786,6 +793,8 @@ int main(int argc, char ** argv) { wparams.no_timestamps = params.no_timestamps; wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format; + wparams.suppress_non_speech_tokens = params.suppress_non_speech_tokens; + whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 }; // this callback is called on each new segment