mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-23 14:32:23 +00:00
whisper : refactor ggml-alloc init
This commit is contained in:
parent
4d9acc60c3
commit
2770d46ef5
139
whisper.cpp
139
whisper.cpp
@ -618,20 +618,26 @@ struct whisper_state {
|
||||
// buffer for swapping KV caches between decoders during beam-search
|
||||
std::vector<kv_buf> kv_swap_bufs;
|
||||
|
||||
// memory buffers used by encode / decode contexts
|
||||
std::vector<uint8_t> buf_compute;
|
||||
|
||||
// reusable buffer for `struct ggml_graph_plan.work_data`
|
||||
std::vector<uint8_t> work_buffer;
|
||||
|
||||
// ggml-alloc
|
||||
std::vector<uint8_t> buf_encode;
|
||||
std::vector<uint8_t> buf_encode_post;
|
||||
std::vector<uint8_t> buf_decode;
|
||||
// ggml-alloc:
|
||||
// - stores meta info about the intermediate tensors into the `meta_*` buffers
|
||||
// - stores the actual tensor data into the `data_*` buffers
|
||||
|
||||
ggml_allocr * alloc_encode = NULL;
|
||||
ggml_allocr * alloc_encode_post = NULL;
|
||||
ggml_allocr * alloc_decode = NULL;
|
||||
ggml_allocr * alloc_encode = NULL;
|
||||
ggml_allocr * alloc_cross = NULL;
|
||||
ggml_allocr * alloc_decode = NULL;
|
||||
|
||||
// meta data
|
||||
std::vector<uint8_t> meta_encode;
|
||||
std::vector<uint8_t> meta_cross;
|
||||
std::vector<uint8_t> meta_decode;
|
||||
|
||||
// tensor data
|
||||
std::vector<uint8_t> data_encode;
|
||||
std::vector<uint8_t> data_cross;
|
||||
std::vector<uint8_t> data_decode;
|
||||
|
||||
// result of the encoder
|
||||
struct ggml_tensor * embd_enc = NULL;
|
||||
@ -1411,8 +1417,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
const int n_mels = hparams.n_mels;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.buf_compute.size(),
|
||||
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
||||
/*.mem_size =*/ wstate.meta_encode.size(),
|
||||
/*.mem_buffer =*/ wstate.meta_encode.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
@ -1746,7 +1752,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
||||
}
|
||||
|
||||
// pre-compute cross-attention memory
|
||||
static struct ggml_cgraph * whisper_build_graph_encoder_post(
|
||||
static struct ggml_cgraph * whisper_build_graph_cross(
|
||||
whisper_context & wctx,
|
||||
whisper_state & wstate) {
|
||||
const auto & model = wctx.model;
|
||||
@ -1757,8 +1763,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post(
|
||||
const int n_head = hparams.n_audio_head;
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.buf_compute.size(),
|
||||
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
||||
/*.mem_size =*/ wstate.meta_cross.size(),
|
||||
/*.mem_buffer =*/ wstate.meta_cross.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
@ -1766,7 +1772,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post(
|
||||
|
||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
|
||||
ggml_allocr * alloc = wstate.alloc_encode_post;
|
||||
ggml_allocr * alloc = wstate.alloc_cross;
|
||||
|
||||
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc);
|
||||
|
||||
@ -1863,13 +1869,13 @@ static bool whisper_encode_internal(
|
||||
//printf("n: %d\n", ggml_nelements(cur));
|
||||
}
|
||||
|
||||
// encoder_post
|
||||
// cross
|
||||
{
|
||||
auto & alloc = wstate.alloc_encode_post;
|
||||
auto & alloc = wstate.alloc_cross;
|
||||
|
||||
ggml_allocr_reset(alloc);
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_encoder_post(wctx, wstate);
|
||||
ggml_cgraph * gf = whisper_build_graph_cross(wctx, wstate);
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
@ -1924,8 +1930,8 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
//WHISPER_PRINT_DEBUG("%s: n_past = %d, N = %d, M = %d, n_ctx = %d\n", __func__, n_past, N, M, n_ctx);
|
||||
|
||||
struct ggml_init_params params = {
|
||||
/*.mem_size =*/ wstate.buf_compute.size(),
|
||||
/*.mem_buffer =*/ wstate.buf_compute.data(),
|
||||
/*.mem_size =*/ wstate.meta_decode.size(),
|
||||
/*.mem_buffer =*/ wstate.meta_decode.data(),
|
||||
/*.no_alloc =*/ true,
|
||||
};
|
||||
|
||||
@ -2733,8 +2739,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
log("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
log("debug CI - checkpoint 0\n");
|
||||
|
||||
if (!kv_cache_init(ctx->model.hparams, state->kv_cross, ctx->itype, ctx->model.hparams.n_audio_ctx)) {
|
||||
log("%s: kv_cache_init() failed for cross-attention cache\n", __func__);
|
||||
delete state;
|
||||
@ -2746,8 +2750,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
log("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
log("debug CI - checkpoint 1\n");
|
||||
|
||||
#ifdef WHISPER_USE_COREML
|
||||
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);
|
||||
|
||||
@ -2765,70 +2767,73 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
}
|
||||
#endif
|
||||
|
||||
log("debug CI - checkpoint 2\n");
|
||||
|
||||
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
||||
|
||||
log("debug CI - checkpoint 3\n");
|
||||
|
||||
state->logits_id.reserve(ctx->model.hparams.n_vocab);
|
||||
|
||||
log("debug CI - checkpoint 4\n");
|
||||
|
||||
// TAGS: WHISPER_DECODER_INIT
|
||||
state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx);
|
||||
|
||||
log("debug CI - checkpoint 5\n");
|
||||
|
||||
state->decoders[0].probs.reserve(ctx->vocab.n_vocab);
|
||||
state->decoders[0].logits.reserve(ctx->vocab.n_vocab);
|
||||
state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab);
|
||||
|
||||
log("debug CI - checkpoint 6\n");
|
||||
|
||||
state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
log("debug CI - checkpoint 7\n");
|
||||
|
||||
static const size_t tensor_alignment = 32;
|
||||
|
||||
log("debug CI - checkpoint 8\n");
|
||||
|
||||
state->alloc_encode = ggml_allocr_new_measure(tensor_alignment);
|
||||
log("debug CI - checkpoint 9\n");
|
||||
state->alloc_encode_post = ggml_allocr_new_measure(tensor_alignment);
|
||||
log("debug CI - checkpoint 10\n");
|
||||
state->alloc_decode = ggml_allocr_new_measure(tensor_alignment);
|
||||
log("debug CI - checkpoint 11\n");
|
||||
|
||||
// encoder allocator
|
||||
{
|
||||
auto & alloc = state->alloc_encode;
|
||||
auto & meta = state->meta_encode;
|
||||
auto & data = state->data_encode;
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
alloc = ggml_allocr_new_measure(tensor_alignment);
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0);
|
||||
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode, gf) + tensor_alignment;
|
||||
ggml_allocr_free(state->alloc_encode);
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
|
||||
|
||||
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
|
||||
ggml_allocr_free(alloc);
|
||||
|
||||
state->buf_encode.resize(alloc_size);
|
||||
state->alloc_encode = ggml_allocr_new(state->buf_encode.data(), state->buf_encode.size(), tensor_alignment);
|
||||
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
|
||||
|
||||
data.resize(alloc_size);
|
||||
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
||||
}
|
||||
|
||||
// encoder_post allocator
|
||||
// cross allocator
|
||||
{
|
||||
ggml_cgraph * gf = whisper_build_graph_encoder_post(*ctx, *state);
|
||||
auto & alloc = state->alloc_cross;
|
||||
auto & meta = state->meta_cross;
|
||||
auto & data = state->data_cross;
|
||||
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode_post, gf) + tensor_alignment;
|
||||
ggml_allocr_free(state->alloc_encode_post);
|
||||
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
log("%s: compute buffer (encode_post) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
|
||||
alloc = ggml_allocr_new_measure(tensor_alignment);
|
||||
|
||||
state->buf_encode_post.resize(alloc_size);
|
||||
state->alloc_encode_post = ggml_allocr_new(state->buf_encode_post.data(), state->buf_encode_post.size(), tensor_alignment);
|
||||
ggml_cgraph * gf = whisper_build_graph_cross(*ctx, *state);
|
||||
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
|
||||
|
||||
ggml_allocr_free(alloc);
|
||||
|
||||
log("%s: compute buffer (cross) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
|
||||
|
||||
data.resize(alloc_size);
|
||||
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
||||
}
|
||||
|
||||
// decoder allocator
|
||||
{
|
||||
auto & alloc = state->alloc_decode;
|
||||
auto & meta = state->meta_decode;
|
||||
auto & data = state->data_decode;
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
alloc = ggml_allocr_new_measure(tensor_alignment);
|
||||
|
||||
const auto & hparams = ctx->model.hparams;
|
||||
|
||||
// TODO: make sure this is the worst-case scenario
|
||||
@ -2837,13 +2842,13 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past);
|
||||
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_decode, gf) + tensor_alignment;
|
||||
ggml_allocr_free(state->alloc_decode);
|
||||
const size_t alloc_size = ggml_allocr_alloc_graph(alloc, gf) + tensor_alignment;
|
||||
ggml_allocr_free(alloc);
|
||||
|
||||
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
|
||||
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (meta.size() + alloc_size) / 1024.0 / 1024.0);
|
||||
|
||||
state->buf_decode.resize(alloc_size);
|
||||
state->alloc_decode = ggml_allocr_new(state->buf_decode.data(), state->buf_decode.size(), tensor_alignment);
|
||||
data.resize(alloc_size);
|
||||
alloc = ggml_allocr_new(data.data(), data.size(), tensor_alignment);
|
||||
}
|
||||
|
||||
state->rng = std::mt19937(0);
|
||||
@ -3071,8 +3076,8 @@ void whisper_free_state(struct whisper_state * state)
|
||||
ggml_allocr_free(state->alloc_encode);
|
||||
}
|
||||
|
||||
if (state->alloc_encode_post) {
|
||||
ggml_allocr_free(state->alloc_encode_post);
|
||||
if (state->alloc_cross) {
|
||||
ggml_allocr_free(state->alloc_cross);
|
||||
}
|
||||
|
||||
if (state->alloc_decode) {
|
||||
|
Loading…
Reference in New Issue
Block a user