whisper : avoid starting sampling threads with bs=1

This commit is contained in:
Georgi Gerganov 2023-11-15 13:41:38 +02:00
parent 820f45895e
commit 4c245ea108
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735

View File

@ -4521,11 +4521,12 @@ static const std::vector<std::string> non_speech_tokens = {
// process the logits for the selected decoder
// - applies logit filters
// - computes logprobs and probs
// TODO: optimize
static void whisper_process_logits(
struct whisper_context & ctx,
struct whisper_state & state,
const struct whisper_full_params params,
struct whisper_decoder & decoder,
const struct whisper_full_params params,
float temperature) {
const auto & vocab = ctx.vocab;
const auto & tokens_cur = decoder.sequence.tokens;
@ -5297,7 +5298,7 @@ int whisper_full_with_state(
state->decoders[0].i_batch = prompt.size() - 1;
whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
for (int j = 1; j < n_decoders_cur; ++j) {
auto & decoder = state->decoders[j];
@ -5322,15 +5323,12 @@ int whisper_full_with_state(
}
}
// sampling
// TODO: avoid memory allocations, optimize, avoid threads?
{
std::atomic<int> j_cur(0);
const int n_threads = std::min(params.n_threads, n_decoders_cur);
std::vector<std::thread> threads(n_threads);
for (int t = 0; t < n_threads; ++t) {
threads[t] = std::thread([&]() {
auto process = [&]() {
while (true) {
const int j = j_cur.fetch_add(1);
@ -5367,11 +5365,24 @@ int whisper_full_with_state(
} break;
};
}
});
};
const int n_threads = std::min(params.n_threads, n_decoders_cur);
if (n_threads == 1) {
process();
} else {
std::vector<std::thread> threads(n_threads - 1);
for (int t = 0; t < n_threads - 1; ++t) {
threads[t] = std::thread(process);
}
for (auto & t : threads) {
t.join();
process();
for (int t = 0; t < n_threads - 1; ++t) {
threads[t].join();
}
}
}
@ -5577,13 +5588,10 @@ int whisper_full_with_state(
const int64_t t_start_sample_us = ggml_time_us();
// TODO: avoid memory allocations, optimize, avoid threads?
{
std::atomic<int> j_cur(0);
const int n_threads = std::min(params.n_threads, n_decoders_cur);
std::vector<std::thread> threads(n_threads);
auto process = [&]() {
while (true) {
const int j = j_cur.fetch_add(1);
@ -5598,10 +5606,17 @@ int whisper_full_with_state(
continue;
}
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
whisper_process_logits(*ctx, *state, decoder, params, t_cur);
}
};
const int n_threads = std::min(params.n_threads, n_decoders_cur);
if (n_threads == 1) {
process();
} else {
std::vector<std::thread> threads(n_threads - 1);
for (int t = 0; t < n_threads - 1; ++t) {
threads[t] = std::thread(process);
}
@ -5612,6 +5627,7 @@ int whisper_full_with_state(
threads[t].join();
}
}
}
state->t_sample_us += ggml_time_us() - t_start_sample_us;
}