diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index 1ee888a1..066f382b 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -18,6 +18,7 @@ const whisperParamsMock = { translate: true, no_timestamps: false, audio_ctx: 0, + max_len: 0, }; describe("Run whisper.node", () => { diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index cc054503..d4773ce0 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -128,192 +128,227 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper void cb_log_disable(enum ggml_log_level, const char *, void *) {} -int run(whisper_params ¶ms, std::vector> &result) { - if (params.no_prints) { - whisper_log_set(cb_log_disable, NULL); - } - - if (params.fname_inp.empty() && params.pcmf32.empty()) { - fprintf(stderr, "error: no input files or audio buffer specified\n"); - return 2; - } - - if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) { - fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); - exit(0); - } - - // whisper init - - struct whisper_context_params cparams = whisper_context_default_params(); - cparams.use_gpu = params.use_gpu; - cparams.flash_attn = params.flash_attn; - struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); - - if (ctx == nullptr) { - fprintf(stderr, "error: failed to initialize whisper context\n"); - return 3; - } - - // if params.pcmf32 is provided, set params.fname_inp to "buffer" - // this is simpler than further modifications in the code - if (!params.pcmf32.empty()) { - fprintf(stderr, "info: using audio buffer as input\n"); - params.fname_inp.clear(); - params.fname_inp.emplace_back("buffer"); - } - - for (int f = 0; f < (int) params.fname_inp.size(); ++f) { - const auto fname_inp = params.fname_inp[f]; - const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; - - std::vector pcmf32; // mono-channel F32 PCM - std::vector> pcmf32s; // stereo-channel F32 PCM - - // read the input audio file if params.pcmf32 is not provided - if (params.pcmf32.empty()) { - if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) { - fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str()); - continue; - } - } else { - pcmf32 = params.pcmf32; - } - - // print system information - if (!params.no_prints) { - fprintf(stderr, "\n"); - fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", - params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); - } - - // print some info about the processing - if (!params.no_prints) { - fprintf(stderr, "\n"); - if (!whisper_is_multilingual(ctx)) { - if (params.language != "en" || params.translate) { - params.language = "en"; - params.translate = false; - fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); - } - } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n", - __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, - params.n_threads, params.n_processors, - params.language.c_str(), - params.translate ? "translate" : "transcribe", - params.no_timestamps ? 0 : 1, - params.audio_ctx); - - fprintf(stderr, "\n"); - } - - // run the inference - { - whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - - wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; - - wparams.print_realtime = false; - wparams.print_progress = params.print_progress; - wparams.print_timestamps = !params.no_timestamps; - wparams.print_special = params.print_special; - wparams.translate = params.translate; - wparams.language = params.language.c_str(); - wparams.n_threads = params.n_threads; - wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; - wparams.offset_ms = params.offset_t_ms; - wparams.duration_ms = params.duration_ms; - - wparams.token_timestamps = params.output_wts || params.max_len > 0; - wparams.thold_pt = params.word_thold; - wparams.entropy_thold = params.entropy_thold; - wparams.logprob_thold = params.logprob_thold; - wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; - wparams.audio_ctx = params.audio_ctx; - - wparams.greedy.best_of = params.best_of; - wparams.beam_search.beam_size = params.beam_size; - - wparams.initial_prompt = params.prompt.c_str(); - - wparams.no_timestamps = params.no_timestamps; - - whisper_print_user_data user_data = { ¶ms, &pcmf32s }; - - // this callback is called on each new segment - if (!wparams.print_realtime) { - wparams.new_segment_callback = whisper_print_segment_callback; - wparams.new_segment_callback_user_data = &user_data; - } - - // example for abort mechanism - // in this example, we do not abort the processing, but we could if the flag is set to true - // the callback is called before every encoder run - if it returns false, the processing is aborted - { - static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - - wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - bool is_aborted = *(bool*)user_data; - return !is_aborted; - }; - wparams.encoder_begin_callback_user_data = &is_aborted; - } - - if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { - fprintf(stderr, "failed to process audio\n"); - return 10; - } - } - } - - const int n_segments = whisper_full_n_segments(ctx); - result.resize(n_segments); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - - result[i].emplace_back(to_timestamp(t0, params.comma_in_time)); - result[i].emplace_back(to_timestamp(t1, params.comma_in_time)); - result[i].emplace_back(text); - } - - whisper_print_timings(ctx); - whisper_free(ctx); - - return 0; -} - -class Worker : public Napi::AsyncWorker { +class ProgressWorker : public Napi::AsyncWorker { public: - Worker(Napi::Function& callback, whisper_params params) - : Napi::AsyncWorker(callback), params(params) {} - - void Execute() override { - run(params, result); - } - - void OnOK() override { - Napi::HandleScope scope(Env()); - Napi::Object res = Napi::Array::New(Env(), result.size()); - for (uint64_t i = 0; i < result.size(); ++i) { - Napi::Object tmp = Napi::Array::New(Env(), 3); - for (uint64_t j = 0; j < 3; ++j) { - tmp[j] = Napi::String::New(Env(), result[i][j]); - } - res[i] = tmp; + ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env) + : Napi::AsyncWorker(callback), params(params), env(env) { + // Create thread-safe function + if (!progress_callback.IsEmpty()) { + tsfn = Napi::ThreadSafeFunction::New( + env, + progress_callback, + "Progress Callback", + 0, + 1 + ); + } + } + + ~ProgressWorker() { + if (tsfn) { + // Make sure to release the thread-safe function on destruction + tsfn.Release(); + } + } + + void Execute() override { + // Use custom run function with progress callback support + run_with_progress(params, result); + } + + void OnOK() override { + Napi::HandleScope scope(Env()); + Napi::Object res = Napi::Array::New(Env(), result.size()); + for (uint64_t i = 0; i < result.size(); ++i) { + Napi::Object tmp = Napi::Array::New(Env(), 3); + for (uint64_t j = 0; j < 3; ++j) { + tmp[j] = Napi::String::New(Env(), result[i][j]); + } + res[i] = tmp; + } + Callback().Call({Env().Null(), res}); + } + + // Progress callback function - using thread-safe function + void OnProgress(int progress) { + if (tsfn) { + // Use thread-safe function to call JavaScript callback + auto callback = [progress](Napi::Env env, Napi::Function jsCallback) { + jsCallback.Call({Napi::Number::New(env, progress)}); + }; + + tsfn.BlockingCall(callback); + } } - Callback().Call({Env().Null(), res}); - } private: - whisper_params params; - std::vector> result; + whisper_params params; + std::vector> result; + Napi::Env env; + Napi::ThreadSafeFunction tsfn; + + // Custom run function with progress callback support + int run_with_progress(whisper_params ¶ms, std::vector> &result) { + if (params.no_prints) { + whisper_log_set(cb_log_disable, NULL); + } + + if (params.fname_inp.empty() && params.pcmf32.empty()) { + fprintf(stderr, "error: no input files or audio buffer specified\n"); + return 2; + } + + if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) { + fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); + exit(0); + } + + // whisper init + struct whisper_context_params cparams = whisper_context_default_params(); + cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); + + if (ctx == nullptr) { + fprintf(stderr, "error: failed to initialize whisper context\n"); + return 3; + } + + // If params.pcmf32 provides, set params.fname_inp as "buffer" + if (!params.pcmf32.empty()) { + fprintf(stderr, "info: using audio buffer as input\n"); + params.fname_inp.clear(); + params.fname_inp.emplace_back("buffer"); + } + + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { + const auto fname_inp = params.fname_inp[f]; + const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; + + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + // If params.pcmf32 is empty, read input audio file + if (params.pcmf32.empty()) { + if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) { + fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str()); + continue; + } + } else { + pcmf32 = params.pcmf32; + } + + // Print system info + if (!params.no_prints) { + fprintf(stderr, "\n"); + fprintf(stderr, "system_info: n_threads = %d / %d | %s\n", + params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info()); + } + + // Print processing info + if (!params.no_prints) { + fprintf(stderr, "\n"); + if (!whisper_is_multilingual(ctx)) { + if (params.language != "en" || params.translate) { + params.language = "en"; + params.translate = false; + fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); + } + } + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n", + __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, + params.n_threads, params.n_processors, + params.language.c_str(), + params.translate ? "translate" : "transcribe", + params.no_timestamps ? 0 : 1, + params.audio_ctx); + + fprintf(stderr, "\n"); + } + + // Run inference + { + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + + wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY; + + wparams.print_realtime = false; + wparams.print_progress = params.print_progress; + wparams.print_timestamps = !params.no_timestamps; + wparams.print_special = params.print_special; + wparams.translate = params.translate; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; + wparams.offset_ms = params.offset_t_ms; + wparams.duration_ms = params.duration_ms; + + wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.entropy_thold = params.entropy_thold; + wparams.logprob_thold = params.logprob_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + wparams.audio_ctx = params.audio_ctx; + + wparams.greedy.best_of = params.best_of; + wparams.beam_search.beam_size = params.beam_size; + + wparams.initial_prompt = params.prompt.c_str(); + + wparams.no_timestamps = params.no_timestamps; + + whisper_print_user_data user_data = { ¶ms, &pcmf32s }; + + // This callback is called for each new segment + if (!wparams.print_realtime) { + wparams.new_segment_callback = whisper_print_segment_callback; + wparams.new_segment_callback_user_data = &user_data; + } + + // Set progress callback + wparams.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) { + ProgressWorker* worker = static_cast(user_data); + worker->OnProgress(progress); + }; + wparams.progress_callback_user_data = this; + + // Abort mechanism example + { + static bool is_aborted = false; // Note: this should be atomic to avoid data races + + wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { + bool is_aborted = *(bool*)user_data; + return !is_aborted; + }; + wparams.encoder_begin_callback_user_data = &is_aborted; + } + + if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) { + fprintf(stderr, "failed to process audio\n"); + return 10; + } + } + } + + const int n_segments = whisper_full_n_segments(ctx); + result.resize(n_segments); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + + result[i].emplace_back(to_timestamp(t0, params.comma_in_time)); + result[i].emplace_back(to_timestamp(t1, params.comma_in_time)); + result[i].emplace_back(text); + } + + whisper_print_timings(ctx); + whisper_free(ctx); + + return 0; + } }; - - Napi::Value whisper(const Napi::CallbackInfo& info) { Napi::Env env = info.Env(); if (info.Length() <= 0 || !info[0].IsObject()) { @@ -332,6 +367,23 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { int32_t audio_ctx = whisper_params.Get("audio_ctx").As(); bool comma_in_time = whisper_params.Get("comma_in_time").As(); int32_t max_len = whisper_params.Get("max_len").As(); + + // support prompt + std::string prompt = ""; + if (whisper_params.Has("prompt") && whisper_params.Get("prompt").IsString()) { + prompt = whisper_params.Get("prompt").As(); + } + + // Add support for print_progress + bool print_progress = false; + if (whisper_params.Has("print_progress")) { + print_progress = whisper_params.Get("print_progress").As(); + } + // Add support for progress_callback + Napi::Function progress_callback; + if (whisper_params.Has("progress_callback") && whisper_params.Get("progress_callback").IsFunction()) { + progress_callback = whisper_params.Get("progress_callback").As(); + } Napi::Value pcmf32Value = whisper_params.Get("pcmf32"); std::vector pcmf32_vec; @@ -355,9 +407,12 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { params.pcmf32 = pcmf32_vec; params.comma_in_time = comma_in_time; params.max_len = max_len; + params.print_progress = print_progress; + params.prompt = prompt; Napi::Function callback = info[1].As(); - Worker* worker = new Worker(callback, params); + // Create a new Worker class with progress callback support + ProgressWorker* worker = new ProgressWorker(callback, params, progress_callback, env); worker->Queue(); return env.Undefined(); } diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 65fa31f8..408d6d33 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -19,6 +19,9 @@ const whisperParams = { no_timestamps: false, audio_ctx: 0, max_len: 0, + progress_callback: (progress) => { + console.log(`progress: ${progress}%`); + } }; const arguments = process.argv.slice(2);