From 9006946e4b21ac72423ccb16609af18a744b068b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 14 Nov 2023 22:05:31 +0200 Subject: [PATCH] whisper : speed-up sampling --- whisper.cpp | 291 +++++++++++++++++++++++++++++++--------------------- 1 file changed, 172 insertions(+), 119 deletions(-) diff --git a/whisper.cpp b/whisper.cpp index 502ac64d..584e3945 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -20,6 +20,7 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +#include #include #include #define _USE_MATH_DEFINES @@ -459,6 +460,68 @@ static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token batch.logits[n_tokens - 1] = 1; } +// replace std::pair by using customized pair struct (reason: std::pair is very slow) +template +struct whisper_pair { + A first; + B second; + + // Define a constructor that takes two arguments. + whisper_pair(const A& a, const B& b) : first(a), second(b) {} + // Define a constructor that takes no argument. + whisper_pair() : first(A()), second(B()) {} +}; + +// ggml_allocr wrapper for whisper usage +struct whisper_allocr { + ggml_allocr * alloc = nullptr; + + std::vector meta; + + ggml_backend_buffer_t buffer; +}; + +static size_t whisper_allocr_size(struct whisper_allocr & allocr) { + return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc); +} + +// measure the memory usage of a graph and prepare the allocr's internal data buffer +static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function && get_graph) { + auto & alloc = allocr.alloc; + auto & meta = allocr.meta; + + alloc = ggml_allocr_new_measure_from_backend(backend); + + meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); + + ggml_allocr_alloc_graph(alloc, get_graph()); +} + +static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) { + if (allocr.alloc == nullptr) { + // this can be null if we use external encoder like CoreML or OpenVINO + return; + } + + auto & alloc = allocr.alloc; + auto & buffer = allocr.buffer; + + size_t size = ggml_allocr_max_size(alloc); + + ggml_allocr_free(alloc); + + buffer = ggml_backend_alloc_buffer(backend, size); + alloc = ggml_allocr_new_from_buffer(buffer); +} + +static void whisper_allocr_free(struct whisper_allocr & allocr) { + if (allocr.alloc) { + ggml_allocr_free(allocr.alloc); + ggml_backend_buffer_free(allocr.buffer); + allocr.alloc = nullptr; + } +} + // medium // hparams: { // 'n_mels': 80, @@ -699,70 +762,13 @@ struct whisper_decoder { std::vector probs; std::vector logits; std::vector logprobs; + + // work container used to avoid memory allocations + std::vector> logits_id; + + mutable std::mt19937 rng; // used for sampling at t > 0.0 }; -// replace std::pair by using customized pair struct (reason: std::pair is very slow) -template -struct whisper_pair { - A first; - B second; - - // Define a constructor that takes two arguments. - whisper_pair(const A& a, const B& b) : first(a), second(b) {} - // Define a constructor that takes no argument. - whisper_pair() : first(A()), second(B()) {} -}; - -// ggml_allocr wrapper for whisper usage -struct whisper_allocr { - ggml_allocr * alloc = nullptr; - - std::vector meta; - - ggml_backend_buffer_t buffer; -}; - -static size_t whisper_allocr_size(struct whisper_allocr & allocr) { - return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc); -} - -// measure the memory usage of a graph and prepare the allocr's internal data buffer -static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function && get_graph) { - auto & alloc = allocr.alloc; - auto & meta = allocr.meta; - - alloc = ggml_allocr_new_measure_from_backend(backend); - - meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead()); - - ggml_allocr_alloc_graph(alloc, get_graph()); -} - -static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) { - if (allocr.alloc == nullptr) { - // this can be null if we use external encoder like CoreML or OpenVINO - return; - } - - auto & alloc = allocr.alloc; - auto & buffer = allocr.buffer; - - size_t size = ggml_allocr_max_size(alloc); - - ggml_allocr_free(alloc); - - buffer = ggml_backend_alloc_buffer(backend, size); - alloc = ggml_allocr_new_from_buffer(buffer); -} - -static void whisper_allocr_free(struct whisper_allocr & allocr) { - if (allocr.alloc) { - ggml_allocr_free(allocr.alloc); - ggml_backend_buffer_free(allocr.buffer); - allocr.alloc = nullptr; - } -} - struct whisper_state { int64_t t_sample_us = 0; int64_t t_encode_us = 0; @@ -814,11 +820,6 @@ struct whisper_state { std::vector result_all; std::vector prompt_past; - // work container used to avoid memory allocations - std::vector> logits_id; - - mutable std::mt19937 rng; // used for sampling at t > 0.0 - int lang_id = 0; // english by default std::string path_model; // populated by whisper_init_from_file_with_params() @@ -3079,16 +3080,17 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx); - state->logits_id.reserve(ctx->model.hparams.n_vocab); - state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS); // TAGS: WHISPER_DECODER_INIT state->decoders[0].sequence.tokens.reserve(ctx->model.hparams.n_text_ctx); - state->decoders[0].probs.reserve (ctx->vocab.n_vocab); - state->decoders[0].logits.reserve (ctx->vocab.n_vocab); - state->decoders[0].logprobs.reserve(ctx->vocab.n_vocab); + state->decoders[0].probs.reserve (ctx->vocab.n_vocab); + state->decoders[0].logits.reserve (ctx->vocab.n_vocab); + state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab); + state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab); + + state->decoders[0].rng = std::mt19937(0); // conv allocator { @@ -3143,8 +3145,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) { whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend); whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend); - state->rng = std::mt19937(0); - return state; } @@ -3624,7 +3624,7 @@ int whisper_lang_auto_detect_with_state( return -7; } - auto & logits_id = state->logits_id; + auto & logits_id = state->decoders[0].logits_id; logits_id.clear(); for (const auto & kv : g_lang) { @@ -4699,7 +4699,6 @@ static void whisper_process_logits( //WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob); if (timestamp_logprob > max_text_token_logprob) { - //printf("sampling timestamp\n"); for (int i = 0; i < vocab.token_beg; ++i) { logits[i] = -INFINITY; logprobs[i] = -INFINITY; @@ -4797,7 +4796,6 @@ static void whisper_process_logits( static whisper_token_data whisper_sample_token( whisper_context & ctx, - whisper_state & state, const whisper_decoder & decoder, bool best) { whisper_token_data result = { @@ -4842,7 +4840,7 @@ static whisper_token_data whisper_sample_token( } else { std::discrete_distribution<> dist(probs.begin(), probs.end()); - result.id = dist(state.rng); + result.id = dist(decoder.rng); result.p = probs[result.id]; result.plog = logprobs[result.id]; } @@ -4852,15 +4850,12 @@ static whisper_token_data whisper_sample_token( result.pt = result.p; } - state.n_sample++; - return result; } static std::vector whisper_sample_token_topk( whisper_context & ctx, - whisper_state & state, - const whisper_decoder & decoder, + whisper_decoder & decoder, int k) { const auto & vocab = ctx.vocab; @@ -4870,7 +4865,7 @@ static std::vector whisper_sample_token_topk( const int n_logits = vocab.n_vocab; - auto & logits_id = state.logits_id; + auto & logits_id = decoder.logits_id; logits_id.resize(n_logits); for (int i = 0; i < n_logits; ++i) { @@ -4919,7 +4914,7 @@ static std::vector whisper_sample_token_topk( std::discrete_distribution<> dist(probs.begin(), probs.end()); for (int i = 0; i < k; ++i) { - const auto id = dist(state.rng); + const auto id = dist(decoder.rng); //printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum); result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, }); @@ -4930,8 +4925,6 @@ static std::vector whisper_sample_token_topk( } } - state.n_sample++; - return result; } @@ -5082,6 +5075,9 @@ int whisper_full_with_state( decoder.probs.resize (ctx->vocab.n_vocab); decoder.logits.resize (ctx->vocab.n_vocab); decoder.logprobs.resize(ctx->vocab.n_vocab); + decoder.logits_id.reserve(ctx->model.hparams.n_vocab); + + decoder.rng = std::mt19937(0); } // the accumulated text context so far @@ -5161,6 +5157,7 @@ int whisper_full_with_state( whisper_grammar grammar; }; + std::vector> bc_per_dec(n_decoders); std::vector beam_candidates; // main loop @@ -5306,41 +5303,71 @@ int whisper_full_with_state( const int64_t t_start_sample_us = ggml_time_us(); if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) { - beam_candidates.clear(); + for (auto & bc : bc_per_dec) { + bc.clear(); + } } - // generate new sequence candidates for each decoder - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + { + std::atomic j_cur(0); - if (decoder.completed || decoder.failed) { - continue; + const int n_threads = std::min(params.n_threads, n_decoders_cur); + + std::vector threads(n_threads); + + for (int t = 0; t < n_threads; ++t) { + threads[t] = std::thread([&]() { + while (true) { + const int j = j_cur.fetch_add(1); + + if (j >= n_decoders_cur) { + break; + } + + auto & decoder = state->decoders[j]; + + if (decoder.completed || decoder.failed) { + continue; + } + + switch (params.strategy) { + case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: + { + if (t_cur < 1e-6f) { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true)); + } else { + decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false)); + } + + decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; + } break; + case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: + { + const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size); + + for (const auto & token : tokens_new) { + bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, }); + bc_per_dec[j].back().sequence.tokens.push_back(token); + bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog; + } + } break; + }; + } + }); } - switch (params.strategy) { - case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY: - { - if (t_cur < 1e-6f) { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true)); - } else { - decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false)); - } + for (auto & t : threads) { + t.join(); + } + } - decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog; - } break; - case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH: - { - const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size); + beam_candidates.clear(); + for (const auto & bc : bc_per_dec) { + beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end()); - for (const auto & token : tokens_new) { - beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, }); - beam_candidates.back().sequence.tokens.push_back(token); - beam_candidates.back().sequence.sum_logprobs_all += token.plog; - - //WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all); - } - } break; - }; + if (!bc.empty()) { + state->n_sample += 1; + } } // for beam-search, choose the top candidates and update the KV caches @@ -5536,14 +5563,40 @@ int whisper_full_with_state( const int64_t t_start_sample_us = ggml_time_us(); - for (int j = 0; j < n_decoders_cur; ++j) { - auto & decoder = state->decoders[j]; + { + std::atomic j_cur(0); - if (decoder.failed || decoder.completed) { - continue; + const int n_threads = std::min(params.n_threads, n_decoders_cur); + + std::vector threads(n_threads); + + auto process = [&]() { + while (true) { + const int j = j_cur.fetch_add(1); + + if (j >= n_decoders_cur) { + break; + } + + auto & decoder = state->decoders[j]; + + if (decoder.failed || decoder.completed) { + continue; + } + + whisper_process_logits(*ctx, *state, params, decoder, t_cur); + } + }; + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t] = std::thread(process); } - whisper_process_logits(*ctx, *state, params, decoder, t_cur); + process(); + + for (int t = 0; t < n_threads - 1; ++t) { + threads[t].join(); + } } state->t_sample_us += ggml_time_us() - t_start_sample_us;