whisper : avoid starting sampling threads with bs=1

This commit is contained in:
Georgi Gerganov 2023-11-15 13:41:38 +02:00
parent 820f45895e
commit 4c245ea108
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -4521,11 +4521,12 @@ static const std::vector<std::string> non_speech_tokens = {
// process the logits for the selected decoder
// - applies logit filters
// - computes logprobs and probs
// TODO: optimize
static void whisper_process_logits(
struct whisper_context & ctx,
struct whisper_state & state,
const struct whisper_full_params params,
struct whisper_decoder & decoder,
const struct whisper_full_params params,
float temperature) {
const auto & vocab = ctx.vocab;
const auto & tokens_cur = decoder.sequence.tokens;
@ -5297,7 +5298,7 @@ int whisper_full_with_state(
state->decoders[0].i_batch = prompt.size() - 1;
whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
@ -5322,56 +5323,66 @@ int whisper_full_with_state(
}
}
// sampling
// TODO: avoid memory allocations, optimize, avoid threads?
{
std::atomic<int> j_cur(0);
auto process = [&]() {
while (true) {
const int j = j_cur.fetch_add(1);
if (j >= n_decoders_cur) {
break;
}
auto & decoder = state->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
switch (params.strategy) {
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{
if (t_cur < 1e-6f) {
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
} else {
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
}
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
} break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
for (const auto & token : tokens_new) {
bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
bc_per_dec[j].back().sequence.tokens.push_back(token);
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
}
} break;
};
}
};
const int n_threads = std::min(params.n_threads, n_decoders_cur);
std::vector<std::thread> threads(n_threads);
if (n_threads == 1) {
process();
} else {
std::vector<std::thread> threads(n_threads - 1);
for (int t = 0; t < n_threads; ++t) {
threads[t] = std::thread([&]() {
while (true) {
const int j = j_cur.fetch_add(1);
for (int t = 0; t < n_threads - 1; ++t) {
threads[t] = std::thread(process);
}
if (j >= n_decoders_cur) {
break;
}
process();
auto & decoder = state->decoders[j];
if (decoder.completed || decoder.failed) {
continue;
}
switch (params.strategy) {
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
{
if (t_cur < 1e-6f) {
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
} else {
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
}
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
} break;
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
{
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
for (const auto & token : tokens_new) {
bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
bc_per_dec[j].back().sequence.tokens.push_back(token);
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
}
} break;
};
}
});
}
for (auto & t : threads) {
t.join();
for (int t = 0; t < n_threads - 1; ++t) {
threads[t].join();
}
}
}
@ -5577,13 +5588,10 @@ int whisper_full_with_state(
const int64_t t_start_sample_us = ggml_time_us();
// TODO: avoid memory allocations, optimize, avoid threads?
{
std::atomic<int> j_cur(0);
const int n_threads = std::min(params.n_threads, n_decoders_cur);
std::vector<std::thread> threads(n_threads);
auto process = [&]() {
while (true) {
const int j = j_cur.fetch_add(1);
@ -5598,18 +5606,26 @@ int whisper_full_with_state(
continue;
}
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
whisper_process_logits(*ctx, *state, decoder, params, t_cur);
}
};
for (int t = 0; t < n_threads - 1; ++t) {
threads[t] = std::thread(process);
}
const int n_threads = std::min(params.n_threads, n_decoders_cur);
process();
if (n_threads == 1) {
process();
} else {
std::vector<std::thread> threads(n_threads - 1);
for (int t = 0; t < n_threads - 1; ++t) {
threads[t].join();
for (int t = 0; t < n_threads - 1; ++t) {
threads[t] = std::thread(process);
}
process();
for (int t = 0; t < n_threads - 1; ++t) {
threads[t].join();
}
}
}