diff --git a/whisper.cpp b/whisper.cpp index e8d9f0c9..fbcba2c0 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1457,7 +1457,7 @@ static bool whisper_encode( layer.cross_attn_k_w, cur); - Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); + //Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25))); struct ggml_tensor * Vcross = ggml_mul_mat(ctx0, layer.cross_attn_v_w, @@ -1579,14 +1579,14 @@ static bool whisper_decode( Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); // note: no bias for Key struct ggml_tensor * Kcur = ggml_mul_mat(ctxL, layer.attn_k_w, cur); - Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); struct ggml_tensor * Vcur = ggml_mul_mat(ctxL, layer.attn_v_w, @@ -1609,6 +1609,33 @@ static bool whisper_decode( // ------ +#ifdef USE_FLASH_ATTN + struct ggml_tensor * Q = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Qcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = + ggml_permute(ctxL, + ggml_reshape_3d(ctxL, + ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state), + n_state/n_head, n_head, n_past + N), + 0, 2, 1, 3); + + struct ggml_tensor * V = + ggml_cpy(ctxL, + ggml_permute(ctxL, + ggml_reshape_3d(ctxL, + ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state), + n_state/n_head, n_head, n_past + N), + 1, 2, 0, 3), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true); +#else struct ggml_tensor * Q = ggml_permute(ctxL, ggml_cpy(ctxL, @@ -1626,13 +1653,13 @@ static bool whisper_decode( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q); - //struct ggml_tensor * KQ_scaled = - // ggml_scale(ctxL, - // KQ, - // ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) - // ); + struct ggml_tensor * KQ_scaled = + ggml_scale(ctxL, + KQ, + ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head)) + ); - struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past); + struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past); struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked); @@ -1644,6 +1671,7 @@ static bool whisper_decode( 1, 2, 0, 3); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); +#endif struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3); @@ -1689,7 +1717,7 @@ static bool whisper_decode( Qcur), Qcur); - Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); + //Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25))); // Kcross is already scaled struct ggml_tensor * Kcross = @@ -1704,6 +1732,24 @@ static bool whisper_decode( // ------ +#ifdef USE_FLASH_ATTN + struct ggml_tensor * Q = + ggml_permute(ctxL, + ggml_cpy(ctxL, + Qcur, + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)), + 0, 2, 1, 3); + + struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3); + + struct ggml_tensor * V = + ggml_cpy(ctxL, + ggml_permute(ctxL, Vcross, 1, 2, 0, 3), + ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head) + ); + + struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false); +#else struct ggml_tensor * Q = ggml_permute(ctxL, ggml_cpy(ctxL, @@ -1730,6 +1776,7 @@ static bool whisper_decode( struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3); struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max); +#endif struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);