node : add flash_attn param (#2170)

This commit is contained in:
Pedro Probst 2024-05-20 03:08:48 -03:00 committed by GitHub
parent 4798be1f9a
commit adee3f9c1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 6 additions and 0 deletions

View File

@ -12,6 +12,7 @@ const whisperParamsMock = {
model: path.join(__dirname, "../../../models/ggml-base.en.bin"), model: path.join(__dirname, "../../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../../samples/jfk.wav"), fname_inp: path.join(__dirname, "../../../samples/jfk.wav"),
use_gpu: true, use_gpu: true,
flash_attn: false,
no_prints: true, no_prints: true,
comma_in_time: false, comma_in_time: false,
translate: true, translate: true,

View File

@ -39,6 +39,7 @@ struct whisper_params {
bool no_timestamps = false; bool no_timestamps = false;
bool no_prints = false; bool no_prints = false;
bool use_gpu = true; bool use_gpu = true;
bool flash_attn = false;
bool comma_in_time = true; bool comma_in_time = true;
std::string language = "en"; std::string language = "en";
@ -146,6 +147,7 @@ int run(whisper_params &params, std::vector<std::vector<std::string>> &result) {
struct whisper_context_params cparams = whisper_context_default_params(); struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu; cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams);
if (ctx == nullptr) { if (ctx == nullptr) {
@ -326,6 +328,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
std::string model = whisper_params.Get("model").As<Napi::String>(); std::string model = whisper_params.Get("model").As<Napi::String>();
std::string input = whisper_params.Get("fname_inp").As<Napi::String>(); std::string input = whisper_params.Get("fname_inp").As<Napi::String>();
bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>(); bool use_gpu = whisper_params.Get("use_gpu").As<Napi::Boolean>();
bool flash_attn = whisper_params.Get("flash_attn").As<Napi::Boolean>();
bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>(); bool no_prints = whisper_params.Get("no_prints").As<Napi::Boolean>();
bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>(); bool no_timestamps = whisper_params.Get("no_timestamps").As<Napi::Boolean>();
int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>(); int32_t audio_ctx = whisper_params.Get("audio_ctx").As<Napi::Number>();
@ -346,6 +349,7 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
params.model = model; params.model = model;
params.fname_inp.emplace_back(input); params.fname_inp.emplace_back(input);
params.use_gpu = use_gpu; params.use_gpu = use_gpu;
params.flash_attn = flash_attn;
params.no_prints = no_prints; params.no_prints = no_prints;
params.no_timestamps = no_timestamps; params.no_timestamps = no_timestamps;
params.audio_ctx = audio_ctx; params.audio_ctx = audio_ctx;

View File

@ -12,6 +12,7 @@ const whisperParams = {
model: path.join(__dirname, "../../models/ggml-base.en.bin"), model: path.join(__dirname, "../../models/ggml-base.en.bin"),
fname_inp: path.join(__dirname, "../../samples/jfk.wav"), fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
use_gpu: true, use_gpu: true,
flash_attn: false,
no_prints: true, no_prints: true,
comma_in_time: false, comma_in_time: false,
translate: true, translate: true,