whisper : speed-up sampling

This commit is contained in:
Georgi Gerganov 2023-11-14 22:05:31 +02:00
parent d77603578b
commit 9006946e4b
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -20,6 +20,7 @@
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include <atomic>
#include <algorithm>
#include <cassert>
#define _USE_MATH_DEFINES
@ -459,6 +460,68 @@ static void whisper_batch_prep_legacy(whisper_batch & batch, const whisper_token
batch.logits[n_tokens - 1] = 1;
}
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
template<typename A, typename B>
struct whisper_pair {
A first;
B second;
// Define a constructor that takes two arguments.
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
// Define a constructor that takes no argument.
whisper_pair() : first(A()), second(B()) {}
};
// ggml_allocr wrapper for whisper usage
struct whisper_allocr {
ggml_allocr * alloc = nullptr;
std::vector<uint8_t> meta;
ggml_backend_buffer_t buffer;
};
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
}
// measure the memory usage of a graph and prepare the allocr's internal data buffer
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
auto & alloc = allocr.alloc;
auto & meta = allocr.meta;
alloc = ggml_allocr_new_measure_from_backend(backend);
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
ggml_allocr_alloc_graph(alloc, get_graph());
}
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
if (allocr.alloc == nullptr) {
// this can be null if we use external encoder like CoreML or OpenVINO
return;
}
auto & alloc = allocr.alloc;
auto & buffer = allocr.buffer;
size_t size = ggml_allocr_max_size(alloc);
ggml_allocr_free(alloc);
buffer = ggml_backend_alloc_buffer(backend, size);
alloc = ggml_allocr_new_from_buffer(buffer);
}
static void whisper_allocr_free(struct whisper_allocr & allocr) {
if (allocr.alloc) {
ggml_allocr_free(allocr.alloc);
ggml_backend_buffer_free(allocr.buffer);
allocr.alloc = nullptr;
}
}
// medium
// hparams: {
// 'n_mels': 80,
@ -699,70 +762,13 @@ struct whisper_decoder {
std::vector<float> probs;
std::vector<float> logits;
std::vector<float> logprobs;
// work container used to avoid memory allocations
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
mutable std::mt19937 rng; // used for sampling at t > 0.0
};
// replace std::pair by using customized pair struct (reason: std::pair is very slow)
template<typename A, typename B>
struct whisper_pair {
A first;
B second;
// Define a constructor that takes two arguments.
whisper_pair(const A& a, const B& b) : first(a), second(b) {}
// Define a constructor that takes no argument.
whisper_pair() : first(A()), second(B()) {}
};
// ggml_allocr wrapper for whisper usage
struct whisper_allocr {
ggml_allocr * alloc = nullptr;
std::vector<uint8_t> meta;
ggml_backend_buffer_t buffer;
};
static size_t whisper_allocr_size(struct whisper_allocr & allocr) {
return allocr.meta.size() + ggml_allocr_max_size(allocr.alloc);
}
// measure the memory usage of a graph and prepare the allocr's internal data buffer
static void whisper_allocr_graph_init(struct whisper_allocr & allocr, ggml_backend_t backend, std::function<struct ggml_cgraph *()> && get_graph) {
auto & alloc = allocr.alloc;
auto & meta = allocr.meta;
alloc = ggml_allocr_new_measure_from_backend(backend);
meta.resize(ggml_tensor_overhead()*WHISPER_MAX_NODES + ggml_graph_overhead());
ggml_allocr_alloc_graph(alloc, get_graph());
}
static void whisper_allocr_graph_realloc(struct whisper_allocr & allocr, ggml_backend_t backend) {
if (allocr.alloc == nullptr) {
// this can be null if we use external encoder like CoreML or OpenVINO
return;
}
auto & alloc = allocr.alloc;
auto & buffer = allocr.buffer;
size_t size = ggml_allocr_max_size(alloc);
ggml_allocr_free(alloc);
buffer = ggml_backend_alloc_buffer(backend, size);
alloc = ggml_allocr_new_from_buffer(buffer);
}
static void whisper_allocr_free(struct whisper_allocr & allocr) {
if (allocr.alloc) {
ggml_allocr_free(allocr.alloc);
ggml_backend_buffer_free(allocr.buffer);
allocr.alloc = nullptr;
}
}
struct whisper_state {
int64_t t_sample_us = 0;
int64_t t_encode_us = 0;
@ -814,11 +820,6 @@ struct whisper_state {
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
// work container used to avoid memory allocations
std::vector<whisper_pair<double, whisper_vocab::id>> logits_id;
mutable std::mt19937 rng; // used for sampling at t > 0.0
int lang_id = 0; // english by default
std::string path_model; // populated by whisper_init_from_file_with_params()
@ -3079,8 +3080,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->logits.reserve(ctx->vocab.n_vocab * ctx->model.hparams.n_text_ctx);
state->logits_id.reserve(ctx->model.hparams.n_vocab);
state->batch = whisper_batch_init(ctx->model.hparams.n_text_ctx, WHISPER_MAX_DECODERS);
// TAGS: WHISPER_DECODER_INIT
@ -3089,6 +3088,9 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
state->decoders[0].probs.reserve (ctx->vocab.n_vocab);
state->decoders[0].logits.reserve (ctx->vocab.n_vocab);
state->decoders[0].logprobs.reserve (ctx->vocab.n_vocab);
state->decoders[0].logits_id.reserve(ctx->model.hparams.n_vocab);
state->decoders[0].rng = std::mt19937(0);
// conv allocator
{
@ -3143,8 +3145,6 @@ struct whisper_state * whisper_init_state(whisper_context * ctx) {
whisper_allocr_graph_realloc(state->alloc_cross, ctx->backend);
whisper_allocr_graph_realloc(state->alloc_decode, ctx->backend);
state->rng = std::mt19937(0);
return state;
}
@ -3624,7 +3624,7 @@ int whisper_lang_auto_detect_with_state(
return -7;
}
auto & logits_id = state->logits_id;
auto & logits_id = state->decoders[0].logits_id;
logits_id.clear();
for (const auto & kv : g_lang) {
@ -4699,7 +4699,6 @@ static void whisper_process_logits(
//WHISPER_LOG_INFO("timestamp_logprob=%f max_text_token_logprob=%f\n", timestamp_logprob, max_text_token_logprob);
if (timestamp_logprob > max_text_token_logprob) {
//printf("sampling timestamp\n");
for (int i = 0; i < vocab.token_beg; ++i) {
logits[i] = -INFINITY;
logprobs[i] = -INFINITY;
@ -4797,7 +4796,6 @@ static void whisper_process_logits(
static whisper_token_data whisper_sample_token(
whisper_context & ctx,
whisper_state & state,
const whisper_decoder & decoder,
bool best) {
whisper_token_data result = {
@ -4842,7 +4840,7 @@ static whisper_token_data whisper_sample_token(
} else {
std::discrete_distribution<> dist(probs.begin(), probs.end());
result.id = dist(state.rng);
result.id = dist(decoder.rng);
result.p = probs[result.id];
result.plog = logprobs[result.id];
}
@ -4852,15 +4850,12 @@ static whisper_token_data whisper_sample_token(
result.pt = result.p;
}
state.n_sample++;
return result;
}
static std::vector<whisper_token_data> whisper_sample_token_topk(
whisper_context & ctx,
whisper_state & state,
const whisper_decoder & decoder,
whisper_decoder & decoder,
int k) {
const auto & vocab = ctx.vocab;
@ -4870,7 +4865,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
const int n_logits = vocab.n_vocab;
auto & logits_id = state.logits_id;
auto & logits_id = decoder.logits_id;
logits_id.resize(n_logits);
for (int i = 0; i < n_logits; ++i) {
@ -4919,7 +4914,7 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
std::discrete_distribution<> dist(probs.begin(), probs.end());
for (int i = 0; i < k; ++i) {
const auto id = dist(state.rng);
const auto id = dist(decoder.rng);
//printf("XXX %d %d %f %f %f %f\n", id, tid, probs[id], logprobs[id], pt, ptsum);
result.push_back({ id, tid, probs[id], logprobs[id], pt, ptsum, -1, -1, 0.0f, });
@ -4930,8 +4925,6 @@ static std::vector<whisper_token_data> whisper_sample_token_topk(
}
}
state.n_sample++;
return result;
}
@ -5082,6 +5075,9 @@ int whisper_full_with_state(
decoder.probs.resize (ctx->vocab.n_vocab);
decoder.logits.resize (ctx->vocab.n_vocab);
decoder.logprobs.resize(ctx->vocab.n_vocab);
decoder.logits_id.reserve(ctx->model.hparams.n_vocab);
decoder.rng = std::mt19937(0);
}
// the accumulated text context so far
@ -5161,6 +5157,7 @@ int whisper_full_with_state(
whisper_grammar grammar;
};
std::vector<std::vector<beam_candidate>> bc_per_dec(n_decoders);
std::vector<beam_candidate> beam_candidates;
// main loop
@ -5306,11 +5303,27 @@ int whisper_full_with_state(
const int64_t t_start_sample_us = ggml_time_us();
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
beam_candidates.clear();
for (auto & bc : bc_per_dec) {
bc.clear();
}
}
{
std::atomic<int> j_cur(0);
const int n_threads = std::min(params.n_threads, n_decoders_cur);
std::vector<std::thread> threads(n_threads);
for (int t = 0; t < n_threads; ++t) {
threads[t] = std::thread([&]() {
while (true) {
const int j = j_cur.fetch_add(1);
if (j >= n_decoders_cur) {
break;
}
// generate new sequence candidates for each decoder
for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
if (decoder.completed || decoder.failed) {
@ -5321,27 +5334,41 @@ int whisper_full_with_state(
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{
if (t_cur < 1e-6f) {
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, true));
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
} else {
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, *state, decoder, false));
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
}
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
} break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{
const auto tokens_new = whisper_sample_token_topk(*ctx, *state, decoder, params.beam_search.beam_size);
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
for (const auto & token : tokens_new) {
beam_candidates.push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
beam_candidates.back().sequence.tokens.push_back(token);
beam_candidates.back().sequence.sum_logprobs_all += token.plog;
//WHISPER_PRINT_DEBUG("%s: beam candidate: %s (%f, %f)\n", __func__, ctx->vocab.id_to_token.at(token.id).c_str(), token.plog, beam_candidates.back().sequence.sum_logprobs_all);
bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
bc_per_dec[j].back().sequence.tokens.push_back(token);
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
}
} break;
};
}
});
}
for (auto & t : threads) {
t.join();
}
}
beam_candidates.clear();
for (const auto & bc : bc_per_dec) {
beam_candidates.insert(beam_candidates.end(), bc.begin(), bc.end());
if (!bc.empty()) {
state->n_sample += 1;
}
}
// for beam-search, choose the top candidates and update the KV caches
if (params.strategy == whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH) {
@ -5536,7 +5563,21 @@ int whisper_full_with_state(
const int64_t t_start_sample_us = ggml_time_us();
for (int j = 0; j < n_decoders_cur; ++j) {
{
std::atomic<int> j_cur(0);
const int n_threads = std::min(params.n_threads, n_decoders_cur);
std::vector<std::thread> threads(n_threads);
auto process = [&]() {
while (true) {
const int j = j_cur.fetch_add(1);
if (j >= n_decoders_cur) {
break;
}
auto & decoder = state->decoders[j];
if (decoder.failed || decoder.completed) {
@ -5545,6 +5586,18 @@ int whisper_full_with_state(
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
}
};
for (int t = 0; t < n_threads - 1; ++t) {
threads[t] = std::thread(process);
}
process();
for (int t = 0; t < n_threads - 1; ++t) {
threads[t].join();
}
}
state->t_sample_us += ggml_time_us() - t_start_sample_us;
}