whisper : restore decoder temperature fallbacks

I disabled this because there were many complaints about slow decoding.
The current implementation does not allow batching the decoders when
using the "best of" or "beam size" parameters, so the decoding time is
proportional to the number of decoders, which is obviously not great.

However, now there are even more complaints about wrong decodings and
repetition.

So, making a compromise by re-enabling the fallbacks, but defaulting to
just 2 "best of" / "beam size" decoders. Also, the temperature step is
increased from 0.2 to 0.4 - i.e. from maximum of 5 fallbacks to maximum
of 2.

Also, the stream example now has fallbacks enabled by default.

close #471 #477 #508 #612 #719 #731
This commit is contained in:
Georgi Gerganov 2023-04-15 16:04:07 +03:00
parent ea1f8a50d4
commit f19e23fbd1
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 25 additions and 21 deletions

View File

@ -57,7 +57,7 @@ struct whisper_params {
int32_t duration_ms = 0; int32_t duration_ms = 0;
int32_t max_context = -1; int32_t max_context = -1;
int32_t max_len = 0; int32_t max_len = 0;
int32_t best_of = 5; int32_t best_of = 2;
int32_t beam_size = -1; int32_t beam_size = -1;
float word_thold = 0.01f; float word_thold = 0.01f;

View File

@ -43,6 +43,7 @@ struct whisper_params {
bool speed_up = false; bool speed_up = false;
bool translate = false; bool translate = false;
bool no_fallback = false;
bool print_special = false; bool print_special = false;
bool no_context = true; bool no_context = true;
bool no_timestamps = false; bool no_timestamps = false;
@ -73,6 +74,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); }
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
else if (arg == "-tr" || arg == "--translate") { params.translate = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; } else if (arg == "-kc" || arg == "--keep-context") { params.no_context = false; }
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
@ -94,22 +96,23 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "options:\n"); fprintf(stderr, "options:\n");
fprintf(stderr, " -h, --help [default] show this help message and exit\n"); fprintf(stderr, " -h, --help [default] show this help message and exit\n");
fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads);
fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms); fprintf(stderr, " --step N [%-7d] audio step size in milliseconds\n", params.step_ms);
fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms); fprintf(stderr, " --length N [%-7d] audio length in milliseconds\n", params.length_ms);
fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms); fprintf(stderr, " --keep N [%-7d] audio to keep from previous step in ms\n", params.keep_ms);
fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id);
fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens);
fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx);
fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold);
fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold);
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true"); fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false");
fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -kc, --keep-context [%-7s] keep context between audio chunks\n", params.no_context ? "false" : "true");
fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str());
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -297,7 +300,8 @@ int main(int argc, char ** argv) {
wparams.speed_up = params.speed_up; wparams.speed_up = params.speed_up;
// disable temperature fallback // disable temperature fallback
wparams.temperature_inc = -1.0f; //wparams.temperature_inc = -1.0f;
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data(); wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size(); wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();

View File

@ -3220,7 +3220,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.max_initial_ts =*/ 1.0f, /*.max_initial_ts =*/ 1.0f,
/*.length_penalty =*/ -1.0f, /*.length_penalty =*/ -1.0f,
/*.temperature_inc =*/ 0.0f, // TODO: temporary disabled until improve performance /*.temperature_inc =*/ 0.4f,
/*.entropy_thold =*/ 2.4f, /*.entropy_thold =*/ 2.4f,
/*.logprob_thold =*/ -1.0f, /*.logprob_thold =*/ -1.0f,
/*.no_speech_thold =*/ 0.6f, /*.no_speech_thold =*/ 0.6f,
@ -3252,13 +3252,13 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
case WHISPER_SAMPLING_GREEDY: case WHISPER_SAMPLING_GREEDY:
{ {
result.greedy = { result.greedy = {
/*.best_of =*/ 1, /*.best_of =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
}; };
} break; } break;
case WHISPER_SAMPLING_BEAM_SEARCH: case WHISPER_SAMPLING_BEAM_SEARCH:
{ {
result.beam_search = { result.beam_search = {
/*.beam_size =*/ 5, /*.beam_size =*/ 2, // TODO: increase to 5 when we speed-up batch decoding
/*.patience =*/ -1.0f, /*.patience =*/ -1.0f,
}; };