diff --git a/examples/addon.node/__test__/whisper.spec.js b/examples/addon.node/__test__/whisper.spec.js index 9ba86b62..1ee888a1 100644 --- a/examples/addon.node/__test__/whisper.spec.js +++ b/examples/addon.node/__test__/whisper.spec.js @@ -12,6 +12,7 @@ const whisperParamsMock = { model: path.join(__dirname, "../../../models/ggml-base.en.bin"), fname_inp: path.join(__dirname, "../../../samples/jfk.wav"), use_gpu: true, + flash_attn: false, no_prints: true, comma_in_time: false, translate: true, diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 8125e5dd..53bf1abb 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -39,6 +39,7 @@ struct whisper_params { bool no_timestamps = false; bool no_prints = false; bool use_gpu = true; + bool flash_attn = false; bool comma_in_time = true; std::string language = "en"; @@ -146,6 +147,7 @@ int run(whisper_params ¶ms, std::vector> &result) { struct whisper_context_params cparams = whisper_context_default_params(); cparams.use_gpu = params.use_gpu; + cparams.flash_attn = params.flash_attn; struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); if (ctx == nullptr) { @@ -326,6 +328,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { std::string model = whisper_params.Get("model").As(); std::string input = whisper_params.Get("fname_inp").As(); bool use_gpu = whisper_params.Get("use_gpu").As(); + bool flash_attn = whisper_params.Get("flash_attn").As(); bool no_prints = whisper_params.Get("no_prints").As(); bool no_timestamps = whisper_params.Get("no_timestamps").As(); int32_t audio_ctx = whisper_params.Get("audio_ctx").As(); @@ -346,6 +349,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) { params.model = model; params.fname_inp.emplace_back(input); params.use_gpu = use_gpu; + params.flash_attn = flash_attn; params.no_prints = no_prints; params.no_timestamps = no_timestamps; params.audio_ctx = audio_ctx; diff --git a/examples/addon.node/index.js b/examples/addon.node/index.js index 09b33c54..643ee756 100644 --- a/examples/addon.node/index.js +++ b/examples/addon.node/index.js @@ -12,6 +12,7 @@ const whisperParams = { model: path.join(__dirname, "../../models/ggml-base.en.bin"), fname_inp: path.join(__dirname, "../../samples/jfk.wav"), use_gpu: true, + flash_attn: false, no_prints: true, comma_in_time: false, translate: true,