From f897eb767041bdd4d14ab7497d1d156e10149698 Mon Sep 17 00:00:00 2001 From: Karthick Date: Tue, 17 Dec 2024 22:45:47 +0530 Subject: [PATCH] whisper : support no_speech_thold (#2625) * Implement no_speech_thold no_speech_thold functionality is on par with OpenAI's whisper * Addressed review comments --- include/whisper.h | 2 +- src/whisper.cpp | 92 ++++++++++++++++++++++++++++++----------------- 2 files changed, 61 insertions(+), 33 deletions(-) diff --git a/include/whisper.h b/include/whisper.h index 9188d686..71949bdd 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -534,7 +534,7 @@ extern "C" { float temperature_inc; float entropy_thold; // similar to OpenAI's "compression_ratio_threshold" float logprob_thold; - float no_speech_thold; // TODO: not implemented + float no_speech_thold; struct { int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264 diff --git a/src/whisper.cpp b/src/whisper.cpp index 810a8d26..bcc530ae 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -867,6 +867,7 @@ struct whisper_state { whisper_token tid_last; std::vector energy; // PCM signal energy + float no_speech_prob = 0.0f; // [EXPERIMENTAL] Token-level timestamps with DTW whisper_aheads_masks aheads_masks; @@ -4825,6 +4826,42 @@ static const std::vector non_speech_tokens = { "♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯" }; +static void whisper_compute_logprobs( + const std::vector & logits, + const int n_logits, + std::vector & logprobs) { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } + } + logsumexp = logf(logsumexp) + logit_max; + + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } + } +} + +static void whisper_compute_probs( + const std::vector & logits, + const int n_logits, + const std::vector & logprobs, + std::vector & probs) { + for (int i = 0; i < n_logits; ++i) { + if (logits[i] == -INFINITY) { + probs[i] = 0.0f; + } else { + probs[i] = expf(logprobs[i]); + } + } +} + // process the logits for the selected decoder // - applies logit filters // - computes logprobs and probs @@ -4886,7 +4923,7 @@ static void whisper_process_logits( // suppress sot and nosp tokens logits[vocab.token_sot] = -INFINITY; - logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now + logits[vocab.token_nosp] = -INFINITY; // [TDRZ] when tinydiarize is disabled, suppress solm token if (params.tdrz_enable == false) { @@ -4985,24 +5022,7 @@ static void whisper_process_logits( } // populate the logprobs array (log_softmax) - { - const float logit_max = *std::max_element(logits.begin(), logits.end()); - float logsumexp = 0.0f; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { - logsumexp += expf(logits[i] - logit_max); - } - } - logsumexp = logf(logsumexp) + logit_max; - - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { - logprobs[i] = logits[i] - logsumexp; - } else { - logprobs[i] = -INFINITY; - } - } - } + whisper_compute_logprobs(logits, n_logits, logprobs); // if sum of probability over timestamps is above any other token, sample timestamp // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437 @@ -5060,15 +5080,7 @@ static void whisper_process_logits( } // compute probs - { - for (int i = 0; i < n_logits; ++i) { - if (logits[i] == -INFINITY) { - probs[i] = 0.0f; - } else { - probs[i] = expf(logprobs[i]); - } - } - } + whisper_compute_probs(logits, n_logits, logprobs, probs); #if 0 // print first 100 logits - token string : logit @@ -5647,6 +5659,18 @@ int whisper_full_with_state( return -8; } + // Calculate no_speech probability after first decode. + // This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits. + { + const int n_logits = ctx->vocab.id_to_token.size(); + std::vector logprobs(n_logits); + std::vector probs(n_logits); + + whisper_compute_logprobs(state->logits, n_logits, logprobs); + whisper_compute_probs(state->logits, n_logits, logprobs, probs); + state->no_speech_prob = probs[whisper_token_nosp(ctx)]; + } + { const int64_t t_start_sample_us = ggml_time_us(); @@ -6038,8 +6062,9 @@ int whisper_full_with_state( if (it != (int) temperatures.size() - 1) { const auto & decoder = state->decoders[best_decoder_id]; - if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) { - WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold); + if (decoder.failed || + (decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) { + WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold); success = false; state->n_fail_p++; } @@ -6068,6 +6093,9 @@ int whisper_full_with_state( // [EXPERIMENTAL] Token-level timestamps with DTW const auto n_segments_before = state->result_all.size(); + const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold && + best_decoder.sequence.avg_logprobs < params.logprob_thold); + //WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta); // update prompt_past @@ -6076,11 +6104,11 @@ int whisper_full_with_state( prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size()); } - for (int i = 0; i < result_len; ++i) { + for (int i = 0; i < result_len && !is_no_speech; ++i) { prompt_past.push_back(tokens_cur[i].id); } - if (!tokens_cur.empty() && ctx->model.n_loaded > 0) { + if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) { int i0 = 0; auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));