mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-04-17 07:49:14 +00:00
whisper : experiments with Flash Attention in the decoder
This commit is contained in:
parent
f30b5d322c
commit
e2aa556a99
67
whisper.cpp
67
whisper.cpp
@ -1457,7 +1457,7 @@ static bool whisper_encode(
|
||||
layer.cross_attn_k_w,
|
||||
cur);
|
||||
|
||||
Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||
//Kcross = ggml_scale(ctx0, Kcross, ggml_new_f32(ctx0, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
struct ggml_tensor * Vcross = ggml_mul_mat(ctx0,
|
||||
layer.cross_attn_v_w,
|
||||
@ -1579,14 +1579,14 @@ static bool whisper_decode(
|
||||
Qcur),
|
||||
Qcur);
|
||||
|
||||
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
// note: no bias for Key
|
||||
struct ggml_tensor * Kcur = ggml_mul_mat(ctxL,
|
||||
layer.attn_k_w,
|
||||
cur);
|
||||
|
||||
Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
//Kcur = ggml_scale(ctxL, Kcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
struct ggml_tensor * Vcur = ggml_mul_mat(ctxL,
|
||||
layer.attn_v_w,
|
||||
@ -1609,6 +1609,33 @@ static bool whisper_decode(
|
||||
|
||||
// ------
|
||||
|
||||
#ifdef USE_FLASH_ATTN
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctxL,
|
||||
ggml_reshape_3d(ctxL,
|
||||
ggml_view_1d(ctxL, model.memory_k, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_k)*n_state),
|
||||
n_state/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_cpy(ctxL,
|
||||
ggml_permute(ctxL,
|
||||
ggml_reshape_3d(ctxL,
|
||||
ggml_view_1d(ctxL, model.memory_v, (n_past + N)*n_state, il*n_ctx*ggml_element_size(model.memory_v)*n_state),
|
||||
n_state/n_head, n_head, n_past + N),
|
||||
1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_past + N, n_state/n_head, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, true);
|
||||
#else
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
@ -1626,13 +1653,13 @@ static bool whisper_decode(
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctxL, K, Q);
|
||||
|
||||
//struct ggml_tensor * KQ_scaled =
|
||||
// ggml_scale(ctxL,
|
||||
// KQ,
|
||||
// ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
||||
// );
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctxL,
|
||||
KQ,
|
||||
ggml_new_f32(ctxL, 1.0f/sqrt(float(n_state)/n_head))
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ, n_past);
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctxL, KQ_scaled, n_past);
|
||||
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctxL, KQ_masked);
|
||||
|
||||
@ -1644,6 +1671,7 @@ static bool whisper_decode(
|
||||
1, 2, 0, 3);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||
#endif
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
||||
|
||||
@ -1689,7 +1717,7 @@ static bool whisper_decode(
|
||||
Qcur),
|
||||
Qcur);
|
||||
|
||||
Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
//Qcur = ggml_scale(ctxL, Qcur, ggml_new_f32(ctxL, pow(float(n_state)/n_head, -0.25)));
|
||||
|
||||
// Kcross is already scaled
|
||||
struct ggml_tensor * Kcross =
|
||||
@ -1704,6 +1732,24 @@ static bool whisper_decode(
|
||||
|
||||
// ------
|
||||
|
||||
#ifdef USE_FLASH_ATTN
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
Qcur,
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, n_state/n_head, n_head, N)),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K = ggml_permute(ctxL, Kcross, 0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * V =
|
||||
ggml_cpy(ctxL,
|
||||
ggml_permute(ctxL, Vcross, 1, 2, 0, 3),
|
||||
ggml_new_tensor_3d(ctxL, GGML_TYPE_F16, M, n_state/n_head, n_head)
|
||||
);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_flash_attn(ctxL, Q, K, V, false);
|
||||
#else
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctxL,
|
||||
ggml_cpy(ctxL,
|
||||
@ -1730,6 +1776,7 @@ static bool whisper_decode(
|
||||
struct ggml_tensor * V_trans = ggml_permute(ctxL, Vcross, 1, 2, 0, 3);
|
||||
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctxL, V_trans, KQ_soft_max);
|
||||
#endif
|
||||
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctxL, KQV, 0, 2, 1, 3);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user