mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-01-25 22:00:25 +00:00
whisper : fix excessive memory usage (#2443)
* whisper : fix KV cache allocation * whisper : reduce memory overhead from unused input tensors
This commit is contained in:
parent
2944cb72d9
commit
f62a546e03
@ -163,7 +163,6 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
|||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
//#define WHISPER_USE_FLASH_FF
|
|
||||||
#define WHISPER_MAX_DECODERS 8
|
#define WHISPER_MAX_DECODERS 8
|
||||||
#define WHISPER_MAX_NODES 4096
|
#define WHISPER_MAX_NODES 4096
|
||||||
|
|
||||||
@ -817,6 +816,9 @@ struct whisper_state {
|
|||||||
int32_t n_fail_p = 0; // number of logprob threshold failures
|
int32_t n_fail_p = 0; // number of logprob threshold failures
|
||||||
int32_t n_fail_h = 0; // number of entropy threshold failures
|
int32_t n_fail_h = 0; // number of entropy threshold failures
|
||||||
|
|
||||||
|
// number of decoders for which we have constructed the KV cache
|
||||||
|
int32_t kv_self_n_dec = 0;
|
||||||
|
|
||||||
// unified self-attention KV cache for all decoders
|
// unified self-attention KV cache for all decoders
|
||||||
whisper_kv_cache kv_self;
|
whisper_kv_cache kv_self;
|
||||||
|
|
||||||
@ -2096,9 +2098,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
|
|
||||||
struct ggml_tensor * Q =
|
struct ggml_tensor * Q =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_cpy(ctx0,
|
ggml_reshape_3d(ctx0, Qcur, n_state_head, n_head, n_ctx),
|
||||||
Qcur,
|
|
||||||
ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_state_head, n_head, n_ctx)),
|
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
if (wctx.params.flash_attn) {
|
if (wctx.params.flash_attn) {
|
||||||
@ -2125,9 +2125,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
} else {
|
} else {
|
||||||
struct ggml_tensor * K =
|
struct ggml_tensor * K =
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_cpy(ctx0,
|
ggml_cast(ctx0,
|
||||||
Kcur,
|
ggml_reshape_3d(ctx0, Kcur, n_state_head, n_head, n_ctx),
|
||||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_state_head, n_head, n_ctx)),
|
wctx.itype),
|
||||||
0, 2, 1, 3);
|
0, 2, 1, 3);
|
||||||
|
|
||||||
// K * Q
|
// K * Q
|
||||||
@ -2136,22 +2136,19 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
|
||||||
|
|
||||||
struct ggml_tensor * V =
|
struct ggml_tensor * V =
|
||||||
ggml_cpy(ctx0,
|
ggml_cast(ctx0,
|
||||||
ggml_permute(ctx0,
|
ggml_permute(ctx0,
|
||||||
ggml_reshape_3d(ctx0,
|
ggml_reshape_3d(ctx0,
|
||||||
Vcur,
|
Vcur,
|
||||||
n_state_head, n_head, n_ctx),
|
n_state_head, n_head, n_ctx),
|
||||||
1, 2, 0, 3),
|
1, 2, 0, 3),
|
||||||
ggml_new_tensor_3d(ctx0, wctx.itype, n_ctx, n_state_head, n_head)
|
wctx.itype);
|
||||||
);
|
|
||||||
|
|
||||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||||
|
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
cur = ggml_cpy(ctx0,
|
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_ctx);
|
||||||
KQV_merged,
|
|
||||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2181,11 +2178,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
layer.mlp_ln_b);
|
layer.mlp_ln_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifdef WHISPER_USE_FLASH_FF
|
|
||||||
cur = ggml_flash_ff(ctx0,
|
|
||||||
ggml_cpy(ctx0, cur, ggml_new_tensor_2d(ctx0, wstate.itype, n_state, n_ctx)),
|
|
||||||
layer.mlp_0_w, layer.mlp_0_b, layer.mlp_1_w, layer.mlp_1_b);
|
|
||||||
#else
|
|
||||||
// fully connected
|
// fully connected
|
||||||
cur = ggml_mul_mat(ctx0,
|
cur = ggml_mul_mat(ctx0,
|
||||||
layer.mlp_0_w,
|
layer.mlp_0_w,
|
||||||
@ -2202,7 +2194,6 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
cur);
|
cur);
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
|
cur = ggml_add(ctx0, cur, layer.mlp_1_b);
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
inpL = ggml_add(ctx0, cur, inpFF);
|
inpL = ggml_add(ctx0, cur, inpFF);
|
||||||
@ -2578,9 +2569,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
|
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
cur = ggml_cpy(ctx0,
|
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
||||||
KQV_merged,
|
|
||||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2687,9 +2676,7 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
|
|
||||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||||
|
|
||||||
cur = ggml_cpy(ctx0,
|
cur = ggml_cont_2d(ctx0, KQV_merged, n_state, n_tokens);
|
||||||
KQV_merged,
|
|
||||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_tokens));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3403,14 +3390,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
|
whisper_mel_init(state->mel, state->backends[0], n_len, n_len, n_mel);
|
||||||
}
|
}
|
||||||
|
|
||||||
// at this point, we don't know yet how many decoders will be used, so we overallocate 3x ctx
|
// at this point, we don't know yet how many decoders will be used
|
||||||
// in theory, there can be a case where this is not enough, but in practice it should always be enough
|
// later during decoding, if more decoders are used, we will recreate the KV cache respectively
|
||||||
const int factor = 3;
|
state->kv_self_n_dec = 1;
|
||||||
|
|
||||||
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
||||||
ctx->model.hparams.n_text_state,
|
ctx->model.hparams.n_text_state,
|
||||||
ctx->model.hparams.n_text_layer,
|
ctx->model.hparams.n_text_layer,
|
||||||
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
GGML_PAD(ctx->model.hparams.n_text_ctx, 256))) {
|
||||||
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
whisper_free_state(state);
|
whisper_free_state(state);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
@ -5775,13 +5761,34 @@ int whisper_full_with_state(
|
|||||||
}
|
}
|
||||||
WHISPER_LOG_DEBUG("\n\n");
|
WHISPER_LOG_DEBUG("\n\n");
|
||||||
|
|
||||||
|
// recreate the KV cache if the number of decoders has changed
|
||||||
|
if (state->kv_self_n_dec < n_decoders_cur) {
|
||||||
|
WHISPER_LOG_DEBUG("%s: recreating KV cache: n_decoders_cur = %d\n", __func__, n_decoders_cur);
|
||||||
|
|
||||||
|
whisper_kv_cache_free(state->kv_self);
|
||||||
|
|
||||||
|
// overallocate to workaround KV cache fragmentation issues
|
||||||
|
const int factor = n_decoders_cur > 1 ? n_decoders_cur + 2 : 1;
|
||||||
|
|
||||||
|
if (!whisper_kv_cache_init(state->kv_self, state->backends[0], ctx->itype,
|
||||||
|
ctx->model.hparams.n_text_state,
|
||||||
|
ctx->model.hparams.n_text_layer,
|
||||||
|
GGML_PAD(ctx->model.hparams.n_text_ctx, 256)*factor)) {
|
||||||
|
WHISPER_LOG_ERROR("%s: whisper_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
|
whisper_free_state(state);
|
||||||
|
return -7;
|
||||||
|
}
|
||||||
|
|
||||||
|
state->kv_self_n_dec = n_decoders_cur;
|
||||||
|
}
|
||||||
|
|
||||||
whisper_kv_cache_clear(state->kv_self);
|
whisper_kv_cache_clear(state->kv_self);
|
||||||
|
|
||||||
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
whisper_batch_prep_legacy(state->batch, prompt.data(), prompt.size(), 0, 0);
|
||||||
|
|
||||||
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
||||||
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
||||||
return -7;
|
return -8;
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -6081,7 +6088,7 @@ int whisper_full_with_state(
|
|||||||
|
|
||||||
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
if (!whisper_decode_internal(*ctx, *state, state->batch, params.n_threads, false, params.abort_callback, params.abort_callback_user_data)) {
|
||||||
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
WHISPER_LOG_ERROR("%s: failed to decode\n", __func__);
|
||||||
return -8;
|
return -9;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64_t t_start_sample_us = ggml_time_us();
|
const int64_t t_start_sample_us = ggml_time_us();
|
||||||
|
Loading…
x
Reference in New Issue
Block a user