mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-12 20:18:08 +00:00
whisper : use flash attention (#2152)
* whisper : use flash attention in the encoder * whisper : add kv_pad * whisper : remove extra backend instance (huh?) * whisper : use FA for cross-attention * whisper : use FA for self-attention * whisper : simplify encoder FA * whisper : add flash_attn runtime parameter * scripts : add bench log * scripts : add M1 Pro bench log
This commit is contained in:
@ -70,6 +70,7 @@ struct whisper_params {
|
||||
bool no_timestamps = false;
|
||||
bool log_score = false;
|
||||
bool use_gpu = true;
|
||||
bool flash_attn = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
@ -168,7 +169,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
||||
else if (arg == "-ls" || arg == "--log-score") { params.log_score = true; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
|
||||
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
||||
else if ( arg == "--suppress-regex") { params.suppress_regex = argv[++i]; }
|
||||
else if ( arg == "--grammar") { params.grammar = argv[++i]; }
|
||||
else if ( arg == "--grammar-rule") { params.grammar_rule = argv[++i]; }
|
||||
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); }
|
||||
@ -234,6 +236,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -dtw MODEL --dtw MODEL [%-7s] compute token-level timestamps\n", params.dtw.c_str());
|
||||
fprintf(stderr, " -ls, --log-score [%-7s] log best decoder scores of tokens\n", params.log_score?"true":"false");
|
||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true");
|
||||
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention\n", params.flash_attn ? "true" : "false");
|
||||
fprintf(stderr, " --suppress-regex REGEX [%-7s] regular expression matching tokens to suppress\n", params.suppress_regex.c_str());
|
||||
fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str());
|
||||
fprintf(stderr, " --grammar-rule RULE [%-7s] top-level GBNF grammar rule name\n", params.grammar_rule.c_str());
|
||||
@ -977,7 +980,9 @@ int main(int argc, char ** argv) {
|
||||
// whisper init
|
||||
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
||||
if (!params.dtw.empty()) {
|
||||
cparams.dtw_token_timestamps = true;
|
||||
|
Reference in New Issue
Block a user