From f141b2b938ebf56384dbbe04e595e1b5a9d1104c Mon Sep 17 00:00:00 2001 From: Daniel Ziegenberg <daniel@ziegenberg.at> Date: Mon, 13 May 2024 13:59:44 +0200 Subject: [PATCH] main : add options for temperature control (#2088) Add two options: ``` -tp, --temperature N [0.00 ] The sampling temperature, between 0 and 1 -tpi, --temperature-inc N [0.20 ] The increment of temperature, between 0 and 1 ``` The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use log probability to automatically increase the temperature until certain thresholds are hit. Signed-off-by: Daniel Ziegenberg <daniel@ziegenberg.at> --- examples/main/main.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6a3db73d..bb193186 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -44,6 +44,8 @@ struct whisper_params { float entropy_thold = 2.40f; float logprob_thold = -1.00f; float grammar_penalty = 100.0f; + float temperature = 0.0f; + float temperature_inc = 0.2f; bool speed_up = false; bool debug_mode = false; @@ -133,6 +135,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); } else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); } + else if (arg == "-tp" || arg == "--temperature") { params.temperature = std::stof(argv[++i]); } + else if (arg == "-tpi" || arg == "--temperature-inc") { params.temperature_inc = std::stof(argv[++i]); } // else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } else if (arg == "-debug"|| arg == "--debug-mode") { params.debug_mode = true; } else if (arg == "-tr" || arg == "--translate") { params.translate = true; } @@ -198,6 +202,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -wt N, --word-thold N [%-7.2f] word timestamp probability threshold\n", params.word_thold); fprintf(stderr, " -et N, --entropy-thold N [%-7.2f] entropy threshold for decoder fail\n", params.entropy_thold); fprintf(stderr, " -lpt N, --logprob-thold N [%-7.2f] log probability threshold for decoder fail\n", params.logprob_thold); + fprintf(stderr, " -tp, --temperature N [%-7.2f] The sampling temperature, between 0 and 1\n", params.temperature); + fprintf(stderr, " -tpi, --temperature-inc N [%-7.2f] The increment of temperature, between 0 and 1\n",params.temperature_inc); // fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); fprintf(stderr, " -debug, --debug-mode [%-7s] enable debug mode (eg. dump log_mel)\n", params.debug_mode ? "true" : "false"); fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); @@ -1107,7 +1113,9 @@ int main(int argc, char ** argv) { wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; - wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc; + wparams.temperature_inc = params.no_fallback ? 0.0f : params.temperature_inc; + wparams.temperature = params.temperature; + wparams.entropy_thold = params.entropy_thold; wparams.logprob_thold = params.logprob_thold;