mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-19 04:37:51 +00:00
whisper : speed-up sampling
This commit is contained in:
parent
d77603578b
commit
9006946e4b
243
whisper.cpp
243
whisper.cpp
@ -20,6 +20,7 @@
|
||||
#include "ggml-alloc.h"
|
||||
#include "ggml-backend.h"
|
||||
|
||||
#include <atomic>
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#define _USE_MATH_DEFINES
|
||||
@ -459,6 +460,68 @@ static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token
|
||||
batch.logits[n_tokens - 1] = 1;
|
||||
}
|
||||
|
||||
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
||||
template<typename A, typename B>
|
||||
struct whisper_pair {
|
||||
A first;
|
||||
B second;
|
||||
|
||||
// Define a constructor that takes two arguments.
|
||||
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
||||
// Define a constructor that takes no argument.
|
||||
whisper_pair() : first(A()), second(B()) {}
|
||||
};
|
||||
|
||||
// ggml_allocr wrapper for whisper usage
|
||||
struct whisper_allocr {
|
||||
ggml_allocr * alloc = nullptr;
|
||||
|
||||
std::vector<uint8_t> meta;
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
};
|
||||
|
||||
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
||||
return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
|
||||
}
|
||||
|
||||
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
||||
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
|
||||
auto & alloc = allocr.alloc;
|
||||
auto & meta = allocr.meta;
|
||||
|
||||
alloc = ggml_allocr_new_measure_from_backend(backend);
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, get_graph());
|
||||
}
|
||||
|
||||
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
|
||||
if (allocr.alloc == nullptr) {
|
||||
// this can be null if we use external encoder like CoreML or OpenVINO
|
||||
return;
|
||||
}
|
||||
|
||||
auto & alloc = allocr.alloc;
|
||||
auto & buffer = allocr.buffer;
|
||||
|
||||
size_t size = ggml_allocr_max_size(alloc);
|
||||
|
||||
ggml_allocr_free(alloc);
|
||||
|
||||
buffer = ggml_backend_alloc_buffer(backend, size);
|
||||
alloc = ggml_allocr_new_from_buffer(buffer);
|
||||
}
|
||||
|
||||
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
||||
if (allocr.alloc) {
|
||||
ggml_allocr_free(allocr.alloc);
|
||||
ggml_backend_buffer_free(allocr.buffer);
|
||||
allocr.alloc = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
// medium
|
||||
// hparams: {
|
||||
// 'n_mels': 80,
|
||||
@ -699,70 +762,13 @@ struct whisper_decoder {
|
||||
std::vector<float> probs;
|
||||
std::vector<float> logits;
|
||||
std::vector<float> logprobs;
|
||||
|
||||
// work container used to avoid memory allocations
|
||||
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
||||
|
||||
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
||||
};
|
||||
|
||||
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
|
||||
template<typename A, typename B>
|
||||
struct whisper_pair {
|
||||
A first;
|
||||
B second;
|
||||
|
||||
// Define a constructor that takes two arguments.
|
||||
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
|
||||
// Define a constructor that takes no argument.
|
||||
whisper_pair() : first(A()), second(B()) {}
|
||||
};
|
||||
|
||||
// ggml_allocr wrapper for whisper usage
|
||||
struct whisper_allocr {
|
||||
ggml_allocr * alloc = nullptr;
|
||||
|
||||
std::vector<uint8_t> meta;
|
||||
|
||||
ggml_backend_buffer_t buffer;
|
||||
};
|
||||
|
||||
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
|
||||
return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
|
||||
}
|
||||
|
||||
// measure the memory usage of a graph and prepare the allocr's internal data buffer
|
||||
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
|
||||
auto & alloc = allocr.alloc;
|
||||
auto & meta = allocr.meta;
|
||||
|
||||
alloc = ggml_allocr_new_measure_from_backend(backend);
|
||||
|
||||
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
|
||||
|
||||
ggml_allocr_alloc_graph(alloc, get_graph());
|
||||
}
|
||||
|
||||
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
|
||||
if (allocr.alloc == nullptr) {
|
||||
// this can be null if we use external encoder like CoreML or OpenVINO
|
||||
return;
|
||||
}
|
||||
|
||||
auto & alloc = allocr.alloc;
|
||||
auto & buffer = allocr.buffer;
|
||||
|
||||
size_t size = ggml_allocr_max_size(alloc);
|
||||
|
||||
ggml_allocr_free(alloc);
|
||||
|
||||
buffer = ggml_backend_alloc_buffer(backend, size);
|
||||
alloc = ggml_allocr_new_from_buffer(buffer);
|
||||
}
|
||||
|
||||
static void whisper_allocr_free(struct whisper_allocr & allocr) {
|
||||
if (allocr.alloc) {
|
||||
ggml_allocr_free(allocr.alloc);
|
||||
ggml_backend_buffer_free(allocr.buffer);
|
||||
allocr.alloc = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
struct whisper_state {
|
||||
int64_t t_sample_us = 0;
|
||||
int64_t t_encode_us = 0;
|
||||
@ -814,11 +820,6 @@ struct whisper_state {
|
||||
std::vector<whisper_segment> result_all;
|
||||
std::vector<whisper_token> prompt_past;
|
||||
|
||||
// work container used to avoid memory allocations
|
||||
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
|
||||
|
||||
mutable std::mt19937 rng; // used for sampling at t > 0.0
|
||||
|
||||
int lang_id = 0; // english by default
|
||||
|
||||
std::string path_model; // populated by whisper_init_from_file_with_params()
|
||||
@ -3079,8 +3080,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
|
||||
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
|
||||
|
||||
state->logits_id.reserve(ctx->model.hparams.n_vocab);
|
||||
|
||||
state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
|
||||
|
||||
// TAGS: WHISPER_DECODER_INIT
|
||||
@ -3089,6 +3088,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
|
||||
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
|
||||
state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
|
||||
state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
|
||||
|
||||
state->decoders[0].rng = std::mt19937(0);
|
||||
|
||||
// conv allocator
|
||||
{
|
||||
@ -3143,8 +3145,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
|
||||
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
|
||||
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
|
||||
|
||||
state->rng = std::mt19937(0);
|
||||
|
||||
return state;
|
||||
}
|
||||
|
||||
@ -3624,7 +3624,7 @@ int whisper_lang_auto_detect_with_state(
|
||||
return -7;
|
||||
}
|
||||
|
||||
auto & logits_id = state->logits_id;
|
||||
auto & logits_id = state->decoders[0].logits_id;
|
||||
logits_id.clear();
|
||||
|
||||
for (const auto & kv : g_lang) {
|
||||
@ -4699,7 +4699,6 @@ static void whisper_process_logits(
|
||||
//WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
|
||||
|
||||
if (timestamp_logprob > max_text_token_logprob) {
|
||||
//printf("sampling timestamp\n");
|
||||
for (int i = 0; i < vocab.token_beg; ++i) {
|
||||
logits[i] = -INFINITY;
|
||||
logprobs[i] = -INFINITY;
|
||||
@ -4797,7 +4796,6 @@ static void whisper_process_logits(
|
||||
|
||||
static whisper_token_data whisper_sample_token(
|
||||
whisper_context & ctx,
|
||||
whisper_state & state,
|
||||
const whisper_decoder & decoder,
|
||||
bool best) {
|
||||
whisper_token_data result = {
|
||||
@ -4842,7 +4840,7 @@ static whisper_token_data whisper_sample_token(
|
||||
} else {
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
|
||||
result.id = dist(state.rng);
|
||||
result.id = dist(decoder.rng);
|
||||
result.p = probs[result.id];
|
||||
result.plog = logprobs[result.id];
|
||||
}
|
||||
@ -4852,15 +4850,12 @@ static whisper_token_data whisper_sample_token(
|
||||
result.pt = result.p;
|
||||
}
|
||||
|
||||
state.n_sample++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
static std::vector<whisper_token_data> whisper_sample_token_topk(
|
||||
whisper_context & ctx,
|
||||
whisper_state & state,
|
||||
const whisper_decoder & decoder,
|
||||
whisper_decoder & decoder,
|
||||
int k) {
|
||||
const auto & vocab = ctx.vocab;
|
||||
|
||||
@ -4870,7 +4865,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
||||
|
||||
const int n_logits = vocab.n_vocab;
|
||||
|
||||
auto & logits_id = state.logits_id;
|
||||
auto & logits_id = decoder.logits_id;
|
||||
|
||||
logits_id.resize(n_logits);
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
@ -4919,7 +4914,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
|
||||
for (int i = 0; i < k; ++i) {
|
||||
const auto id = dist(state.rng);
|
||||
const auto id = dist(decoder.rng);
|
||||
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
|
||||
|
||||
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
|
||||
@ -4930,8 +4925,6 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
|
||||
}
|
||||
}
|
||||
|
||||
state.n_sample++;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
@ -5082,6 +5075,9 @@ int whisper_full_with_state(
|
||||
decoder.probs.resize (ctx->vocab.n_vocab);
|
||||
decoder.logits.resize (ctx->vocab.n_vocab);
|
||||
decoder.logprobs.resize(ctx->vocab.n_vocab);
|
||||
decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
|
||||
|
||||
decoder.rng = std::mt19937(0);
|
||||
}
|
||||
|
||||
// the accumulated text context so far
|
||||
@ -5161,6 +5157,7 @@ int whisper_full_with_state(
|
||||
whisper_grammar grammar;
|
||||
};
|
||||
|
||||
std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
|
||||
std::vector<beam_candidate> beam_candidates;
|
||||
|
||||
// main loop
|
||||
@ -5306,11 +5303,27 @@ int whisper_full_with_state(
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
||||
beam_candidates.clear();
|
||||
for (auto & bc : bc_per_dec) {
|
||||
bc.clear();
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
std::atomic<int> j_cur(0);
|
||||
|
||||
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
||||
|
||||
std::vector<std::thread> threads(n_threads);
|
||||
|
||||
for (int t = 0; t < n_threads; ++t) {
|
||||
threads[t] = std::thread([&]() {
|
||||
while (true) {
|
||||
const int j = j_cur.fetch_add(1);
|
||||
|
||||
if (j >= n_decoders_cur) {
|
||||
break;
|
||||
}
|
||||
|
||||
// generate new sequence candidates for each decoder
|
||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
if (decoder.completed || decoder.failed) {
|
||||
@ -5321,27 +5334,41 @@ int whisper_full_with_state(
|
||||
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
||||
{
|
||||
if (t_cur < 1e-6f) {
|
||||
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
|
||||
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
|
||||
} else {
|
||||
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
|
||||
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
||||
}
|
||||
|
||||
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
||||
} break;
|
||||
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
||||
{
|
||||
const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
|
||||
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
|
||||
|
||||
for (const auto & token : tokens_new) {
|
||||
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
|
||||
beam_candidates.back().sequence.tokens.push_back(token);
|
||||
beam_candidates.back().sequence.sum_logprobs_all += token.plog;
|
||||
|
||||
//WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
|
||||
bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
|
||||
bc_per_dec[j].back().sequence.tokens.push_back(token);
|
||||
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
|
||||
}
|
||||
} break;
|
||||
};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto & t : threads) {
|
||||
t.join();
|
||||
}
|
||||
}
|
||||
|
||||
beam_candidates.clear();
|
||||
for (const auto & bc : bc_per_dec) {
|
||||
beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
|
||||
|
||||
if (!bc.empty()) {
|
||||
state->n_sample += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// for beam-search, choose the top candidates and update the KV caches
|
||||
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
|
||||
@ -5536,7 +5563,21 @@ int whisper_full_with_state(
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
for (int j = 0; j < n_decoders_cur; ++j) {
|
||||
{
|
||||
std::atomic<int> j_cur(0);
|
||||
|
||||
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
||||
|
||||
std::vector<std::thread> threads(n_threads);
|
||||
|
||||
auto process = [&]() {
|
||||
while (true) {
|
||||
const int j = j_cur.fetch_add(1);
|
||||
|
||||
if (j >= n_decoders_cur) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
if (decoder.failed || decoder.completed) {
|
||||
@ -5545,6 +5586,18 @@ int whisper_full_with_state(
|
||||
|
||||
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
|
||||
}
|
||||
};
|
||||
|
||||
for (int t = 0; t < n_threads - 1; ++t) {
|
||||
threads[t] = std::thread(process);
|
||||
}
|
||||
|
||||
process();
|
||||
|
||||
for (int t = 0; t < n_threads - 1; ++t) {
|
||||
threads[t].join();
|
||||
}
|
||||
}
|
||||
|
||||
state->t_sample_us += ggml_time_us() - t_start_sample_us;
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user