diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index 2f264fd3..9ba86b62 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -16,6 +16,7 @@ const whisperParamsMock = { comma_in_time: false, translate: true, no_timestamps: false, + audio_ctx: 0, }; describe("Run whisper.node", () => { diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 85576311..8125e5dd 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -19,6 +19,7 @@ struct whisper_params { int32_t max_len = 0; int32_t best_of = 5; int32_t beam_size = -1; + int32_t audio_ctx = 0; float word_thold = 0.01f; float entropy_thold = 2.4f; @@ -46,6 +47,8 @@ struct whisper_params { std::vector fname_inp = {}; std::vector fname_out = {}; + + std::vector pcmf32 = {}; // mono-channel F32 PCM }; struct whisper_print_user_data { @@ -125,13 +128,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper void cb_log_disable(enum ggml_log_level, const char *, void *) {} int run(whisper_params ¶ms, std::vector> &result) { - if (params.no_prints) { whisper_log_set(cb_log_disable, NULL); } - if (params.fname_inp.empty()) { - fprintf(stderr, "error: no input files specified\n"); + if (params.fname_inp.empty() && params.pcmf32.empty()) { + fprintf(stderr, "error: no input files or audio buffer specified\n"); return 2; } @@ -151,6 +153,14 @@ int run(whisper_params ¶ms, std::vector> &result) { return 3; } + // if params.pcmf32 is provided, set params.fname_inp to "buffer" + // this is simpler than further modifications in the code + if (!params.pcmf32.empty()) { + fprintf(stderr, "info: using audio buffer as input\n"); + params.fname_inp.clear(); + params.fname_inp.emplace_back("buffer"); + } + for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; @@ -158,9 +168,14 @@ int run(whisper_params ¶ms, std::vector> &result) { std::vector pcmf32; // mono-channel F32 PCM std::vector> pcmf32s; // stereo-channel F32 PCM - if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { - fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); - continue; + // read the input audio file if params.pcmf32 is not provided + if (params.pcmf32.empty()) { + if (!::read_wav(fname_inp, pcmf32, pcmf32s, params.diarize)) { + fprintf(stderr, "error: failed to read WAV file '%s'\n", fname_inp.c_str()); + continue; + } + } else { + pcmf32 = params.pcmf32; } // print system information @@ -180,12 +195,13 @@ int run(whisper_params ¶ms, std::vector> &result) { fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__); } } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n", + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n", __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, params.n_threads, params.n_processors, params.language.c_str(), params.translate ? "translate" : "transcribe", - params.no_timestamps ? 0 : 1); + params.no_timestamps ? 0 : 1, + params.audio_ctx); fprintf(stderr, "\n"); } @@ -212,6 +228,7 @@ int run(whisper_params ¶ms, std::vector> &result) { wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold; wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; @@ -311,14 +328,28 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { bool use_gpu = whisper_params.Get("use_gpu").As(); bool no_prints = whisper_params.Get("no_prints").As(); bool no_timestamps = whisper_params.Get("no_timestamps").As(); + int32_t audio_ctx = whisper_params.Get("audio_ctx").As(); bool comma_in_time = whisper_params.Get("comma_in_time").As(); + Napi::Value pcmf32Value = whisper_params.Get("pcmf32"); + std::vector pcmf32_vec; + if (pcmf32Value.IsTypedArray()) { + Napi::Float32Array pcmf32 = pcmf32Value.As(); + size_t length = pcmf32.ElementLength(); + pcmf32_vec.reserve(length); + for (size_t i = 0; i < length; i++) { + pcmf32_vec.push_back(pcmf32[i]); + } + } + params.language = language; params.model = model; params.fname_inp.emplace_back(input); params.use_gpu = use_gpu; params.no_prints = no_prints; params.no_timestamps = no_timestamps; + params.audio_ctx = audio_ctx; + params.pcmf32 = pcmf32_vec; params.comma_in_time = comma_in_time; Napi::Function callback = info[1].As(); diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 90bd6fc2..09b33c54 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -16,13 +16,20 @@ const whisperParams = { comma_in_time: false, translate: true, no_timestamps: false, + audio_ctx: 0, }; const arguments = process.argv.slice(2); const params = Object.fromEntries( arguments.reduce((pre, item) => { if (item.startsWith("--")) { - return [...pre, item.slice(2).split("=")]; + const [key, value] = item.slice(2).split("="); + if (key === "audio_ctx") { + whisperParams[key] = parseInt(value); + } else { + whisperParams[key] = value; + } + return pre; } return pre; }, [])