diff --git a/examples/main/main.cpp b/examples/main/main.cpp index e43dfe3f..98af5839 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -62,8 +62,8 @@ struct whisper_params { int32_t progress_step = 5; int32_t max_context = -1; int32_t max_len = 0; - int32_t best_of = 2; - int32_t beam_size = -1; + int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of; + int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size; float word_thold = 0.01f; float entropy_thold = 2.40f; @@ -925,9 +925,9 @@ int main(int argc, char ** argv) { if (params.detect_language) { params.language = "auto"; } - fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n", + fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n", __func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE, - params.n_threads, params.n_processors, + params.n_threads, params.n_processors, params.beam_size, params.best_of, params.language.c_str(), params.translate ? "translate" : "transcribe", params.tinydiarize ? "tdrz = 1, " : "", diff --git a/whisper.cpp b/whisper.cpp index edd97c7c..e2bfa41e 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -1263,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_EOT_]"; } else if (i == vocab.token_sot) { word = "[_SOT_]"; + } else if (i == vocab.token_translate) { + word = "[_TRANSLATE_]"; + } else if (i == vocab.token_transcribe) { + word = "[_TRANSCRIBE_]"; } else if (i == vocab.token_solm) { word = "[_SOLM_]"; } else if (i == vocab.token_prev) { @@ -1273,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con word = "[_NOT_]"; } else if (i == vocab.token_beg) { word = "[_BEG_]"; + } else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) { + word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]"; } else { word = "[_extra_token_" + std::to_string(i) + "]"; } @@ -2182,7 +2188,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder( const auto & model = wctx.model; const auto & hparams = model.hparams; - // TODO: move to wstate auto & kv_self = wstate.kv_self; WHISPER_ASSERT(!!kv_self.ctx); @@ -4385,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.max_initial_ts =*/ 1.0f, /*.length_penalty =*/ -1.0f, - /*.temperature_inc =*/ 0.4f, + /*.temperature_inc =*/ 0.2f, /*.entropy_thold =*/ 2.4f, /*.logprob_thold =*/ -1.0f, /*.no_speech_thold =*/ 0.6f, @@ -4425,13 +4430,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str case WHISPER_SAMPLING_GREEDY: { result.greedy = { - /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding + /*.best_of =*/ 5, }; } break; case WHISPER_SAMPLING_BEAM_SEARCH: { result.beam_search = { - /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding + /*.beam_size =*/ 5, /*.patience =*/ -1.0f, }; @@ -4713,25 +4718,27 @@ static void whisper_process_logits( logits[i] = -INFINITY; logprobs[i] = -INFINITY; } - } else if (params.n_grammar_rules > 0) { - whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); + } else { + if (params.n_grammar_rules > 0) { + whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar); - // populate the logprobs array (log_softmax) - { - const float logit_max = *std::max_element(logits.begin(), logits.end()); - float logsumexp = 0.0f; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { - logsumexp += expf(logits[i] - logit_max); + // populate the logprobs array (log_softmax) + { + const float logit_max = *std::max_element(logits.begin(), logits.end()); + float logsumexp = 0.0f; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logsumexp += expf(logits[i] - logit_max); + } } - } - logsumexp = logf(logsumexp) + logit_max; + logsumexp = logf(logsumexp) + logit_max; - for (int i = 0; i < n_logits; ++i) { - if (logits[i] > -INFINITY) { - logprobs[i] = logits[i] - logsumexp; - } else { - logprobs[i] = -INFINITY; + for (int i = 0; i < n_logits; ++i) { + if (logits[i] > -INFINITY) { + logprobs[i] = logits[i] - logsumexp; + } else { + logprobs[i] = -INFINITY; + } } } }