mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-05 02:19:09 +00:00
whisper : factor out graph compute in common function
This commit is contained in:
parent
b27726da93
commit
b618229340
61
whisper.cpp
61
whisper.cpp
@ -155,8 +155,8 @@ static void whisper_log_callback_default(ggml_log_level level, const char * text
|
||||
//
|
||||
|
||||
static void ggml_graph_compute_helper(
|
||||
struct ggml_cgraph * graph,
|
||||
std::vector<uint8_t> & buf,
|
||||
ggml_cgraph * graph,
|
||||
int n_threads,
|
||||
whisper_abort_callback abort_callback,
|
||||
void * abort_callback_data) {
|
||||
@ -173,6 +173,21 @@ static void ggml_graph_compute_helper(
|
||||
ggml_graph_compute(graph, &plan);
|
||||
}
|
||||
|
||||
static void ggml_graph_compute_helper(
|
||||
struct ggml_backend * backend,
|
||||
struct ggml_cgraph * graph,
|
||||
int n_threads) {
|
||||
if (ggml_backend_is_cpu(backend)) {
|
||||
ggml_backend_cpu_set_n_threads(backend, n_threads);
|
||||
}
|
||||
#ifdef GGML_USE_METAL
|
||||
if (ggml_backend_is_metal(backend)) {
|
||||
ggml_backend_metal_set_n_cb(backend, n_threads);
|
||||
}
|
||||
#endif
|
||||
ggml_backend_graph_compute(backend, graph);
|
||||
}
|
||||
|
||||
// faster matrix multiplications for tensors that do not have dimension 0 divisible by "pad"
|
||||
// the idea is to represent the original matrix multiplication:
|
||||
//
|
||||
@ -1943,15 +1958,7 @@ static bool whisper_encode_internal(
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
if (!whisper_encode_external(wstate)) {
|
||||
if (ggml_backend_is_cpu(wctx.backend)) {
|
||||
ggml_backend_cpu_set_n_threads(wctx.backend, n_threads);
|
||||
}
|
||||
#ifdef GGML_USE_METAL
|
||||
if (ggml_backend_is_metal(wctx.backend)) {
|
||||
ggml_backend_metal_set_n_cb(wctx.backend, n_threads);
|
||||
}
|
||||
#endif
|
||||
ggml_backend_graph_compute(wctx.backend, gf);
|
||||
ggml_graph_compute_helper(wctx.backend, gf, n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1965,15 +1972,7 @@ static bool whisper_encode_internal(
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
if (ggml_backend_is_cpu(wctx.backend)) {
|
||||
ggml_backend_cpu_set_n_threads(wctx.backend, n_threads);
|
||||
}
|
||||
#ifdef GGML_USE_METAL
|
||||
if (ggml_backend_is_metal(wctx.backend)) {
|
||||
ggml_backend_metal_set_n_cb(wctx.backend, n_threads);
|
||||
}
|
||||
#endif
|
||||
ggml_backend_graph_compute(wctx.backend, gf);
|
||||
ggml_graph_compute_helper(wctx.backend, gf, n_threads);
|
||||
}
|
||||
|
||||
// cross
|
||||
@ -1986,15 +1985,7 @@ static bool whisper_encode_internal(
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, gf);
|
||||
|
||||
if (ggml_backend_is_cpu(wctx.backend)) {
|
||||
ggml_backend_cpu_set_n_threads(wctx.backend, n_threads);
|
||||
}
|
||||
#ifdef GGML_USE_METAL
|
||||
if (ggml_backend_is_metal(wctx.backend)) {
|
||||
ggml_backend_metal_set_n_cb(wctx.backend, n_threads);
|
||||
}
|
||||
#endif
|
||||
ggml_backend_graph_compute(wctx.backend, gf);
|
||||
ggml_graph_compute_helper(wctx.backend, gf, n_threads);
|
||||
}
|
||||
|
||||
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
||||
@ -2385,15 +2376,7 @@ static bool whisper_decode_internal(
|
||||
|
||||
logits = gf->nodes[gf->n_nodes - 1];
|
||||
|
||||
if (ggml_backend_is_cpu(wctx.backend)) {
|
||||
ggml_backend_cpu_set_n_threads(wctx.backend, n_threads);
|
||||
}
|
||||
#ifdef GGML_USE_METAL
|
||||
if (ggml_backend_is_metal(wctx.backend)) {
|
||||
ggml_backend_metal_set_n_cb(wctx.backend, n_threads);
|
||||
}
|
||||
#endif
|
||||
ggml_backend_graph_compute(wctx.backend, gf);
|
||||
ggml_graph_compute_helper(wctx.backend, gf, n_threads);
|
||||
}
|
||||
|
||||
// extract logits for all N tokens
|
||||
@ -5495,12 +5478,12 @@ WHISPER_API const char * whisper_bench_ggml_mul_mat_str(int n_threads) {
|
||||
double tsum = 0.0;
|
||||
|
||||
// heat-up
|
||||
ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
|
||||
ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
|
||||
|
||||
for (int i = 0; i < n_max; ++i) {
|
||||
const int64_t t0 = ggml_time_us();
|
||||
|
||||
ggml_graph_compute_helper(work, gf, n_threads, nullptr, nullptr);
|
||||
ggml_graph_compute_helper(gf, work, n_threads, nullptr, nullptr);
|
||||
|
||||
const int64_t t1 = ggml_time_us();
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user