mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-28 03:12:48 +00:00
whisper : enable beam-search by default
This commit is contained in:
parent
4c245ea108
commit
b7c82a37b1
@ -62,8 +62,8 @@ struct whisper_params {
|
||||
int32_t progress_step = 5;
|
||||
int32_t max_context = -1;
|
||||
int32_t max_len = 0;
|
||||
int32_t best_of = 2;
|
||||
int32_t beam_size = -1;
|
||||
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
||||
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
||||
|
||||
float word_thold = 0.01f;
|
||||
float entropy_thold = 2.40f;
|
||||
@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
|
||||
if (params.detect_language) {
|
||||
params.language = "auto";
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors,
|
||||
params.n_threads, params.n_processors, params.beam_size, params.best_of,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.tinydiarize ? "tdrz = 1, " : "",
|
||||
|
47
whisper.cpp
47
whisper.cpp
@ -1263,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
||||
word = "[_EOT_]";
|
||||
} else if (i == vocab.token_sot) {
|
||||
word = "[_SOT_]";
|
||||
} else if (i == vocab.token_translate) {
|
||||
word = "[_TRANSLATE_]";
|
||||
} else if (i == vocab.token_transcribe) {
|
||||
word = "[_TRANSCRIBE_]";
|
||||
} else if (i == vocab.token_solm) {
|
||||
word = "[_SOLM_]";
|
||||
} else if (i == vocab.token_prev) {
|
||||
@ -1273,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
||||
word = "[_NOT_]";
|
||||
} else if (i == vocab.token_beg) {
|
||||
word = "[_BEG_]";
|
||||
} else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
|
||||
word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
|
||||
} else {
|
||||
word = "[_extra_token_" + std::to_string(i) + "]";
|
||||
}
|
||||
@ -2182,7 +2188,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
||||
const auto & model = wctx.model;
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
// TODO: move to wstate
|
||||
auto & kv_self = wstate.kv_self;
|
||||
|
||||
WHISPER_ASSERT(!!kv_self.ctx);
|
||||
@ -4385,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||
/*.max_initial_ts =*/ 1.0f,
|
||||
/*.length_penalty =*/ -1.0f,
|
||||
|
||||
/*.temperature_inc =*/ 0.4f,
|
||||
/*.temperature_inc =*/ 0.2f,
|
||||
/*.entropy_thold =*/ 2.4f,
|
||||
/*.logprob_thold =*/ -1.0f,
|
||||
/*.no_speech_thold =*/ 0.6f,
|
||||
@ -4425,13 +4430,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
||||
case WHISPER_SAMPLING_GREEDY:
|
||||
{
|
||||
result.greedy = {
|
||||
/*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
|
||||
/*.best_of =*/ 5,
|
||||
};
|
||||
} break;
|
||||
case WHISPER_SAMPLING_BEAM_SEARCH:
|
||||
{
|
||||
result.beam_search = {
|
||||
/*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
|
||||
/*.beam_size =*/ 5,
|
||||
|
||||
/*.patience =*/ -1.0f,
|
||||
};
|
||||
@ -4713,25 +4718,27 @@ static void whisper_process_logits(
|
||||
logits[i] = -INFINITY;
|
||||
logprobs[i] = -INFINITY;
|
||||
}
|
||||
} else if (params.n_grammar_rules > 0) {
|
||||
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
|
||||
} else {
|
||||
if (params.n_grammar_rules > 0) {
|
||||
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
|
||||
|
||||
// populate the logprobs array (log_softmax)
|
||||
{
|
||||
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
||||
float logsumexp = 0.0f;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
if (logits[i] > -INFINITY) {
|
||||
logsumexp += expf(logits[i] - logit_max);
|
||||
// populate the logprobs array (log_softmax)
|
||||
{
|
||||
const float logit_max = *std::max_element(logits.begin(), logits.end());
|
||||
float logsumexp = 0.0f;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
if (logits[i] > -INFINITY) {
|
||||
logsumexp += expf(logits[i] - logit_max);
|
||||
}
|
||||
}
|
||||
}
|
||||
logsumexp = logf(logsumexp) + logit_max;
|
||||
logsumexp = logf(logsumexp) + logit_max;
|
||||
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
if (logits[i] > -INFINITY) {
|
||||
logprobs[i] = logits[i] - logsumexp;
|
||||
} else {
|
||||
logprobs[i] = -INFINITY;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
if (logits[i] > -INFINITY) {
|
||||
logprobs[i] = logits[i] - logsumexp;
|
||||
} else {
|
||||
logprobs[i] = -INFINITY;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user