mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-19 12:47:52 +00:00
whisper : allocate encoder and decoder using ggml-alloc
This commit is contained in:
parent
949ab6328d
commit
bed5ad69dd
86
whisper.cpp
86
whisper.cpp
@ -606,12 +606,18 @@ struct whisper_state {
|
||||
// memory buffers used by encode / decode contexts
|
||||
std::vector<uint8_t> buf_compute;
|
||||
|
||||
// ggml-alloc
|
||||
std::vector<uint8_t> buf_encode;
|
||||
std::vector<uint8_t> buf_encode_post;
|
||||
std::vector<uint8_t> buf_decode;
|
||||
|
||||
ggml_allocr * alloc_encode = NULL;
|
||||
ggml_allocr * alloc_encode_post = NULL;
|
||||
ggml_allocr * alloc_decode = NULL;
|
||||
|
||||
// result of the encoder
|
||||
struct ggml_tensor * embd_enc = NULL;
|
||||
|
||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||
std::vector<float> logits;
|
||||
|
||||
@ -701,7 +707,7 @@ static bool kv_cache_init(
|
||||
const int64_t n_mem = n_text_layer*n_ctx;
|
||||
const int64_t n_elements = n_text_state*n_mem;
|
||||
|
||||
const size_t mem_bytes = ggml_type_size(wtype)*n_elements;
|
||||
const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
|
||||
|
||||
cache.buf.resize(mem_bytes);
|
||||
|
||||
@ -1385,6 +1391,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
const int n_layer = hparams.n_audio_layer;
|
||||
|
||||
const int n_mels = hparams.n_mels;
|
||||
|
||||
assert(mel_inp.n_mel == n_mels);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
@ -1397,9 +1404,11 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
ggml_allocr * alloc = wstate.alloc_encode;
|
||||
|
||||
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
||||
assert(mel->type == GGML_TYPE_F32);
|
||||
{
|
||||
if (!ggml_allocr_is_measure(alloc)) {
|
||||
float * dst = (float *) mel->data;
|
||||
memset(dst, 0, ggml_nbytes(mel));
|
||||
|
||||
@ -1689,6 +1698,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
}
|
||||
#endif
|
||||
|
||||
wstate.embd_enc = cur;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
||||
@ -1706,8 +1717,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
// pre-compute cross-attention memory
|
||||
static struct ggml_cgraph * whisper_build_graph_encoder_post(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate,
|
||||
struct ggml_tensor * embd_enc) {
|
||||
whisper_state & wstate) {
|
||||
const auto & model = wctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
@ -1725,7 +1735,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
struct ggml_tensor * cur = embd_enc;
|
||||
//ggml_allocr * alloc = wstate.alloc_encode_post;
|
||||
|
||||
struct ggml_tensor * cur = wstate.embd_enc;
|
||||
|
||||
// TODO: hack to disconnect the encoded features from the previous graph
|
||||
cur->op = GGML_OP_NONE;
|
||||
@ -1826,13 +1838,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
ggml_allocr * alloc = wstate.alloc_decode;
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
if (!ggml_allocr_is_measure(alloc)) {
|
||||
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
||||
}
|
||||
|
||||
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
if (!ggml_allocr_is_measure(alloc)) {
|
||||
for (int i = 0; i < N; ++i) {
|
||||
((int32_t *) position->data)[i] = n_past + i;
|
||||
}
|
||||
}
|
||||
|
||||
// token encoding + position encoding
|
||||
struct ggml_tensor * cur =
|
||||
@ -2637,9 +2655,55 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
static const size_t tensor_alignment = 32;
|
||||
|
||||
state->alloc_encode = ggml_allocr_new_measure(tensor_alignment);
|
||||
state->alloc_encode_post = ggml_allocr_new_measure(tensor_alignment);
|
||||
state->alloc_decode = ggml_allocr_new_measure(tensor_alignment);
|
||||
|
||||
// encoder allocator
|
||||
{
|
||||
ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0);
|
||||
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode, gf) + tensor_alignment;
|
||||
ggml_allocr_free(state->alloc_encode);
|
||||
|
||||
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
|
||||
|
||||
state->buf_encode.resize(alloc_size);
|
||||
state->alloc_encode = ggml_allocr_new(state->buf_encode.data(), state->buf_encode.size(), tensor_alignment);
|
||||
}
|
||||
|
||||
// encoder_post allocator
|
||||
{
|
||||
ggml_cgraph * gf = whisper_build_graph_encoder_post(*ctx, *state);
|
||||
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode_post, gf) + tensor_alignment;
|
||||
ggml_allocr_free(state->alloc_encode_post);
|
||||
|
||||
log("%s: compute buffer (encode_post) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
|
||||
|
||||
state->buf_encode_post.resize(alloc_size);
|
||||
state->alloc_encode_post = ggml_allocr_new(state->buf_encode_post.data(), state->buf_encode_post.size(), tensor_alignment);
|
||||
}
|
||||
|
||||
// decoder allocator
|
||||
{
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
|
||||
const int n_tokens = hparams.n_text_ctx/2;
|
||||
const int n_past = hparams.n_text_ctx/2; // TODO: double-check
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past);
|
||||
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_decode, gf) + tensor_alignment;
|
||||
ggml_allocr_free(state->alloc_decode);
|
||||
|
||||
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
|
||||
|
||||
state->buf_decode.resize(alloc_size);
|
||||
state->alloc_decode = ggml_allocr_new(state->buf_decode.data(), state->buf_decode.size(), tensor_alignment);
|
||||
}
|
||||
|
||||
state->rng = std::mt19937(0);
|
||||
|
||||
return state;
|
||||
@ -2862,6 +2926,18 @@ void whisper_free_state(struct whisper_state * state)
|
||||
}
|
||||
#endif
|
||||
|
||||
if (state->alloc_encode) {
|
||||
ggml_allocr_free(state->alloc_encode);
|
||||
}
|
||||
|
||||
if (state->alloc_encode_post) {
|
||||
ggml_allocr_free(state->alloc_encode_post);
|
||||
}
|
||||
|
||||
if (state->alloc_decode) {
|
||||
ggml_allocr_free(state->alloc_decode);
|
||||
}
|
||||
|
||||
delete state;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user