mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-13 12:38:08 +00:00
whisper : use flash attention (#2152)
* whisper : use flash attention in the encoder * whisper : add kv_pad * whisper : remove extra backend instance (huh?) * whisper : use FA for cross-attention * whisper : use FA for self-attention * whisper : simplify encoder FA * whisper : add flash_attn runtime parameter * scripts : add bench log * scripts : add M1 Pro bench log
This commit is contained in:
@ -75,6 +75,7 @@ struct whisper_params {
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
bool use_gpu = true;
|
||||
bool flash_attn = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string prompt = "";
|
||||
@ -178,6 +179,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
|
||||
else if (arg == "-oved" || arg == "--ov-e-device") { params.openvino_encode_device = argv[++i]; }
|
||||
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
|
||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
||||
// server params
|
||||
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
|
||||
else if ( arg == "--host") { sparams.hostname = argv[++i]; }
|
||||
@ -502,7 +504,10 @@ int main(int argc, char ** argv) {
|
||||
}
|
||||
// whisper init
|
||||
struct whisper_context_params cparams = whisper_context_default_params();
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
|
||||
cparams.use_gpu = params.use_gpu;
|
||||
cparams.flash_attn = params.flash_attn;
|
||||
|
||||
if (!params.dtw.empty()) {
|
||||
cparams.dtw_token_timestamps = true;
|
||||
cparams.dtw_aheads_preset = WHISPER_AHEADS_NONE;
|
||||
|
Reference in New Issue
Block a user