mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-19 12:47:52 +00:00
whisper : allocate encoder and decoder using ggml-alloc
This commit is contained in:
parent
949ab6328d
commit
bed5ad69dd
100
whisper.cpp
100
whisper.cpp
@ -606,11 +606,17 @@ struct whisper_state {
|
|||||||
// memory buffers used by encode / decode contexts
|
// memory buffers used by encode / decode contexts
|
||||||
std::vector<uint8_t> buf_compute;
|
std::vector<uint8_t> buf_compute;
|
||||||
|
|
||||||
|
// ggml-alloc
|
||||||
std::vector<uint8_t> buf_encode;
|
std::vector<uint8_t> buf_encode;
|
||||||
|
std::vector<uint8_t> buf_encode_post;
|
||||||
std::vector<uint8_t> buf_decode;
|
std::vector<uint8_t> buf_decode;
|
||||||
|
|
||||||
ggml_allocr * alloc_encode = NULL;
|
ggml_allocr * alloc_encode = NULL;
|
||||||
ggml_allocr * alloc_decode = NULL;
|
ggml_allocr * alloc_encode_post = NULL;
|
||||||
|
ggml_allocr * alloc_decode = NULL;
|
||||||
|
|
||||||
|
// result of the encoder
|
||||||
|
struct ggml_tensor * embd_enc = NULL;
|
||||||
|
|
||||||
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
// decode output (2-dimensional array: [n_tokens][n_vocab])
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
@ -701,7 +707,7 @@ static bool kv_cache_init(
|
|||||||
const int64_t n_mem = n_text_layer*n_ctx;
|
const int64_t n_mem = n_text_layer*n_ctx;
|
||||||
const int64_t n_elements = n_text_state*n_mem;
|
const int64_t n_elements = n_text_state*n_mem;
|
||||||
|
|
||||||
const size_t mem_bytes = ggml_type_size(wtype)*n_elements;
|
const size_t mem_bytes = 2*(ggml_type_size(wtype)*n_elements + ggml_tensor_overhead());
|
||||||
|
|
||||||
cache.buf.resize(mem_bytes);
|
cache.buf.resize(mem_bytes);
|
||||||
|
|
||||||
@ -1385,6 +1391,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
const int n_layer = hparams.n_audio_layer;
|
const int n_layer = hparams.n_audio_layer;
|
||||||
|
|
||||||
const int n_mels = hparams.n_mels;
|
const int n_mels = hparams.n_mels;
|
||||||
|
|
||||||
assert(mel_inp.n_mel == n_mels);
|
assert(mel_inp.n_mel == n_mels);
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
@ -1397,9 +1404,11 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
|
|
||||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
|
ggml_allocr * alloc = wstate.alloc_encode;
|
||||||
|
|
||||||
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
struct ggml_tensor * mel = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 2*n_ctx, n_mels);
|
||||||
assert(mel->type == GGML_TYPE_F32);
|
assert(mel->type == GGML_TYPE_F32);
|
||||||
{
|
if (!ggml_allocr_is_measure(alloc)) {
|
||||||
float * dst = (float *) mel->data;
|
float * dst = (float *) mel->data;
|
||||||
memset(dst, 0, ggml_nbytes(mel));
|
memset(dst, 0, ggml_nbytes(mel));
|
||||||
|
|
||||||
@ -1689,6 +1698,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
wstate.embd_enc = cur;
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
//printf("%s: used_mem = %f MB, %f MB, %f MB %f MB %f MB\n", __func__,
|
||||||
@ -1706,8 +1717,7 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
|
|||||||
// pre-compute cross-attention memory
|
// pre-compute cross-attention memory
|
||||||
static struct ggml_cgraph * whisper_build_graph_encoder_post(
|
static struct ggml_cgraph * whisper_build_graph_encoder_post(
|
||||||
whisper_context & wctx,
|
whisper_context & wctx,
|
||||||
whisper_state & wstate,
|
whisper_state & wstate) {
|
||||||
struct ggml_tensor * embd_enc) {
|
|
||||||
const auto & model = wctx.model;
|
const auto & model = wctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
@ -1725,7 +1735,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder_post(
|
|||||||
|
|
||||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
struct ggml_tensor * cur = embd_enc;
|
//ggml_allocr * alloc = wstate.alloc_encode_post;
|
||||||
|
|
||||||
|
struct ggml_tensor * cur = wstate.embd_enc;
|
||||||
|
|
||||||
// TODO: hack to disconnect the encoded features from the previous graph
|
// TODO: hack to disconnect the encoded features from the previous graph
|
||||||
cur->op = GGML_OP_NONE;
|
cur->op = GGML_OP_NONE;
|
||||||
@ -1826,12 +1838,18 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
|
|
||||||
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
|
ggml_allocr * alloc = wstate.alloc_decode;
|
||||||
|
|
||||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
if (!ggml_allocr_is_measure(alloc)) {
|
||||||
|
memcpy(embd->data, tokens, N*ggml_element_size(embd));
|
||||||
|
}
|
||||||
|
|
||||||
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * position = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
for (int i = 0; i < N; ++i) {
|
if (!ggml_allocr_is_measure(alloc)) {
|
||||||
((int32_t *) position->data)[i] = n_past + i;
|
for (int i = 0; i < N; ++i) {
|
||||||
|
((int32_t *) position->data)[i] = n_past + i;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// token encoding + position encoding
|
// token encoding + position encoding
|
||||||
@ -2637,8 +2655,54 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
state->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead());
|
||||||
|
|
||||||
static const size_t tensor_alignment = 32;
|
static const size_t tensor_alignment = 32;
|
||||||
state->alloc_encode = ggml_allocr_new_measure(tensor_alignment);
|
|
||||||
state->alloc_decode = ggml_allocr_new_measure(tensor_alignment);
|
state->alloc_encode = ggml_allocr_new_measure(tensor_alignment);
|
||||||
|
state->alloc_encode_post = ggml_allocr_new_measure(tensor_alignment);
|
||||||
|
state->alloc_decode = ggml_allocr_new_measure(tensor_alignment);
|
||||||
|
|
||||||
|
// encoder allocator
|
||||||
|
{
|
||||||
|
ggml_cgraph * gf = whisper_build_graph_encoder(*ctx, *state, 0);
|
||||||
|
|
||||||
|
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode, gf) + tensor_alignment;
|
||||||
|
ggml_allocr_free(state->alloc_encode);
|
||||||
|
|
||||||
|
log("%s: compute buffer (encode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
|
||||||
|
|
||||||
|
state->buf_encode.resize(alloc_size);
|
||||||
|
state->alloc_encode = ggml_allocr_new(state->buf_encode.data(), state->buf_encode.size(), tensor_alignment);
|
||||||
|
}
|
||||||
|
|
||||||
|
// encoder_post allocator
|
||||||
|
{
|
||||||
|
ggml_cgraph * gf = whisper_build_graph_encoder_post(*ctx, *state);
|
||||||
|
|
||||||
|
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_encode_post, gf) + tensor_alignment;
|
||||||
|
ggml_allocr_free(state->alloc_encode_post);
|
||||||
|
|
||||||
|
log("%s: compute buffer (encode_post) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
|
||||||
|
|
||||||
|
state->buf_encode_post.resize(alloc_size);
|
||||||
|
state->alloc_encode_post = ggml_allocr_new(state->buf_encode_post.data(), state->buf_encode_post.size(), tensor_alignment);
|
||||||
|
}
|
||||||
|
|
||||||
|
// decoder allocator
|
||||||
|
{
|
||||||
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
|
const int n_tokens = hparams.n_text_ctx/2;
|
||||||
|
const int n_past = hparams.n_text_ctx/2; // TODO: double-check
|
||||||
|
|
||||||
|
ggml_cgraph * gf = whisper_build_graph_decoder(*ctx, *state, state->decoders[0], NULL, n_tokens, n_past);
|
||||||
|
|
||||||
|
const size_t alloc_size = ggml_allocr_alloc_graph(state->alloc_decode, gf) + tensor_alignment;
|
||||||
|
ggml_allocr_free(state->alloc_decode);
|
||||||
|
|
||||||
|
log("%s: compute buffer (decode) = %7.2f MB\n", __func__, (state->buf_compute.size() + alloc_size) / 1024.0 / 1024.0);
|
||||||
|
|
||||||
|
state->buf_decode.resize(alloc_size);
|
||||||
|
state->alloc_decode = ggml_allocr_new(state->buf_decode.data(), state->buf_decode.size(), tensor_alignment);
|
||||||
|
}
|
||||||
|
|
||||||
state->rng = std::mt19937(0);
|
state->rng = std::mt19937(0);
|
||||||
|
|
||||||
@ -2862,6 +2926,18 @@ void whisper_free_state(struct whisper_state * state)
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
if (state->alloc_encode) {
|
||||||
|
ggml_allocr_free(state->alloc_encode);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state->alloc_encode_post) {
|
||||||
|
ggml_allocr_free(state->alloc_encode_post);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (state->alloc_decode) {
|
||||||
|
ggml_allocr_free(state->alloc_decode);
|
||||||
|
}
|
||||||
|
|
||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user