From 4183517076d6f8de8f89d6b2eae3a1456ce5e29c Mon Sep 17 00:00:00 2001 From: Sacha Arbonel Date: Sat, 21 Dec 2024 16:00:08 +0100 Subject: [PATCH] server : add no-speech threshold parameter and functionality (#2654) --- examples/server/server.cpp | 7 ++++++- include/whisper.h | 2 ++ src/whisper.cpp | 9 +++++++-- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 0608bb6b..a2ef726f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -61,6 +61,7 @@ struct whisper_params { float logprob_thold = -1.00f; float temperature = 0.00f; float temperature_inc = 0.20f; + float no_speech_thold = 0.6f; bool debug_mode = false; bool translate = false; @@ -137,6 +138,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); + fprintf(stderr, " -nth N, --no-speech-thold N [%-7.2f] no speech threshold\n", params.no_speech_thold); fprintf(stderr, "\n"); } @@ -182,6 +184,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } + else if (arg == "-nth" || arg == "--no-speech-thold") { params.no_speech_thold = std::stof(argv[++i]); } + // server params else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--host") { sparams.hostname = argv[++i]; } @@ -790,6 +794,7 @@ int main(int argc, char ** argv) { wparams.beam_search.beam_size = params.beam_size; wparams.temperature = params.temperature; + wparams.no_speech_thold = params.no_speech_thold; wparams.temperature_inc = params.temperature_inc; wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; @@ -942,7 +947,7 @@ int main(int argc, char ** argv) { // TODO compression_ratio and no_speech_prob are not implemented yet // segment["compression_ratio"] = 0; - // segment["no_speech_prob"] = 0; + segment["no_speech_prob"] = whisper_full_get_segment_no_speech_prob(ctx, i); jres["segments"].push_back(segment); } diff --git a/include/whisper.h b/include/whisper.h index 6e0db505..03ce110d 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -665,6 +665,8 @@ extern "C" { WHISPER_API void whisper_log_set(ggml_log_callback log_callback, void * user_data); + // Get the no_speech probability for the specified segment + WHISPER_API float whisper_full_get_segment_no_speech_prob (struct whisper_context * ctx, int i_segment); #ifdef __cplusplus } #endif diff --git a/src/whisper.cpp b/src/whisper.cpp index ff200ea0..5a9f3df8 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -428,6 +428,7 @@ struct whisper_segment { int64_t t1; std::string text; + float no_speech_prob; std::vector tokens; @@ -6147,7 +6148,7 @@ int whisper_full_with_state( //printf("tt0 = %d, tt1 = %d, text = %s, token = %s, token_id = %d, tid = %d\n", tt0, tt1, text.c_str(), ctx->vocab.id_to_token[tokens_cur[i].id].c_str(), tokens_cur[i].id, tokens_cur[i].tid); - result_all.push_back({ tt0, tt1, text, {}, speaker_turn_next }); + result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next }); for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -6192,7 +6193,7 @@ int whisper_full_with_state( } } - result_all.push_back({ tt0, tt1, text, {} , speaker_turn_next }); + result_all.push_back({ tt0, tt1, text, state->no_speech_prob, {}, speaker_turn_next }); for (int j = i0; j < (int) tokens_cur.size(); j++) { result_all.back().tokens.push_back(tokens_cur[j]); } @@ -6459,6 +6460,10 @@ float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int return ctx->state->result_all[i_segment].tokens[i_token].p; } +float whisper_full_get_segment_no_speech_prob(struct whisper_context * ctx, int i_segment) { + return ctx->state->result_all[i_segment].no_speech_prob; +} + // ================================================================================================= //