examples : support progress_callback API for addon.node (#2941)
Some checks failed
CI / determine-tag (push) Waiting to run
CI / ubuntu-22 (linux/amd64) (push) Waiting to run
CI / ubuntu-22 (linux/ppc64le) (push) Waiting to run
CI / ubuntu-22-arm64 (linux/arm64) (push) Waiting to run
CI / ubuntu-22-arm-v7 (linux/arm/v7) (push) Waiting to run
CI / macOS-latest (generic/platform=iOS) (push) Waiting to run
CI / macOS-latest (generic/platform=macOS) (push) Waiting to run
CI / macOS-latest (generic/platform=tvOS) (push) Waiting to run
CI / ubuntu-22-gcc (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-22-gcc (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-22-gcc (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-22-gcc (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-22-gcc-arm64 (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-22-gcc-arm64 (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Debug) (push) Waiting to run
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Release) (push) Waiting to run
CI / ubuntu-22-clang (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-22-clang (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-22-clang (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-22-clang (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-22-clang (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-22-clang (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-22-gcc-sanitized (linux/amd64, ADDRESS) (push) Waiting to run
CI / ubuntu-22-gcc-sanitized (linux/amd64, THREAD) (push) Waiting to run
CI / ubuntu-22-gcc-sanitized (linux/amd64, UNDEFINED) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Waiting to run
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Waiting to run
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Waiting to run
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-cublas (x64, Release, ON, 11.8.0, ON, 2.28.5) (push) Waiting to run
CI / windows-cublas (x64, Release, ON, 12.2.0, ON, 2.28.5) (push) Waiting to run
CI / emscripten (Release) (push) Waiting to run
CI / ios-xcode-build (Release) (push) Blocked by required conditions
CI / android (push) Waiting to run
CI / quantize (push) Waiting to run
CI / release (push) Blocked by required conditions
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64 tag:main]) (push) Waiting to run
Examples Tests / addon_node-ubuntu-22 (16.x) (push) Has been cancelled
Examples Tests / addon_node-ubuntu-22 (18.x) (push) Has been cancelled

* feat: progress supported

* fix: missing params

* style: Format the code to improve readability

Unified code indentation ensures consistent coding style, enhancing code readability and maintainability.

* feat: support prompt api

---------

Co-authored-by: linxiaodong <calm.lin@wukongsch.com>
This commit is contained in:
Lin Xiaodong 2025-03-28 13:34:26 +08:00 committed by GitHub
parent f28bf5d186
commit 1279f0d0bc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 240 additions and 181 deletions

View File

@ -18,6 +18,7 @@ const whisperParamsMock = {
translate: true,
no_timestamps: false,
audio_ctx: 0,
max_len: 0,
};
describe("Run whisper.node", () => {

View File

@ -128,192 +128,227 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
void cb_log_disable(enum ggml_log_level, const char *, void *) {}
int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.no_prints) {
whisper_log_set(cb_log_disable, NULL);
}
if (params.fname_inp.empty() && params.pcmf32.empty()) {
fprintf(stderr, "error: no input files or audio buffer specified\n");
return 2;
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
exit(0);
}
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 3;
}
// if params.pcmf32 is provided, set params.fname_inp to "buffer"
// this is simpler than further modifications in the code
if (!params.pcmf32.empty()) {
fprintf(stderr, "info: using audio buffer as input\n");
params.fname_inp.clear();
params.fname_inp.emplace_back("buffer");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// read the input audio file if params.pcmf32 is not provided
if (params.pcmf32.empty()) {
if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str());
continue;
}
} else {
pcmf32 = params.pcmf32;
}
// print system information
if (!params.no_prints) {
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
}
// print some info about the processing
if (!params.no_prints) {
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1,
params.audio_ctx);
fprintf(stderr, "\n");
}
// run the inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.audio_ctx = params.audio_ctx;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.initial_prompt = params.prompt.c_str();
wparams.no_timestamps = params.no_timestamps;
whisper_print_user_data user_data = { &params, &pcmf32s };
// this callback is called on each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
// example for abort mechanism
// in this example, we do not abort the processing, but we could if the flag is set to true
// the callback is called before every encoder run - if it returns false, the processing is aborted
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "failed to process audio\n");
return 10;
}
}
}
const int n_segments = whisper_full_n_segments(ctx);
result.resize(n_segments);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
result[i].emplace_back(text);
}
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}
class Worker : public Napi::AsyncWorker {
class ProgressWorker : public Napi::AsyncWorker {
public:
Worker(Napi::Function& callback, whisper_params params)
: Napi::AsyncWorker(callback), params(params) {}
void Execute() override {
run(params, result);
}
void OnOK() override {
Napi::HandleScope scope(Env());
Napi::Object res = Napi::Array::New(Env(), result.size());
for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(Env(), 3);
for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(Env(), result[i][j]);
}
res[i] = tmp;
ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
: Napi::AsyncWorker(callback), params(params), env(env) {
// Create thread-safe function
if (!progress_callback.IsEmpty()) {
tsfn = Napi::ThreadSafeFunction::New(
env,
progress_callback,
"Progress Callback",
0,
1
);
}
}
~ProgressWorker() {
if (tsfn) {
// Make sure to release the thread-safe function on destruction
tsfn.Release();
}
}
void Execute() override {
// Use custom run function with progress callback support
run_with_progress(params, result);
}
void OnOK() override {
Napi::HandleScope scope(Env());
Napi::Object res = Napi::Array::New(Env(), result.size());
for (uint64_t i = 0; i < result.size(); ++i) {
Napi::Object tmp = Napi::Array::New(Env(), 3);
for (uint64_t j = 0; j < 3; ++j) {
tmp[j] = Napi::String::New(Env(), result[i][j]);
}
res[i] = tmp;
}
Callback().Call({Env().Null(), res});
}
// Progress callback function - using thread-safe function
void OnProgress(int progress) {
if (tsfn) {
// Use thread-safe function to call JavaScript callback
auto callback = [progress](Napi::Env env, Napi::Function jsCallback) {
jsCallback.Call({Napi::Number::New(env, progress)});
};
tsfn.BlockingCall(callback);
}
}
Callback().Call({Env().Null(), res});
}
private:
whisper_params params;
std::vector<std::vector<std::string>> result;
whisper_params params;
std::vector<std::vector<std::string>> result;
Napi::Env env;
Napi::ThreadSafeFunction tsfn;
// Custom run function with progress callback support
int run_with_progress(whisper_params &params, std::vector<std::vector<std::string>> &result) {
if (params.no_prints) {
whisper_log_set(cb_log_disable, NULL);
}
if (params.fname_inp.empty() && params.pcmf32.empty()) {
fprintf(stderr, "error: no input files or audio buffer specified\n");
return 2;
}
if (params.language != "auto" && whisper_lang_id(params.language.c_str()) == -1) {
fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str());
exit(0);
}
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) {
fprintf(stderr, "error: failed to initialize whisper context\n");
return 3;
}
// If params.pcmf32 provides, set params.fname_inp as "buffer"
if (!params.pcmf32.empty()) {
fprintf(stderr, "info: using audio buffer as input\n");
params.fname_inp.clear();
params.fname_inp.emplace_back("buffer");
}
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
const auto fname_inp = params.fname_inp[f];
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// If params.pcmf32 is empty, read input audio file
if (params.pcmf32.empty()) {
if (!::read_audio_data(fname_inp, pcmf32, pcmf32s, params.diarize)) {
fprintf(stderr, "error: failed to read audio file '%s'\n", fname_inp.c_str());
continue;
}
} else {
pcmf32 = params.pcmf32;
}
// Print system info
if (!params.no_prints) {
fprintf(stderr, "\n");
fprintf(stderr, "system_info: n_threads = %d / %d | %s\n",
params.n_threads*params.n_processors, std::thread::hardware_concurrency(), whisper_print_system_info());
}
// Print processing info
if (!params.no_prints) {
fprintf(stderr, "\n");
if (!whisper_is_multilingual(ctx)) {
if (params.language != "en" || params.translate) {
params.language = "en";
params.translate = false;
fprintf(stderr, "%s: WARNING: model is not multilingual, ignoring language and translation options\n", __func__);
}
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d, audio_ctx = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.no_timestamps ? 0 : 1,
params.audio_ctx);
fprintf(stderr, "\n");
}
// Run inference
{
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.strategy = params.beam_size > 1 ? WHISPER_SAMPLING_BEAM_SEARCH : WHISPER_SAMPLING_GREEDY;
wparams.print_realtime = false;
wparams.print_progress = params.print_progress;
wparams.print_timestamps = !params.no_timestamps;
wparams.print_special = params.print_special;
wparams.translate = params.translate;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx;
wparams.offset_ms = params.offset_t_ms;
wparams.duration_ms = params.duration_ms;
wparams.token_timestamps = params.output_wts || params.max_len > 0;
wparams.thold_pt = params.word_thold;
wparams.entropy_thold = params.entropy_thold;
wparams.logprob_thold = params.logprob_thold;
wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len;
wparams.audio_ctx = params.audio_ctx;
wparams.greedy.best_of = params.best_of;
wparams.beam_search.beam_size = params.beam_size;
wparams.initial_prompt = params.prompt.c_str();
wparams.no_timestamps = params.no_timestamps;
whisper_print_user_data user_data = { &params, &pcmf32s };
// This callback is called for each new segment
if (!wparams.print_realtime) {
wparams.new_segment_callback = whisper_print_segment_callback;
wparams.new_segment_callback_user_data = &user_data;
}
// Set progress callback
wparams.progress_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, int progress, void * user_data) {
ProgressWorker* worker = static_cast<ProgressWorker*>(user_data);
worker->OnProgress(progress);
};
wparams.progress_callback_user_data = this;
// Abort mechanism example
{
static bool is_aborted = false; // Note: this should be atomic to avoid data races
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
wparams.encoder_begin_callback_user_data = &is_aborted;
}
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
fprintf(stderr, "failed to process audio\n");
return 10;
}
}
}
const int n_segments = whisper_full_n_segments(ctx);
result.resize(n_segments);
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(ctx, i);
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
result[i].emplace_back(to_timestamp(t0, params.comma_in_time));
result[i].emplace_back(to_timestamp(t1, params.comma_in_time));
result[i].emplace_back(text);
}
whisper_print_timings(ctx);
whisper_free(ctx);
return 0;
}
};
Napi::Value whisper(const Napi::CallbackInfo& info) {
Napi::Env env = info.Env();
if (info.Length() <= 0 || !info[0].IsObject()) {
@ -332,6 +367,23 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
bool comma_in_time = whisper_params.Get("comma_in_time").As<Napi::Boolean>();
int32_t max_len = whisper_params.Get("max_len").As<Napi::Number>();
// support prompt
std::string prompt = "";
if (whisper_params.Has("prompt") && whisper_params.Get("prompt").IsString()) {
prompt = whisper_params.Get("prompt").As<Napi::String>();
}
// Add support for print_progress
bool print_progress = false;
if (whisper_params.Has("print_progress")) {
print_progress = whisper_params.Get("print_progress").As<Napi::Boolean>();
}
// Add support for progress_callback
Napi::Function progress_callback;
if (whisper_params.Has("progress_callback") && whisper_params.Get("progress_callback").IsFunction()) {
progress_callback = whisper_params.Get("progress_callback").As<Napi::Function>();
}
Napi::Value pcmf32Value = whisper_params.Get("pcmf32");
std::vector<float> pcmf32_vec;
@ -355,9 +407,12 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
params.pcmf32 = pcmf32_vec;
params.comma_in_time = comma_in_time;
params.max_len = max_len;
params.print_progress = print_progress;
params.prompt = prompt;
Napi::Function callback = info[1].As<Napi::Function>();
Worker* worker = new Worker(callback, params);
// Create a new Worker class with progress callback support
ProgressWorker* worker = new ProgressWorker(callback, params, progress_callback, env);
worker->Queue();
return env.Undefined();
}

View File

@ -19,6 +19,9 @@ const whisperParams = {
no_timestamps: false,
audio_ctx: 0,
max_len: 0,
progress_callback: (progress) => {
console.log(`progress: ${progress}%`);
}
};
const arguments = process.argv.slice(2);