whisper : use flash attention (#2152)

* whisper : use flash attention in the encoder

* whisper : add kv_pad

* whisper : remove extra backend instance (huh?)

* whisper : use FA for cross-attention

* whisper : use FA for self-attention

* whisper : simplify encoder FA

* whisper : add flash_attn runtime parameter

* scripts : add bench log

* scripts : add M1 Pro bench log
This commit is contained in:
Georgi Gerganov
2024-05-15 09:38:19 +03:00
committed by GitHub
parent 9d5771ae43
commit 7094ea5e75
13 changed files with 657 additions and 172 deletions

View File

@ -75,6 +75,7 @@ struct whisper_params {
bool print_progress = false;
bool no_timestamps = false;
bool use_gpu = true;
bool flash_attn = false;
std::string language = "en";
std::string prompt = "";
@ -178,6 +179,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
// server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
@ -502,7 +504,10 @@ int main(int argc, char ** argv) {
}
// whisper init
struct whisper_context_params cparams = whisper_context_default_params();
cparams.use_gpu = params.use_gpu;
cparams.use_gpu = params.use_gpu;
cparams.flash_attn = params.flash_attn;
if (!params.dtw.empty()) {
cparams.dtw_token_timestamps = true;
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;