whisper : allocate encoder results in dedicated buffer

This commit is contained in:
Georgi Gerganov 2024-03-16 16:02:48 +02:00
parent de4d067f1e
commit c6c94de43a
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -793,6 +793,9 @@ struct whisper_state {
struct ggml_tensor * embd_conv = nullptr; struct ggml_tensor * embd_conv = nullptr;
struct ggml_tensor * embd_enc = nullptr; struct ggml_tensor * embd_enc = nullptr;
ggml_context * ctx_embd = nullptr;
ggml_backend_buffer_t buffer_embd = nullptr;
// helpers for GPU offloading // helpers for GPU offloading
std::vector<float> inp_mel; std::vector<float> inp_mel;
std::vector<float> inp_mask; std::vector<float> inp_mask;
@ -1669,15 +1672,9 @@ static struct ggml_cgraph * whisper_build_graph_conv(
cur = ggml_gelu(ctx0, cur); cur = ggml_gelu(ctx0, cur);
} }
ggml_set_name(cur, "embd_conv"); cur = ggml_cpy(ctx0, cur, wstate.embd_conv);
wstate.embd_conv = cur;
} else { } else {
ggml_build_forward_expand(gf, mel); ggml_build_forward_expand(gf, mel);
cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_state, n_ctx);
ggml_set_name(cur, "embd_enc");
wstate.embd_enc = cur;
} }
ggml_build_forward_expand(gf, cur); ggml_build_forward_expand(gf, cur);
@ -1708,7 +1705,10 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false); ggml_cgraph * gf = ggml_new_graph_custom(ctx0, WHISPER_MAX_NODES, false);
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv); // TODO: this still triggers the assert:
//struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_conv);
struct ggml_tensor * cur = wstate.embd_conv;
const float KQscale = 1.0f/sqrtf(float(n_state)/n_head); const float KQscale = 1.0f/sqrtf(float(n_state)/n_head);
@ -1908,9 +1908,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
model.e_ln_b); model.e_ln_b);
} }
ggml_build_forward_expand(gf, cur); cur = ggml_cpy(ctx0, cur, wstate.embd_enc);
wstate.embd_enc = cur; ggml_build_forward_expand(gf, cur);
//ggml_graph_print(gf); //ggml_graph_print(gf);
@ -1949,7 +1949,7 @@ static struct ggml_cgraph * whisper_build_graph_cross(
ggml_cgraph * gf = ggml_new_graph(ctx0); ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_tensor * cur = ggml_view_tensor(ctx0, wstate.embd_enc); struct ggml_tensor * cur = wstate.embd_enc;
const float Kscale = pow(float(n_state) / n_head, -0.25); const float Kscale = pow(float(n_state) / n_head, -0.25);
@ -3001,6 +3001,27 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6); WHISPER_LOG_INFO("%s: kv cross size = %7.2f MB\n", __func__, memory_size / 1e6);
} }
// encoder results
{
ggml_init_params init_params = {
/* .mem_size */ 2*ggml_tensor_overhead(),
/* .mem_buffer */ nullptr,
/* .no_alloc */ true,
};
state->ctx_embd = ggml_init(init_params);
state->embd_enc = ggml_new_tensor_2d(state->ctx_embd, GGML_TYPE_F32, ctx->model.hparams.n_audio_state, ctx->model.hparams.n_audio_ctx);
state->embd_conv = ggml_new_tensor_2d(state->ctx_embd, GGML_TYPE_F32, ctx->model.hparams.n_audio_ctx, ctx->model.hparams.n_audio_state);
ggml_set_name(state->embd_enc, "embd_enc");
ggml_set_name(state->embd_conv, "embd_conv");
state->buffer_embd = ggml_backend_alloc_ctx_tensors_from_buft(state->ctx_embd, ggml_backend_get_default_buffer_type(ctx->backend));
WHISPER_LOG_INFO("%s: %s enc results size = %.2f MiB\n", __func__,
ggml_backend_buffer_name(state->buffer_embd), ggml_backend_buffer_get_size(state->buffer_embd) / 1e6);
}
#ifdef WHISPER_USE_COREML #ifdef WHISPER_USE_COREML
const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model); const auto path_coreml = whisper_get_coreml_path_encoder(ctx->path_model);