From 2770d46ef5a2062dc2f97af4640d8c97d838efba Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 11 Sep 2023 15:04:33 +0300 Subject: [PATCH] whisper : refactor ggml-alloc init --- whisper.cpp | 139 +++++++++++++++++++++++++++------------------------- 1 file changed, 72 insertions(+), 67 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 0463ecdb..10980cfd 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -618,20 +618,26 @@ struct whisper_state { // buffer for swapping KV caches between decoders during beam-search std::vector kv_swap_bufs; - // memory buffers used by encode / decode contexts - std::vector buf_compute; - // reusable buffer for `struct ggml_graph_plan.work_data` std::vector work_buffer; - // ggml-alloc - std::vector buf_encode; - std::vector buf_encode_post; - std::vector buf_decode; + // ggml-alloc: + // - stores meta info about the intermediate tensors into the `meta_*` buffers + // - stores the actual tensor data into the `data_*` buffers - ggml_allocr * alloc_encode = NULL; - ggml_allocr * alloc_encode_post = NULL; - ggml_allocr * alloc_decode = NULL; + ggml_allocr * alloc_encode = NULL; + ggml_allocr * alloc_cross = NULL; + ggml_allocr * alloc_decode = NULL; + + // meta data + std::vector meta_encode; + std::vector meta_cross; + std::vector meta_decode; + + // tensor data + std::vector data_encode; + std::vector data_cross; + std::vector data_decode; // result of the encoder struct ggml_tensor * embd_enc = NULL; @@ -1411,8 +1417,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder( const int n_mels = hparams.n_mels; struct ggml_init_params params = { - /*.mem_size =*/ wstate.buf_compute.size(), - /*.mem_buffer =*/ wstate.buf_compute.data(), + /*.mem_size =*/ wstate.meta_encode.size(), + /*.mem_buffer =*/ wstate.meta_encode.data(), /*.no_alloc =*/ true, }; @@ -1746,7 +1752,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder( } // pre-compute cross-attention memory -static struct ggml_cgraph * whisper_build_graph_encoder_post( +static struct ggml_cgraph * whisper_build_graph_cross( whisper_context & wctx, whisper_state & wstate) { const auto & model = wctx.model; @@ -1757,8 +1763,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post( const int n_head = hparams.n_audio_head; struct ggml_init_params params = { - /*.mem_size =*/ wstate.buf_compute.size(), - /*.mem_buffer =*/ wstate.buf_compute.data(), + /*.mem_size =*/ wstate.meta_cross.size(), + /*.mem_buffer =*/ wstate.meta_cross.data(), /*.no_alloc =*/ true, }; @@ -1766,7 +1772,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post( ggml_cgraph * gf = ggml_new_graph(ctx0); - ggml_allocr * alloc = wstate.alloc_encode_post; + ggml_allocr * alloc = wstate.alloc_cross; struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); @@ -1863,13 +1869,13 @@ static bool whisper_encode_internal( //printf("n: %d\n", ggml_nelements(cur)); } - // encoder_post + // cross { - auto & alloc = wstate.alloc_encode_post; + auto & alloc = wstate.alloc_cross; ggml_allocr_reset(alloc); - ggml_cgraph * gf = whisper_build_graph_encoder_post(wctx, wstate); + ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate); ggml_allocr_alloc_graph(alloc, gf); @@ -1924,8 +1930,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder( //WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx); struct ggml_init_params params = { - /*.mem_size =*/ wstate.buf_compute.size(), - /*.mem_buffer =*/ wstate.buf_compute.data(), + /*.mem_size =*/ wstate.meta_decode.size(), + /*.mem_buffer =*/ wstate.meta_decode.data(), /*.no_alloc =*/ true, }; @@ -2733,8 +2739,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - log("debug CI - checkpoint 0\n"); - if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) { log("%s: kv_cache_init() failed for cross-attention cache\n", __func__); delete state; @@ -2746,8 +2750,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } - log("debug CI - checkpoint 1\n"); - #ifdef WHISPER_USE_COREML const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); @@ -2765,70 +2767,73 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { } #endif - log("debug CI - checkpoint 2\n"); - state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); - log("debug CI - checkpoint 3\n"); - state->logits_id.reserve(ctx->model.hparams.n_vocab); - log("debug CI - checkpoint 4\n"); - // TAGS: WHISPER_DECODER_INIT state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); - log("debug CI - checkpoint 5\n"); - state->decoders[0].probs.reserve(ctx->vocab.n_vocab); state->decoders[0].logits.reserve(ctx->vocab.n_vocab); state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); - log("debug CI - checkpoint 6\n"); - - state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); - - log("debug CI - checkpoint 7\n"); - static const size_t tensor_alignment = 32; - log("debug CI - checkpoint 8\n"); - - state->alloc_encode = ggml_allocr_new_measure(tensor_alignment); - log("debug CI - checkpoint 9\n"); - state->alloc_encode_post = ggml_allocr_new_measure(tensor_alignment); - log("debug CI - checkpoint 10\n"); - state->alloc_decode = ggml_allocr_new_measure(tensor_alignment); - log("debug CI - checkpoint 11\n"); - // encoder allocator { + auto & alloc = state->alloc_encode; + auto & meta = state->meta_encode; + auto & data = state->data_encode; + + meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + + alloc = ggml_allocr_new_measure(tensor_alignment); + ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0); - const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode, gf) + tensor_alignment; - ggml_allocr_free(state->alloc_encode); + const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; - log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0); + ggml_allocr_free(alloc); - state->buf_encode.resize(alloc_size); - state->alloc_encode = ggml_allocr_new(state->buf_encode.data(), state->buf_encode.size(), tensor_alignment); + log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0); + + data.resize(alloc_size); + alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); } - // encoder_post allocator + // cross allocator { - ggml_cgraph * gf = whisper_build_graph_encoder_post(*ctx, *state); + auto & alloc = state->alloc_cross; + auto & meta = state->meta_cross; + auto & data = state->data_cross; - const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode_post, gf) + tensor_alignment; - ggml_allocr_free(state->alloc_encode_post); + meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); - log("%s: compute buffer (encode_post) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0); + alloc = ggml_allocr_new_measure(tensor_alignment); - state->buf_encode_post.resize(alloc_size); - state->alloc_encode_post = ggml_allocr_new(state->buf_encode_post.data(), state->buf_encode_post.size(), tensor_alignment); + ggml_cgraph * gf = whisper_build_graph_cross(*ctx, *state); + + const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; + + ggml_allocr_free(alloc); + + log("%s: compute buffer (cross) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0); + + data.resize(alloc_size); + alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); } // decoder allocator { + auto & alloc = state->alloc_decode; + auto & meta = state->meta_decode; + auto & data = state->data_decode; + + meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); + + alloc = ggml_allocr_new_measure(tensor_alignment); + const auto & hparams = ctx->model.hparams; // TODO: make sure this is the worst-case scenario @@ -2837,13 +2842,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past); - const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_decode, gf) + tensor_alignment; - ggml_allocr_free(state->alloc_decode); + const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment; + ggml_allocr_free(alloc); - log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0); + log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0); - state->buf_decode.resize(alloc_size); - state->alloc_decode = ggml_allocr_new(state->buf_decode.data(), state->buf_decode.size(), tensor_alignment); + data.resize(alloc_size); + alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment); } state->rng = std::mt19937(0); @@ -3071,8 +3076,8 @@ void whisper_free_state(struct whisper_state * state) ggml_allocr_free(state->alloc_encode); } - if (state->alloc_encode_post) { - ggml_allocr_free(state->alloc_encode_post); + if (state->alloc_cross) { + ggml_allocr_free(state->alloc_cross); } if (state->alloc_decode) {