mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
whisper : avoid starting sampling threads with bs=1
This commit is contained in:
parent
820f45895e
commit
4c245ea108
126
whisper.cpp
126
whisper.cpp
@ -4521,11 +4521,12 @@ static const std::vector<std::string> non_speech_tokens = {
|
||||
// process the logits for the selected decoder
|
||||
// - applies logit filters
|
||||
// - computes logprobs and probs
|
||||
// TODO: optimize
|
||||
static void whisper_process_logits(
|
||||
struct whisper_context & ctx,
|
||||
struct whisper_state & state,
|
||||
const struct whisper_full_params params,
|
||||
struct whisper_decoder & decoder,
|
||||
const struct whisper_full_params params,
|
||||
float temperature) {
|
||||
const auto & vocab = ctx.vocab;
|
||||
const auto & tokens_cur = decoder.sequence.tokens;
|
||||
@ -5297,7 +5298,7 @@ int whisper_full_with_state(
|
||||
|
||||
state->decoders[0].i_batch = prompt.size() - 1;
|
||||
|
||||
whisper_process_logits(*ctx, *state, params, state->decoders[0], t_cur);
|
||||
whisper_process_logits(*ctx, *state, state->decoders[0], params, t_cur);
|
||||
|
||||
for (int j = 1; j < n_decoders_cur; ++j) {
|
||||
auto & decoder = state->decoders[j];
|
||||
@ -5322,56 +5323,66 @@ int whisper_full_with_state(
|
||||
}
|
||||
}
|
||||
|
||||
// sampling
|
||||
// TODO: avoid memory allocations, optimize, avoid threads?
|
||||
{
|
||||
std::atomic<int> j_cur(0);
|
||||
|
||||
auto process = [&]() {
|
||||
while (true) {
|
||||
const int j = j_cur.fetch_add(1);
|
||||
|
||||
if (j >= n_decoders_cur) {
|
||||
break;
|
||||
}
|
||||
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
if (decoder.completed || decoder.failed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (params.strategy) {
|
||||
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
||||
{
|
||||
if (t_cur < 1e-6f) {
|
||||
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
|
||||
} else {
|
||||
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
||||
}
|
||||
|
||||
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
||||
} break;
|
||||
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
||||
{
|
||||
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
|
||||
|
||||
for (const auto & token : tokens_new) {
|
||||
bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
|
||||
bc_per_dec[j].back().sequence.tokens.push_back(token);
|
||||
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
|
||||
}
|
||||
} break;
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
||||
|
||||
std::vector<std::thread> threads(n_threads);
|
||||
if (n_threads == 1) {
|
||||
process();
|
||||
} else {
|
||||
std::vector<std::thread> threads(n_threads - 1);
|
||||
|
||||
for (int t = 0; t < n_threads; ++t) {
|
||||
threads[t] = std::thread([&]() {
|
||||
while (true) {
|
||||
const int j = j_cur.fetch_add(1);
|
||||
for (int t = 0; t < n_threads - 1; ++t) {
|
||||
threads[t] = std::thread(process);
|
||||
}
|
||||
|
||||
if (j >= n_decoders_cur) {
|
||||
break;
|
||||
}
|
||||
process();
|
||||
|
||||
auto & decoder = state->decoders[j];
|
||||
|
||||
if (decoder.completed || decoder.failed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
switch (params.strategy) {
|
||||
case whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY:
|
||||
{
|
||||
if (t_cur < 1e-6f) {
|
||||
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, true));
|
||||
} else {
|
||||
decoder.sequence.tokens.push_back(whisper_sample_token(*ctx, decoder, false));
|
||||
}
|
||||
|
||||
decoder.sequence.sum_logprobs_all += decoder.sequence.tokens.back().plog;
|
||||
} break;
|
||||
case whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH:
|
||||
{
|
||||
const auto tokens_new = whisper_sample_token_topk(*ctx, decoder, params.beam_search.beam_size);
|
||||
|
||||
for (const auto & token : tokens_new) {
|
||||
bc_per_dec[j].push_back({ j, decoder.seek_delta, decoder.has_ts, decoder.sequence, decoder.grammar, });
|
||||
bc_per_dec[j].back().sequence.tokens.push_back(token);
|
||||
bc_per_dec[j].back().sequence.sum_logprobs_all += token.plog;
|
||||
}
|
||||
} break;
|
||||
};
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for (auto & t : threads) {
|
||||
t.join();
|
||||
for (int t = 0; t < n_threads - 1; ++t) {
|
||||
threads[t].join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -5577,13 +5588,10 @@ int whisper_full_with_state(
|
||||
|
||||
const int64_t t_start_sample_us = ggml_time_us();
|
||||
|
||||
// TODO: avoid memory allocations, optimize, avoid threads?
|
||||
{
|
||||
std::atomic<int> j_cur(0);
|
||||
|
||||
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
||||
|
||||
std::vector<std::thread> threads(n_threads);
|
||||
|
||||
auto process = [&]() {
|
||||
while (true) {
|
||||
const int j = j_cur.fetch_add(1);
|
||||
@ -5598,18 +5606,26 @@ int whisper_full_with_state(
|
||||
continue;
|
||||
}
|
||||
|
||||
whisper_process_logits(*ctx, *state, params, decoder, t_cur);
|
||||
whisper_process_logits(*ctx, *state, decoder, params, t_cur);
|
||||
}
|
||||
};
|
||||
|
||||
for (int t = 0; t < n_threads - 1; ++t) {
|
||||
threads[t] = std::thread(process);
|
||||
}
|
||||
const int n_threads = std::min(params.n_threads, n_decoders_cur);
|
||||
|
||||
process();
|
||||
if (n_threads == 1) {
|
||||
process();
|
||||
} else {
|
||||
std::vector<std::thread> threads(n_threads - 1);
|
||||
|
||||
for (int t = 0; t < n_threads - 1; ++t) {
|
||||
threads[t].join();
|
||||
for (int t = 0; t < n_threads - 1; ++t) {
|
||||
threads[t] = std::thread(process);
|
||||
}
|
||||
|
||||
process();
|
||||
|
||||
for (int t = 0; t < n_threads - 1; ++t) {
|
||||
threads[t].join();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user