From 925915ae3749d76b347573971ad63dee633c7470 Mon Sep 17 00:00:00 2001 From: Hrishikesh Barman Date: Tue, 25 Jul 2023 21:23:34 +0530 Subject: [PATCH] whisper : move progress calculation out of whisper.cpp (#1081) Current `progress_step` was hardcoded into whisper.cpp, this resulted in bindings having to access progress only at that step even if progress callback was being called at every iteration. With this change we get greater granularity progress reporting from whisper.cpp and bindings/implementations can define their own progress step. --- examples/main/main.cpp | 17 ++++++++++++++++- whisper.cpp | 11 +---------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 8dd31d02..4fbc3f69 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -59,6 +59,7 @@ struct whisper_params { int32_t offset_t_ms = 0; int32_t offset_n = 0; int32_t duration_ms = 0; + int32_t progress_step = 5; int32_t max_context = -1; int32_t max_len = 0; int32_t best_of = 2; @@ -218,6 +219,7 @@ struct whisper_print_user_data { const whisper_params * params; const std::vector> * pcmf32s; + int progress_prev; }; std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { @@ -252,6 +254,14 @@ std::string estimate_diarization_speaker(std::vector> pcmf32s return speaker; } +void whisper_print_progress_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int progress, void * user_data) { + int progress_step = ((whisper_print_user_data *) user_data)->params->progress_step; + int * progress_prev = &(((whisper_print_user_data *) user_data)->progress_prev); + if (progress >= *progress_prev + progress_step) { + *progress_prev += progress_step; + fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress); + } +} void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; @@ -895,7 +905,7 @@ int main(int argc, char ** argv) { wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; - whisper_print_user_data user_data = { ¶ms, &pcmf32s }; + whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 }; // this callback is called on each new segment if (!wparams.print_realtime) { @@ -903,6 +913,11 @@ int main(int argc, char ** argv) { wparams.new_segment_callback_user_data = &user_data; } + if (wparams.print_progress) { + wparams.progress_callback = whisper_print_progress_callback; + wparams.progress_callback_user_data = &user_data; + } + // example for abort mechanism // in this example, we do not abort the processing, but we could if the flag is set to true // the callback is called before every encoder run - if it returns false, the processing is aborted diff --git a/whisper.cpp b/whisper.cpp index 38187457..ab734fea 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -4163,9 +4163,6 @@ int whisper_full_with_state( } } - int progress_prev = 0; - int progress_step = 5; - int seek = seek_start; std::vector prompt; @@ -4193,15 +4190,9 @@ int whisper_full_with_state( // main loop while (true) { const int progress_cur = (100*(seek - seek_start))/(seek_end - seek_start); - while (progress_cur >= progress_prev + progress_step) { - progress_prev += progress_step; - if (params.print_progress) { - fprintf(stderr, "%s: progress = %3d%%\n", __func__, progress_prev); - } - } if (params.progress_callback) { params.progress_callback( - ctx, ctx->state, progress_prev, params.progress_callback_user_data); + ctx, ctx->state, progress_cur, params.progress_callback_user_data); } // of only 1 second left, then stop