mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-21 05:33:06 +00:00
whisper : try to fix the parallel whisper_state functionality
This commit is contained in:
parent
40c66036b6
commit
d029784fb0
45
whisper.cpp
45
whisper.cpp
@ -702,6 +702,8 @@ struct whisper_state {
|
|||||||
// buffer for swapping KV caches between decoders during beam-search
|
// buffer for swapping KV caches between decoders during beam-search
|
||||||
std::vector<kv_buf> kv_swap_bufs;
|
std::vector<kv_buf> kv_swap_bufs;
|
||||||
|
|
||||||
|
ggml_backend_t backend = nullptr;
|
||||||
|
|
||||||
// ggml-alloc:
|
// ggml-alloc:
|
||||||
// - stores meta info about the intermediate tensors into the `meta` buffers
|
// - stores meta info about the intermediate tensors into the `meta` buffers
|
||||||
// - stores the actual tensor data into the `data` buffers
|
// - stores the actual tensor data into the `data` buffers
|
||||||
@ -1299,7 +1301,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// init backends
|
// init backend
|
||||||
{
|
{
|
||||||
ggml_backend_t backend_gpu = NULL;
|
ggml_backend_t backend_gpu = NULL;
|
||||||
|
|
||||||
@ -1964,7 +1966,7 @@ static bool whisper_encode_internal(
|
|||||||
ggml_allocr_alloc_graph(alloc, gf);
|
ggml_allocr_alloc_graph(alloc, gf);
|
||||||
|
|
||||||
if (!whisper_encode_external(wstate)) {
|
if (!whisper_encode_external(wstate)) {
|
||||||
ggml_graph_compute_helper(wctx.backend, gf, n_threads);
|
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1978,7 +1980,7 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
ggml_allocr_alloc_graph(alloc, gf);
|
ggml_allocr_alloc_graph(alloc, gf);
|
||||||
|
|
||||||
ggml_graph_compute_helper(wctx.backend, gf, n_threads);
|
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
// cross
|
// cross
|
||||||
@ -1991,7 +1993,7 @@ static bool whisper_encode_internal(
|
|||||||
|
|
||||||
ggml_allocr_alloc_graph(alloc, gf);
|
ggml_allocr_alloc_graph(alloc, gf);
|
||||||
|
|
||||||
ggml_graph_compute_helper(wctx.backend, gf, n_threads);
|
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
wstate.t_encode_us += ggml_time_us() - t_start_us;
|
||||||
@ -2382,7 +2384,7 @@ static bool whisper_decode_internal(
|
|||||||
|
|
||||||
logits = gf->nodes[gf->n_nodes - 1];
|
logits = gf->nodes[gf->n_nodes - 1];
|
||||||
|
|
||||||
ggml_graph_compute_helper(wctx.backend, gf, n_threads);
|
ggml_graph_compute_helper(wstate.backend, gf, n_threads);
|
||||||
}
|
}
|
||||||
|
|
||||||
// extract logits for all N tokens
|
// extract logits for all N tokens
|
||||||
@ -2825,6 +2827,39 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
|||||||
|
|
||||||
whisper_state * state = new whisper_state;
|
whisper_state * state = new whisper_state;
|
||||||
|
|
||||||
|
// init backend
|
||||||
|
{
|
||||||
|
ggml_backend_t backend_gpu = NULL;
|
||||||
|
|
||||||
|
// initialize the backends
|
||||||
|
#ifdef GGML_USE_CUBLAS
|
||||||
|
if (ctx->params.use_gpu) {
|
||||||
|
WHISPER_LOG_INFO("%s: using CUDA backend\n", __func__);
|
||||||
|
backend_gpu = ggml_backend_cuda_init();
|
||||||
|
if (!backend_gpu) {
|
||||||
|
WHISPER_LOG_ERROR("%s: ggml_backend_cuda_init() failed\n", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#ifdef GGML_USE_METAL
|
||||||
|
if (ctx->params.use_gpu) {
|
||||||
|
WHISPER_LOG_INFO("%s: using Metal backend\n", __func__);
|
||||||
|
ggml_metal_log_set_callback(whisper_log_callback_default, nullptr);
|
||||||
|
backend_gpu = ggml_backend_metal_init();
|
||||||
|
if (!backend_gpu) {
|
||||||
|
WHISPER_LOG_ERROR("%s: ggml_backend_metal_init() failed\n", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
if (backend_gpu) {
|
||||||
|
state->backend = backend_gpu;
|
||||||
|
} else {
|
||||||
|
state->backend = ggml_backend_cpu_init();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
if (!kv_cache_init(ctx->model.hparams, state->decoders[0].kv_self, ctx->backend, ctx->itype, ctx->model.hparams.n_text_ctx)) {
|
||||||
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
WHISPER_LOG_ERROR("%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
delete state;
|
delete state;
|
||||||
|
Loading…
Reference in New Issue
Block a user