From d029784fb0914ef975dff47af9bca8bad4fa5408 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 11 Nov 2023 18:37:14 +0200 Subject: [PATCH] whisper : try to fix the parallel whisper_state functionality --- whisper.cpp | 45 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 471d9a85..ccc7aaa8 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -702,6 +702,8 @@ struct whisper_state { // buffer for swapping KV caches between decoders during beam-search std::vector kv_swap_bufs; + ggml_backend_t backend = nullptr; + // ggml-alloc: // - stores meta info about the intermediate tensors into the `meta` buffers // - stores the actual tensor data into the `data` buffers @@ -1299,7 +1301,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con } } - // init backends + // init backend { ggml_backend_t backend_gpu = NULL; @@ -1964,7 +1966,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); if (!whisper_encode_external(wstate)) { - ggml_graph_compute_helper(wctx.backend, gf, n_threads); + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } } @@ -1978,7 +1980,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); - ggml_graph_compute_helper(wctx.backend, gf, n_threads); + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } // cross @@ -1991,7 +1993,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); - ggml_graph_compute_helper(wctx.backend, gf, n_threads); + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } wstate.t_encode_us += ggml_time_us() - t_start_us; @@ -2382,7 +2384,7 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; - ggml_graph_compute_helper(wctx.backend, gf, n_threads); + ggml_graph_compute_helper(wstate.backend, gf, n_threads); } // extract logits for all N tokens @@ -2825,6 +2827,39 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_state * state = new whisper_state; + // init backend + { + ggml_backend_t backend_gpu = NULL; + + // initialize the backends +#ifdef GGML_USE_CUBLAS + if (ctx->params.use_gpu) { + WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__); + backend_gpu = ggml_backend_cuda_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__); + } + } +#endif + +#ifdef GGML_USE_METAL + if (ctx->params.use_gpu) { + WHISPER_LOG_INFO("%s: using Metal backend\n", __func__); + ggml_metal_log_set_callback(whisper_log_callback_default, nullptr); + backend_gpu = ggml_backend_metal_init(); + if (!backend_gpu) { + WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__); + } + } +#endif + + if (backend_gpu) { + state->backend = backend_gpu; + } else { + state->backend = ggml_backend_cpu_init(); + } + } + if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) { WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__); delete state;