whisper : enable beam-search by default

This commit is contained in:
Georgi Gerganov 2023-11-15 15:36:45 +02:00
parent 4c245ea108
commit b7c82a37b1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 31 additions and 24 deletions

View File

@ -62,8 +62,8 @@ struct whisper_params {
int32_t progress_step = 5; int32_t progress_step = 5;
int32_t max_context = -1; int32_t max_context = -1;
int32_t max_len = 0; int32_t max_len = 0;
int32_t best_of = 2; int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
int32_t beam_size = -1; int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
float word_thold = 0.01f; float word_thold = 0.01f;
float entropy_thold = 2.40f; float entropy_thold = 2.40f;
@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
if (params.detect_language) { if (params.detect_language) {
params.language = "auto"; params.language = "auto";
} }
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n", fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
params.n_threads, params.n_processors, params.n_threads, params.n_processors, params.beam_size, params.best_of,
params.language.c_str(), params.language.c_str(),
params.translate ? "translate" : "transcribe", params.translate ? "translate" : "transcribe",
params.tinydiarize ? "tdrz = 1, " : "", params.tinydiarize ? "tdrz = 1, " : "",

View File

@ -1263,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
word = "[_EOT_]"; word = "[_EOT_]";
} else if (i == vocab.token_sot) { } else if (i == vocab.token_sot) {
word = "[_SOT_]"; word = "[_SOT_]";
} else if (i == vocab.token_translate) {
word = "[_TRANSLATE_]";
} else if (i == vocab.token_transcribe) {
word = "[_TRANSCRIBE_]";
} else if (i == vocab.token_solm) { } else if (i == vocab.token_solm) {
word = "[_SOLM_]"; word = "[_SOLM_]";
} else if (i == vocab.token_prev) { } else if (i == vocab.token_prev) {
@ -1273,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
word = "[_NOT_]"; word = "[_NOT_]";
} else if (i == vocab.token_beg) { } else if (i == vocab.token_beg) {
word = "[_BEG_]"; word = "[_BEG_]";
} else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
} else { } else {
word = "[_extra_token_" + std::to_string(i) + "]"; word = "[_extra_token_" + std::to_string(i) + "]";
} }
@ -2182,7 +2188,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
const auto & model = wctx.model; const auto & model = wctx.model;
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
// TODO: move to wstate
auto & kv_self = wstate.kv_self; auto & kv_self = wstate.kv_self;
WHISPER_ASSERT(!!kv_self.ctx); WHISPER_ASSERT(!!kv_self.ctx);
@ -4385,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.max_initial_ts =*/ 1.0f, /*.max_initial_ts =*/ 1.0f,
/*.length_penalty =*/ -1.0f, /*.length_penalty =*/ -1.0f,
/*.temperature_inc =*/ 0.4f, /*.temperature_inc =*/ 0.2f,
/*.entropy_thold =*/ 2.4f, /*.entropy_thold =*/ 2.4f,
/*.logprob_thold =*/ -1.0f, /*.logprob_thold =*/ -1.0f,
/*.no_speech_thold =*/ 0.6f, /*.no_speech_thold =*/ 0.6f,
@ -4425,13 +4430,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
case WHISPER_SAMPLING_GREEDY: case WHISPER_SAMPLING_GREEDY:
{ {
result.greedy = { result.greedy = {
/*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding /*.best_of =*/ 5,
}; };
} break; } break;
case WHISPER_SAMPLING_BEAM_SEARCH: case WHISPER_SAMPLING_BEAM_SEARCH:
{ {
result.beam_search = { result.beam_search = {
/*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding /*.beam_size =*/ 5,
/*.patience =*/ -1.0f, /*.patience =*/ -1.0f,
}; };
@ -4713,7 +4718,8 @@ static void whisper_process_logits(
logits[i] = -INFINITY; logits[i] = -INFINITY;
logprobs[i] = -INFINITY; logprobs[i] = -INFINITY;
} }
} else if (params.n_grammar_rules > 0) { } else {
if (params.n_grammar_rules > 0) {
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
// populate the logprobs array (log_softmax) // populate the logprobs array (log_softmax)
@ -4738,6 +4744,7 @@ static void whisper_process_logits(
} }
} }
} }
}
// compute probs // compute probs
{ {