whisper : refactor ggml-alloc init

This commit is contained in:
Georgi Gerganov 2023-09-11 15:04:33 +03:00
parent 4d9acc60c3
commit 2770d46ef5
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -618,20 +618,26 @@ struct whisper_state {
// buffer for swapping KV caches between decoders during beam-search
std::vector<kv_buf> kv_swap_bufs;
// memory buffers used by encode / decode contexts
std::vector<uint8_t> buf_compute;
// reusable buffer for `struct ggml_graph_plan.work_data`
std::vector<uint8_t> work_buffer;
// ggml-alloc
std::vector<uint8_t> buf_encode;
std::vector<uint8_t> buf_encode_post;
std::vector<uint8_t> buf_decode;
// ggml-alloc:
// - stores meta info about the intermediate tensors into the `meta_*` buffers
// - stores the actual tensor data into the `data_*` buffers
ggml_allocr * alloc_encode = NULL;
ggml_allocr * alloc_encode_post = NULL;
ggml_allocr * alloc_decode = NULL;
ggml_allocr * alloc_encode = NULL;
ggml_allocr * alloc_cross = NULL;
ggml_allocr * alloc_decode = NULL;
// meta data
std::vector<uint8_t> meta_encode;
std::vector<uint8_t> meta_cross;
std::vector<uint8_t> meta_decode;
// tensor data
std::vector<uint8_t> data_encode;
std::vector<uint8_t> data_cross;
std::vector<uint8_t> data_decode;
// result of the encoder
struct ggml_tensor * embd_enc = NULL;
@ -1411,8 +1417,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
const int n_mels = hparams.n_mels;
struct ggml_init_params params = {
/*.mem_size =*/ wstate.buf_compute.size(),
/*.mem_buffer =*/ wstate.buf_compute.data(),
/*.mem_size =*/ wstate.meta_encode.size(),
/*.mem_buffer =*/ wstate.meta_encode.data(),
/*.no_alloc =*/ true,
};
@ -1746,7 +1752,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
}
// pre-compute cross-attention memory
static struct ggml_cgraph * whisper_build_graph_encoder_post(
static struct ggml_cgraph * whisper_build_graph_cross(
whisper_context & wctx,
whisper_state & wstate) {
const auto & model = wctx.model;
@ -1757,8 +1763,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post(
const int n_head = hparams.n_audio_head;
struct ggml_init_params params = {
/*.mem_size =*/ wstate.buf_compute.size(),
/*.mem_buffer =*/ wstate.buf_compute.data(),
/*.mem_size =*/ wstate.meta_cross.size(),
/*.mem_buffer =*/ wstate.meta_cross.data(),
/*.no_alloc =*/ true,
};
@ -1766,7 +1772,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post(
ggml_cgraph * gf = ggml_new_graph(ctx0);
ggml_allocr * alloc = wstate.alloc_encode_post;
ggml_allocr * alloc = wstate.alloc_cross;
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
@ -1863,13 +1869,13 @@ static bool whisper_encode_internal(
//printf("n: %d\n", ggml_nelements(cur));
}
// encoder_post
// cross
{
auto & alloc = wstate.alloc_encode_post;
auto & alloc = wstate.alloc_cross;
ggml_allocr_reset(alloc);
ggml_cgraph * gf = whisper_build_graph_encoder_post(wctx, wstate);
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
ggml_allocr_alloc_graph(alloc, gf);
@ -1924,8 +1930,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
struct ggml_init_params params = {
/*.mem_size =*/ wstate.buf_compute.size(),
/*.mem_buffer =*/ wstate.buf_compute.data(),
/*.mem_size =*/ wstate.meta_decode.size(),
/*.mem_buffer =*/ wstate.meta_decode.data(),
/*.no_alloc =*/ true,
};
@ -2733,8 +2739,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
log("debug CI - checkpoint 0\n");
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
delete state;
@ -2746,8 +2750,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
log("debug CI - checkpoint 1\n");
#ifdef WHISPER_USE_COREML
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
@ -2765,70 +2767,73 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
}
#endif
log("debug CI - checkpoint 2\n");
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
log("debug CI - checkpoint 3\n");
state->logits_id.reserve(ctx->model.hparams.n_vocab);
log("debug CI - checkpoint 4\n");
// TAGS: WHISPER_DECODER_INIT
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
log("debug CI - checkpoint 5\n");
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
log("debug CI - checkpoint 6\n");
state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
log("debug CI - checkpoint 7\n");
static const size_t tensor_alignment = 32;
log("debug CI - checkpoint 8\n");
state->alloc_encode = ggml_allocr_new_measure(tensor_alignment);
log("debug CI - checkpoint 9\n");
state->alloc_encode_post = ggml_allocr_new_measure(tensor_alignment);
log("debug CI - checkpoint 10\n");
state->alloc_decode = ggml_allocr_new_measure(tensor_alignment);
log("debug CI - checkpoint 11\n");
// encoder allocator
{
auto & alloc = state->alloc_encode;
auto & meta = state->meta_encode;
auto & data = state->data_encode;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0);
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode, gf) + tensor_alignment;
ggml_allocr_free(state->alloc_encode);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
ggml_allocr_free(alloc);
state->buf_encode.resize(alloc_size);
state->alloc_encode = ggml_allocr_new(state->buf_encode.data(), state->buf_encode.size(), tensor_alignment);
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
}
// encoder_post allocator
// cross allocator
{
ggml_cgraph * gf = whisper_build_graph_encoder_post(*ctx, *state);
auto & alloc = state->alloc_cross;
auto & meta = state->meta_cross;
auto & data = state->data_cross;
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode_post, gf) + tensor_alignment;
ggml_allocr_free(state->alloc_encode_post);
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
log("%s: compute buffer (encode_post) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
alloc = ggml_allocr_new_measure(tensor_alignment);
state->buf_encode_post.resize(alloc_size);
state->alloc_encode_post = ggml_allocr_new(state->buf_encode_post.data(), state->buf_encode_post.size(), tensor_alignment);
ggml_cgraph * gf = whisper_build_graph_cross(*ctx, *state);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (cross) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
}
// decoder allocator
{
auto & alloc = state->alloc_decode;
auto & meta = state->meta_decode;
auto & data = state->data_decode;
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
alloc = ggml_allocr_new_measure(tensor_alignment);
const auto & hparams = ctx->model.hparams;
// TODO: make sure this is the worst-case scenario
@ -2837,13 +2842,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past);
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_decode, gf) + tensor_alignment;
ggml_allocr_free(state->alloc_decode);
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
ggml_allocr_free(alloc);
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
state->buf_decode.resize(alloc_size);
state->alloc_decode = ggml_allocr_new(state->buf_decode.data(), state->buf_decode.size(), tensor_alignment);
data.resize(alloc_size);
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
}
state->rng = std::mt19937(0);
@ -3071,8 +3076,8 @@ void whisper_free_state(struct whisper_state * state)
ggml_allocr_free(state->alloc_encode);
}
if (state->alloc_encode_post) {
ggml_allocr_free(state->alloc_encode_post);
if (state->alloc_cross) {
ggml_allocr_free(state->alloc_cross);
}
if (state->alloc_decode) {