mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-02-28 11:20:39 +00:00
whisper : enable beam-search by default
This commit is contained in:
parent
4c245ea108
commit
b7c82a37b1
@ -62,8 +62,8 @@ struct whisper_params {
|
|||||||
int32_t progress_step = 5;
|
int32_t progress_step = 5;
|
||||||
int32_t max_context = -1;
|
int32_t max_context = -1;
|
||||||
int32_t max_len = 0;
|
int32_t max_len = 0;
|
||||||
int32_t best_of = 2;
|
int32_t best_of = whisper_full_default_params(WHISPER_SAMPLING_GREEDY).greedy.best_of;
|
||||||
int32_t beam_size = -1;
|
int32_t beam_size = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH).beam_search.beam_size;
|
||||||
|
|
||||||
float word_thold = 0.01f;
|
float word_thold = 0.01f;
|
||||||
float entropy_thold = 2.40f;
|
float entropy_thold = 2.40f;
|
||||||
@ -925,9 +925,9 @@ int main(int argc, char ** argv) {
|
|||||||
if (params.detect_language) {
|
if (params.detect_language) {
|
||||||
params.language = "auto";
|
params.language = "auto";
|
||||||
}
|
}
|
||||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
|
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, %d beams + best of %d, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||||
params.n_threads, params.n_processors,
|
params.n_threads, params.n_processors, params.beam_size, params.best_of,
|
||||||
params.language.c_str(),
|
params.language.c_str(),
|
||||||
params.translate ? "translate" : "transcribe",
|
params.translate ? "translate" : "transcribe",
|
||||||
params.tinydiarize ? "tdrz = 1, " : "",
|
params.tinydiarize ? "tdrz = 1, " : "",
|
||||||
|
17
whisper.cpp
17
whisper.cpp
@ -1263,6 +1263,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
word = "[_EOT_]";
|
word = "[_EOT_]";
|
||||||
} else if (i == vocab.token_sot) {
|
} else if (i == vocab.token_sot) {
|
||||||
word = "[_SOT_]";
|
word = "[_SOT_]";
|
||||||
|
} else if (i == vocab.token_translate) {
|
||||||
|
word = "[_TRANSLATE_]";
|
||||||
|
} else if (i == vocab.token_transcribe) {
|
||||||
|
word = "[_TRANSCRIBE_]";
|
||||||
} else if (i == vocab.token_solm) {
|
} else if (i == vocab.token_solm) {
|
||||||
word = "[_SOLM_]";
|
word = "[_SOLM_]";
|
||||||
} else if (i == vocab.token_prev) {
|
} else if (i == vocab.token_prev) {
|
||||||
@ -1273,6 +1277,8 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
|
|||||||
word = "[_NOT_]";
|
word = "[_NOT_]";
|
||||||
} else if (i == vocab.token_beg) {
|
} else if (i == vocab.token_beg) {
|
||||||
word = "[_BEG_]";
|
word = "[_BEG_]";
|
||||||
|
} else if (i > vocab.token_sot && i <= vocab.token_sot + vocab.num_languages()) {
|
||||||
|
word = "[_LANG_" + std::string(whisper_lang_str(i - vocab.token_sot - 1)) + "]";
|
||||||
} else {
|
} else {
|
||||||
word = "[_extra_token_" + std::to_string(i) + "]";
|
word = "[_extra_token_" + std::to_string(i) + "]";
|
||||||
}
|
}
|
||||||
@ -2182,7 +2188,6 @@ static struct ggml_cgraph * whisper_build_graph_decoder(
|
|||||||
const auto & model = wctx.model;
|
const auto & model = wctx.model;
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
// TODO: move to wstate
|
|
||||||
auto & kv_self = wstate.kv_self;
|
auto & kv_self = wstate.kv_self;
|
||||||
|
|
||||||
WHISPER_ASSERT(!!kv_self.ctx);
|
WHISPER_ASSERT(!!kv_self.ctx);
|
||||||
@ -4385,7 +4390,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
/*.max_initial_ts =*/ 1.0f,
|
/*.max_initial_ts =*/ 1.0f,
|
||||||
/*.length_penalty =*/ -1.0f,
|
/*.length_penalty =*/ -1.0f,
|
||||||
|
|
||||||
/*.temperature_inc =*/ 0.4f,
|
/*.temperature_inc =*/ 0.2f,
|
||||||
/*.entropy_thold =*/ 2.4f,
|
/*.entropy_thold =*/ 2.4f,
|
||||||
/*.logprob_thold =*/ -1.0f,
|
/*.logprob_thold =*/ -1.0f,
|
||||||
/*.no_speech_thold =*/ 0.6f,
|
/*.no_speech_thold =*/ 0.6f,
|
||||||
@ -4425,13 +4430,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
case WHISPER_SAMPLING_GREEDY:
|
case WHISPER_SAMPLING_GREEDY:
|
||||||
{
|
{
|
||||||
result.greedy = {
|
result.greedy = {
|
||||||
/*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
|
/*.best_of =*/ 5,
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
case WHISPER_SAMPLING_BEAM_SEARCH:
|
case WHISPER_SAMPLING_BEAM_SEARCH:
|
||||||
{
|
{
|
||||||
result.beam_search = {
|
result.beam_search = {
|
||||||
/*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
|
/*.beam_size =*/ 5,
|
||||||
|
|
||||||
/*.patience =*/ -1.0f,
|
/*.patience =*/ -1.0f,
|
||||||
};
|
};
|
||||||
@ -4713,7 +4718,8 @@ static void whisper_process_logits(
|
|||||||
logits[i] = -INFINITY;
|
logits[i] = -INFINITY;
|
||||||
logprobs[i] = -INFINITY;
|
logprobs[i] = -INFINITY;
|
||||||
}
|
}
|
||||||
} else if (params.n_grammar_rules > 0) {
|
} else {
|
||||||
|
if (params.n_grammar_rules > 0) {
|
||||||
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
|
whisper_suppress_invalid_grammar(ctx, params, logits, decoder.grammar);
|
||||||
|
|
||||||
// populate the logprobs array (log_softmax)
|
// populate the logprobs array (log_softmax)
|
||||||
@ -4738,6 +4744,7 @@ static void whisper_process_logits(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// compute probs
|
// compute probs
|
||||||
{
|
{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user