whisper : enable beam-search by default

This commit is contained in:
Georgi Gerganov 2023-11-15 15:36:45 +02:00
parent 4c245ea108
commit b7c82a37b1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 31 additions and 24 deletions

View File

@ -62,8 +62,8 @@ struct whisper_params {
int32_t progress_step = 5;
int32_t max_context = -1;
int32_t max_len = 0;
int32_t best_of = 2;
int32_t beam_size = -1;
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
float word_thold = 0.01f;
float entropy_thold = 2.40f;
@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
if (params.detect_language) {
params.language = "auto";
}
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors,
params.n_threads, params.n_processors, params.beam_size, params.best_of,
params.language.c_str(),
params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "",

View File

@ -1263,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
word = "[_EOT_]";
} else if (i == vocab.token_sot) {
word = "[_SOT_]";
} else if (i == vocab.token_translate) {
word = "[_TRANSLATE_]";
} else if (i == vocab.token_transcribe) {
word = "[_TRANSCRIBE_]";
} else if (i == vocab.token_solm) {
word = "[_SOLM_]";
} else if (i == vocab.token_prev) {
@ -1273,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
word = "[_NOT_]";
} else if (i == vocab.token_beg) {
word = "[_BEG_]";
} else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
} else {
word = "[_extra_token_" + std::to_string(i) + "]";
}
@ -2182,7 +2188,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
const auto & model = wctx.model;
const auto & hparams = model.hparams;
// TODO: move to wstate
auto & kv_self = wstate.kv_self;
WHISPER_ASSERT(!!kv_self.ctx);
@ -4385,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.max_initial_ts =*/ 1.0f,
/*.length_penalty =*/ -1.0f,
/*.temperature_inc =*/ 0.4f,
/*.temperature_inc =*/ 0.2f,
/*.entropy_thold =*/ 2.4f,
/*.logprob_thold =*/ -1.0f,
/*.no_speech_thold =*/ 0.6f,
@ -4425,13 +4430,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
case WHISPER_SAMPLING_GREEDY:
{
result.greedy = {
/*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
/*.best_of =*/ 5,
};
} break;
case WHISPER_SAMPLING_BEAM_SEARCH:
{
result.beam_search = {
/*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
/*.beam_size =*/ 5,
/*.patience =*/ -1.0f,
};
@ -4713,25 +4718,27 @@ static void whisper_process_logits(
logits[i] = -INFINITY;
logprobs[i] = -INFINITY;
}
} else if (params.n_grammar_rules > 0) {
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
} else {
if (params.n_grammar_rules > 0) {
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
// populate the logprobs array (log_softmax)
{
const float logit_max = *std::max_element(logits.begin(), logits.end());
float logsumexp = 0.0f;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logsumexp += expf(logits[i] - logit_max);
// populate the logprobs array (log_softmax)
{
const float logit_max = *std::max_element(logits.begin(), logits.end());
float logsumexp = 0.0f;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logsumexp += expf(logits[i] - logit_max);
}
}
}
logsumexp = logf(logsumexp) + logit_max;
logsumexp = logf(logsumexp) + logit_max;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logprobs[i] = logits[i] - logsumexp;
} else {
logprobs[i] = -INFINITY;
for (int i = 0; i < n_logits; ++i) {
if (logits[i] > -INFINITY) {
logprobs[i] = logits[i] - logsumexp;
} else {
logprobs[i] = -INFINITY;
}
}
}
}