diff --git a/whisper.cpp b/whisper.cpp index eb69f96f..d7bbeb4d 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -155,8 +155,8 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text // static void ggml_graph_compute_helper( + struct ggml_cgraph * graph, std::vector & buf, - ggml_cgraph * graph, int n_threads, whisper_abort_callback abort_callback, void * abort_callback_data) { @@ -173,6 +173,21 @@ static void ggml_graph_compute_helper( ggml_graph_compute(graph, &plan); } +static void ggml_graph_compute_helper( + struct ggml_backend * backend, + struct ggml_cgraph * graph, + int n_threads) { + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, n_threads); + } +#ifdef GGML_USE_METAL + if (ggml_backend_is_metal(backend)) { + ggml_backend_metal_set_n_cb(backend, n_threads); + } +#endif + ggml_backend_graph_compute(backend, graph); +} + // faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad" // the idea is to represent the original matrix multiplication: // @@ -1943,15 +1958,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); if (!whisper_encode_external(wstate)) { - if (ggml_backend_is_cpu(wctx.backend)) { - ggml_backend_cpu_set_n_threads(wctx.backend, n_threads); - } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(wctx.backend)) { - ggml_backend_metal_set_n_cb(wctx.backend, n_threads); - } -#endif - ggml_backend_graph_compute(wctx.backend, gf); + ggml_graph_compute_helper(wctx.backend, gf, n_threads); } } @@ -1965,15 +1972,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); - if (ggml_backend_is_cpu(wctx.backend)) { - ggml_backend_cpu_set_n_threads(wctx.backend, n_threads); - } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(wctx.backend)) { - ggml_backend_metal_set_n_cb(wctx.backend, n_threads); - } -#endif - ggml_backend_graph_compute(wctx.backend, gf); + ggml_graph_compute_helper(wctx.backend, gf, n_threads); } // cross @@ -1986,15 +1985,7 @@ static bool whisper_encode_internal( ggml_allocr_alloc_graph(alloc, gf); - if (ggml_backend_is_cpu(wctx.backend)) { - ggml_backend_cpu_set_n_threads(wctx.backend, n_threads); - } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(wctx.backend)) { - ggml_backend_metal_set_n_cb(wctx.backend, n_threads); - } -#endif - ggml_backend_graph_compute(wctx.backend, gf); + ggml_graph_compute_helper(wctx.backend, gf, n_threads); } wstate.t_encode_us += ggml_time_us() - t_start_us; @@ -2385,15 +2376,7 @@ static bool whisper_decode_internal( logits = gf->nodes[gf->n_nodes - 1]; - if (ggml_backend_is_cpu(wctx.backend)) { - ggml_backend_cpu_set_n_threads(wctx.backend, n_threads); - } -#ifdef GGML_USE_METAL - if (ggml_backend_is_metal(wctx.backend)) { - ggml_backend_metal_set_n_cb(wctx.backend, n_threads); - } -#endif - ggml_backend_graph_compute(wctx.backend, gf); + ggml_graph_compute_helper(wctx.backend, gf, n_threads); } // extract logits for all N tokens @@ -5495,12 +5478,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) { double tsum = 0.0; // heat-up - ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr); + ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr); for (int i = 0; i < n_max; ++i) { const int64_t t0 = ggml_time_us(); - ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr); + ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr); const int64_t t1 = ggml_time_us();