mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
whisper : support no_speech_thold (#2625)
Some checks are pending
Bindings Tests (Ruby) / ubuntu-latest (push) Waiting to run
CI / ubuntu-latest (linux/amd64) (push) Waiting to run
CI / ubuntu-latest (linux/arm/v7) (push) Waiting to run
CI / ubuntu-latest (linux/arm64) (push) Waiting to run
CI / ubuntu-latest (linux/ppc64le) (push) Waiting to run
CI / macOS-latest (push) Waiting to run
CI / ubuntu-latest-gcc (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm/v7, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm/v7, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Waiting to run
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Waiting to run
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Waiting to run
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Waiting to run
CI / emscripten (Release) (push) Waiting to run
CI / ios-xcode-build (Release) (push) Waiting to run
CI / android (push) Waiting to run
CI / quantize (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Waiting to run
Some checks are pending
Bindings Tests (Ruby) / ubuntu-latest (push) Waiting to run
CI / ubuntu-latest (linux/amd64) (push) Waiting to run
CI / ubuntu-latest (linux/arm/v7) (push) Waiting to run
CI / ubuntu-latest (linux/arm64) (push) Waiting to run
CI / ubuntu-latest (linux/ppc64le) (push) Waiting to run
CI / macOS-latest (push) Waiting to run
CI / ubuntu-latest-gcc (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm/v7, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm/v7, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Waiting to run
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Waiting to run
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Waiting to run
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Waiting to run
CI / emscripten (Release) (push) Waiting to run
CI / ios-xcode-build (Release) (push) Waiting to run
CI / android (push) Waiting to run
CI / quantize (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Waiting to run
* Implement no_speech_thold no_speech_thold functionality is on par with OpenAI's whisper * Addressed review comments
This commit is contained in:
parent
2f2841bfce
commit
f897eb7670
@ -534,7 +534,7 @@ extern "C" {
|
|||||||
float temperature_inc;
|
float temperature_inc;
|
||||||
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
|
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
|
||||||
float logprob_thold;
|
float logprob_thold;
|
||||||
float no_speech_thold; // TODO: not implemented
|
float no_speech_thold;
|
||||||
|
|
||||||
struct {
|
struct {
|
||||||
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
|
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
|
||||||
|
@ -867,6 +867,7 @@ struct whisper_state {
|
|||||||
whisper_token tid_last;
|
whisper_token tid_last;
|
||||||
|
|
||||||
std::vector<float> energy; // PCM signal energy
|
std::vector<float> energy; // PCM signal energy
|
||||||
|
float no_speech_prob = 0.0f;
|
||||||
|
|
||||||
// [EXPERIMENTAL] Token-level timestamps with DTW
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
||||||
whisper_aheads_masks aheads_masks;
|
whisper_aheads_masks aheads_masks;
|
||||||
@ -4825,6 +4826,42 @@ static const std::vector<std::string> non_speech_tokens = {
|
|||||||
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
"♪♪♪","♩", "♪", "♫", "♬", "♭", "♮", "♯"
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static void whisper_compute_logprobs(
|
||||||
|
const std::vector<float> & logits,
|
||||||
|
const int n_logits,
|
||||||
|
std::vector<float> & logprobs) {
|
||||||
|
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
||||||
|
float logsumexp = 0.0f;
|
||||||
|
for (int i = 0; i < n_logits; ++i) {
|
||||||
|
if (logits[i] > -INFINITY) {
|
||||||
|
logsumexp += expf(logits[i] - logit_max);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
logsumexp = logf(logsumexp) + logit_max;
|
||||||
|
|
||||||
|
for (int i = 0; i < n_logits; ++i) {
|
||||||
|
if (logits[i] > -INFINITY) {
|
||||||
|
logprobs[i] = logits[i] - logsumexp;
|
||||||
|
} else {
|
||||||
|
logprobs[i] = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void whisper_compute_probs(
|
||||||
|
const std::vector<float> & logits,
|
||||||
|
const int n_logits,
|
||||||
|
const std::vector<float> & logprobs,
|
||||||
|
std::vector<float> & probs) {
|
||||||
|
for (int i = 0; i < n_logits; ++i) {
|
||||||
|
if (logits[i] == -INFINITY) {
|
||||||
|
probs[i] = 0.0f;
|
||||||
|
} else {
|
||||||
|
probs[i] = expf(logprobs[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// process the logits for the selected decoder
|
// process the logits for the selected decoder
|
||||||
// - applies logit filters
|
// - applies logit filters
|
||||||
// - computes logprobs and probs
|
// - computes logprobs and probs
|
||||||
@ -4886,7 +4923,7 @@ static void whisper_process_logits(
|
|||||||
|
|
||||||
// suppress sot and nosp tokens
|
// suppress sot and nosp tokens
|
||||||
logits[vocab.token_sot] = -INFINITY;
|
logits[vocab.token_sot] = -INFINITY;
|
||||||
logits[vocab.token_nosp] = -INFINITY; // TODO: ignore this token for now
|
logits[vocab.token_nosp] = -INFINITY;
|
||||||
|
|
||||||
// [TDRZ] when tinydiarize is disabled, suppress solm token
|
// [TDRZ] when tinydiarize is disabled, suppress solm token
|
||||||
if (params.tdrz_enable == false) {
|
if (params.tdrz_enable == false) {
|
||||||
@ -4985,24 +5022,7 @@ static void whisper_process_logits(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// populate the logprobs array (log_softmax)
|
// populate the logprobs array (log_softmax)
|
||||||
{
|
whisper_compute_logprobs(logits, n_logits, logprobs);
|
||||||
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
|
||||||
float logsumexp = 0.0f;
|
|
||||||
for (int i = 0; i < n_logits; ++i) {
|
|
||||||
if (logits[i] > -INFINITY) {
|
|
||||||
logsumexp += expf(logits[i] - logit_max);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
logsumexp = logf(logsumexp) + logit_max;
|
|
||||||
|
|
||||||
for (int i = 0; i < n_logits; ++i) {
|
|
||||||
if (logits[i] > -INFINITY) {
|
|
||||||
logprobs[i] = logits[i] - logsumexp;
|
|
||||||
} else {
|
|
||||||
logprobs[i] = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// if sum of probability over timestamps is above any other token, sample timestamp
|
// if sum of probability over timestamps is above any other token, sample timestamp
|
||||||
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
|
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L431-L437
|
||||||
@ -5060,15 +5080,7 @@ static void whisper_process_logits(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// compute probs
|
// compute probs
|
||||||
{
|
whisper_compute_probs(logits, n_logits, logprobs, probs);
|
||||||
for (int i = 0; i < n_logits; ++i) {
|
|
||||||
if (logits[i] == -INFINITY) {
|
|
||||||
probs[i] = 0.0f;
|
|
||||||
} else {
|
|
||||||
probs[i] = expf(logprobs[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
// print first 100 logits - token string : logit
|
// print first 100 logits - token string : logit
|
||||||
@ -5647,6 +5659,18 @@ int whisper_full_with_state(
|
|||||||
return -8;
|
return -8;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Calculate no_speech probability after first decode.
|
||||||
|
// This has to be done before any logit filtering. Hence we cannot use the probs from the whisper_process_logits.
|
||||||
|
{
|
||||||
|
const int n_logits = ctx->vocab.id_to_token.size();
|
||||||
|
std::vector<float> logprobs(n_logits);
|
||||||
|
std::vector<float> probs(n_logits);
|
||||||
|
|
||||||
|
whisper_compute_logprobs(state->logits, n_logits, logprobs);
|
||||||
|
whisper_compute_probs(state->logits, n_logits, logprobs, probs);
|
||||||
|
state->no_speech_prob = probs[whisper_token_nosp(ctx)];
|
||||||
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
|
||||||
@ -6038,8 +6062,9 @@ int whisper_full_with_state(
|
|||||||
if (it != (int) temperatures.size() - 1) {
|
if (it != (int) temperatures.size() - 1) {
|
||||||
const auto & decoder = state->decoders[best_decoder_id];
|
const auto & decoder = state->decoders[best_decoder_id];
|
||||||
|
|
||||||
if (decoder.failed || decoder.sequence.avg_logprobs < params.logprob_thold) {
|
if (decoder.failed ||
|
||||||
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold);
|
(decoder.sequence.avg_logprobs < params.logprob_thold && state->no_speech_prob < params.no_speech_thold)) {
|
||||||
|
WHISPER_LOG_DEBUG("%s: failed due to avg_logprobs %8.5f < %8.5f and no_speech_prob %8.5f < %8.5f\n", __func__, decoder.sequence.avg_logprobs, params.logprob_thold, state->no_speech_prob, params.no_speech_thold);
|
||||||
success = false;
|
success = false;
|
||||||
state->n_fail_p++;
|
state->n_fail_p++;
|
||||||
}
|
}
|
||||||
@ -6068,6 +6093,9 @@ int whisper_full_with_state(
|
|||||||
// [EXPERIMENTAL] Token-level timestamps with DTW
|
// [EXPERIMENTAL] Token-level timestamps with DTW
|
||||||
const auto n_segments_before = state->result_all.size();
|
const auto n_segments_before = state->result_all.size();
|
||||||
|
|
||||||
|
const bool is_no_speech = (state->no_speech_prob > params.no_speech_thold &&
|
||||||
|
best_decoder.sequence.avg_logprobs < params.logprob_thold);
|
||||||
|
|
||||||
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
//WHISPER_LOG_DEBUG("prompt_init.size() = %d, prompt.size() = %d, result_len = %d, seek_delta = %d\n", prompt_init.size(), prompt.size(), result_len, seek_delta);
|
||||||
|
|
||||||
// update prompt_past
|
// update prompt_past
|
||||||
@ -6076,11 +6104,11 @@ int whisper_full_with_state(
|
|||||||
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end() - prompt_init.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < result_len; ++i) {
|
for (int i = 0; i < result_len && !is_no_speech; ++i) {
|
||||||
prompt_past.push_back(tokens_cur[i].id);
|
prompt_past.push_back(tokens_cur[i].id);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!tokens_cur.empty() && ctx->model.n_loaded > 0) {
|
if (!tokens_cur.empty() && ctx->model.n_loaded > 0 && !is_no_speech) {
|
||||||
int i0 = 0;
|
int i0 = 0;
|
||||||
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
auto t0 = seek + 2*(tokens_cur.front().tid - whisper_token_beg(ctx));
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user