whisper : allocate encoder and decoder using ggml-alloc

This commit is contained in:
Georgi Gerganov 2023-09-10 19:50:34 +03:00
parent 949ab6328d
commit bed5ad69dd
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -606,12 +606,18 @@ struct whisper_state {
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute;
// ggml-alloc
std::vector<uint8_t> buf_encode;
std::vector<uint8_t> buf_encode_post;
std::vector<uint8_t> buf_decode;
ggml_allocr * alloc_encode = NULL;
ggml_allocr * alloc_encode_post = NULL;
ggml_allocr * alloc_decode = NULL;
// result of the encoder
struct ggml_tensor * embd_enc = NULL;
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
@ -701,7 +707,7 @@ static bool kv_cache_init(
const int64_t n_mem = n_text_layer*n_ctx;
const int64_t n_elements = n_text_state*n_mem;
const size_t mem_bytes = ggml_type_size(wtype)*n_elements;
const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
cache.buf.resize(mem_bytes);
@ -1385,6 +1391,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
const int n_layer = hparams.n_audio_layer;
const int n_mels = hparams.n_mels;
assert(mel_inp.n_mel == n_mels);
struct ggml_init_params params = {
@ -1397,9 +1404,11 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_allocr * alloc = wstate.alloc_encode;
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
assert(mel->type == GGML_TYPE_F32);
{
if (!ggml_allocr_is_measure(alloc)) {
float * dst = (float *) mel->data;
memset(dst, 0, ggml_nbytes(mel));
@ -1689,6 +1698,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
}
#endif
wstate.embd_enc = cur;
////////////////////////////////////////////////////////////////////////////
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
@ -1706,8 +1717,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// pre-compute cross-attention memory
static struct ggml_cgraph * whisper_build_graph_encoder_post(
whisper_context & wctx,
whisper_state & wstate,
struct ggml_tensor * embd_enc) {
whisper_state & wstate) {
const auto & model = wctx.model;
const auto & hparams = model.hparams;
@ -1725,7 +1735,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post(
ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur = embd_enc;
//ggml_allocr * alloc = wstate.alloc_encode_post;
struct ggml_tensor * cur = wstate.embd_enc;
// TODO: hack to disconnect the encoded features from the previous graph
cur->op = GGML_OP_NONE;
@ -1826,13 +1838,19 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_allocr * alloc = wstate.alloc_decode;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
if (!ggml_allocr_is_measure(alloc)) {
memcpy(embd->data, tokens, N*ggml_element_size(embd));
}
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
if (!ggml_allocr_is_measure(alloc)) {
for (int i = 0; i < N; ++i) {
((int32_t *) position->data)[i] = n_past + i;
}
}
// token encoding + position encoding
struct ggml_tensor * cur =
@ -2637,9 +2655,55 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
static const size_t tensor_alignment = 32;
state->alloc_encode = ggml_allocr_new_measure(tensor_alignment);
state->alloc_encode_post = ggml_allocr_new_measure(tensor_alignment);
state->alloc_decode = ggml_allocr_new_measure(tensor_alignment);
// encoder allocator
{
ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0);
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode, gf) + tensor_alignment;
ggml_allocr_free(state->alloc_encode);
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
state->buf_encode.resize(alloc_size);
state->alloc_encode = ggml_allocr_new(state->buf_encode.data(), state->buf_encode.size(), tensor_alignment);
}
// encoder_post allocator
{
ggml_cgraph * gf = whisper_build_graph_encoder_post(*ctx, *state);
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode_post, gf) + tensor_alignment;
ggml_allocr_free(state->alloc_encode_post);
log("%s: compute buffer (encode_post) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
state->buf_encode_post.resize(alloc_size);
state->alloc_encode_post = ggml_allocr_new(state->buf_encode_post.data(), state->buf_encode_post.size(), tensor_alignment);
}
// decoder allocator
{
const auto & hparams = ctx->model.hparams;
const int n_tokens = hparams.n_text_ctx/2;
const int n_past = hparams.n_text_ctx/2; // TODO: double-check
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past);
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_decode, gf) + tensor_alignment;
ggml_allocr_free(state->alloc_decode);
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
state->buf_decode.resize(alloc_size);
state->alloc_decode = ggml_allocr_new(state->buf_decode.data(), state->buf_decode.size(), tensor_alignment);
}
state->rng = std::mt19937(0);
return state;
@ -2862,6 +2926,18 @@ void whisper_free_state(struct whisper_state * state)
}
#endif
if (state->alloc_encode) {
ggml_allocr_free(state->alloc_encode);
}
if (state->alloc_encode_post) {
ggml_allocr_free(state->alloc_encode_post);
}
if (state->alloc_decode) {
ggml_allocr_free(state->alloc_decode);
}
delete state;
}
}