diff --git a/whisper.cpp b/whisper.cpp index f31309ed..f3daf5b6 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -147,7 +147,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text } \ } while (0) -//#define WHISPER_USE_FLASH_ATTN //#define WHISPER_USE_FLASH_FF #define WHISPER_MAX_DECODERS 8 #define WHISPER_MAX_NODES 4096 @@ -1951,32 +1950,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // ------ -#ifdef WHISPER_USE_FLASH_ATTN - struct ggml_tensor * Q = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Qcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * K = - ggml_permute(ctx0, - ggml_cpy(ctx0, - Kcur, - ggml_new_tensor_3d(ctx0, wctx.itype, n_state/n_head, n_head, n_ctx)), - 0, 2, 1, 3); - - struct ggml_tensor * V = - ggml_cpy(ctx0, - ggml_permute(ctx0, - ggml_reshape_3d(ctx0, - Vcur, - n_state/n_head, n_head, n_ctx), - 1, 2, 0, 3), - ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state/n_head, n_head)); - - struct ggml_tensor * KQV = ggml_flash_attn(ctx0, Q, K, V, false); -#else struct ggml_tensor * Q = ggml_permute(ctx0, ggml_cpy(ctx0, @@ -1994,9 +1967,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQscale); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_scaled); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); struct ggml_tensor * V = ggml_cpy(ctx0, @@ -2009,7 +1980,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ); struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); -#endif + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); cur = ggml_cpy(ctx0, @@ -2323,6 +2294,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( ggml_set_name(KQ_mask, "KQ_mask"); ggml_set_input(KQ_mask); + struct ggml_tensor * KQ_mask_f16 = ggml_cast(ctx0, KQ_mask, GGML_TYPE_F16); + // token encoding + position encoding struct ggml_tensor * cur = ggml_add(ctx0, @@ -2406,12 +2379,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - //struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); - - //struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ, n_past); - struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ, KQ_mask); - - struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, KQ_mask_f16, 1.0f, 0.0f); struct ggml_tensor * V = ggml_view_3d(ctx0, kv_self.v,