whisper : support no_speech_thold (#2625)
Some checks are pending
Bindings Tests (Ruby) / ubuntu-latest (push) Waiting to run
CI / ubuntu-latest (linux/amd64) (push) Waiting to run
CI / ubuntu-latest (linux/arm/v7) (push) Waiting to run
CI / ubuntu-latest (linux/arm64) (push) Waiting to run
CI / ubuntu-latest (linux/ppc64le) (push) Waiting to run
CI / macOS-latest (push) Waiting to run
CI / ubuntu-latest-gcc (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm/v7, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm/v7, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Waiting to run
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Waiting to run
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Waiting to run
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Waiting to run
CI / emscripten (Release) (push) Waiting to run
CI / ios-xcode-build (Release) (push) Waiting to run
CI / android (push) Waiting to run
CI / quantize (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Waiting to run

* Implement no_speech_thold

no_speech_thold functionality is on par with OpenAI's whisper

* Addressed review comments
This commit is contained in:
Karthick 2024-12-17 22:45:47 +05:30 committed by GitHub
parent 2f2841bfce
commit f897eb7670
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 33 deletions

View File

@ -534,7 +534,7 @@ extern "C" {
float temperature_inc;
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
float logprob_thold;
float no_speech_thold; // TODO: not implemented
float no_speech_thold;
struct {
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264

View File

@ -867,6 +867,7 @@ struct whisper_state {
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
float no_speech_prob = 0.0f;
// [EXPERIMENTAL] Token-level timestamps with DTW
whisper_aheads_masks aheads_masks;
@ -4825,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
"♪♪♪","", "", "", "", "", "", ""
};
static void whisper_compute_logprobs(
const std::vector<float> & logits,
const int n_logits,
std::vector<float> & logprobs) {
const float logit_max = *std::max_element(logits.begin(), logits.end());
float logsumexp = 0.0f;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logsumexp += expf(logits[i] - logit_max);
}
}
logsumexp = logf(logsumexp) + logit_max;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logprobs[i] = logits[i] - logsumexp;
} else {
logprobs[i] = -INFINITY;
}
}
}
static void whisper_compute_probs(
const std::vector<float> & logits,
const int n_logits,
const std::vector<float> & logprobs,
std::vector<float> & probs) {
for (int i = 0; i < n_logits; ++i) {
if (logits[i] == -INFINITY) {
probs[i] = 0.0f;
} else {
probs[i] = expf(logprobs[i]);
}
}
}
// process the logits for the selected decoder
// - applies logit filters
// - computes logprobs and probs
@ -4886,7 +4923,7 @@ static void whisper_process_logits(
// suppress sot and nosp tokens
logits[vocab.token_sot] = -INFINITY;
logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
logits[vocab.token_nosp] = -INFINITY;
// [TDRZ] when tinydiarize is disabled, suppress solm token
if (params.tdrz_enable == false) {
@ -4985,24 +5022,7 @@ static void whisper_process_logits(
}
// populate the logprobs array (log_softmax)
{
const float logit_max = *std::max_element(logits.begin(), logits.end());
float logsumexp = 0.0f;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logsumexp += expf(logits[i] - logit_max);
}
}
logsumexp = logf(logsumexp) + logit_max;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logprobs[i] = logits[i] - logsumexp;
} else {
logprobs[i] = -INFINITY;
}
}
}
whisper_compute_logprobs(logits, n_logits, logprobs);
// if sum of probability over timestamps is above any other token, sample timestamp
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
@ -5060,15 +5080,7 @@ static void whisper_process_logits(
}
// compute probs
{
for (int i = 0; i < n_logits; ++i) {
if (logits[i] == -INFINITY) {
probs[i] = 0.0f;
} else {
probs[i] = expf(logprobs[i]);
}
}
}
whisper_compute_probs(logits, n_logits, logprobs, probs);
#if 0
// print first 100 logits - token string : logit
@ -5647,6 +5659,18 @@ int whisper_full_with_state(
return -8;
}
// Calculate no_speech probability after first decode.
// This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
{
const int n_logits = ctx->vocab.id_to_token.size();
std::vector<float> logprobs(n_logits);
std::vector<float> probs(n_logits);
whisper_compute_logprobs(state->logits, n_logits, logprobs);
whisper_compute_probs(state->logits, n_logits, logprobs, probs);
state->no_speech_prob = probs[whisper_token_nosp(ctx)];
}
{
const int64_t t_start_sample_us = ggml_time_us();
@ -6038,8 +6062,9 @@ int whisper_full_with_state(
if (it != (int) temperatures.size() - 1) {
const auto & decoder = state->decoders[best_decoder_id];
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
if (decoder.failed ||
(decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
success = false;
state->n_fail_p++;
}
@ -6068,6 +6093,9 @@ int whisper_full_with_state(
// [EXPERIMENTAL] Token-level timestamps with DTW
const auto n_segments_before = state->result_all.size();
const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
best_decoder.sequence.avg_logprobs < params.logprob_thold);
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
// update prompt_past
@ -6076,11 +6104,11 @@ int whisper_full_with_state(
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
}
for (int i = 0; i < result_len; ++i) {
for (int i = 0; i < result_len && !is_no_speech; ++i) {
prompt_past.push_back(tokens_cur[i].id);
}
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
int i0 = 0;
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));