diff --git a/main.cpp b/main.cpp index b913522e..995eefc1 100644 --- a/main.cpp +++ b/main.cpp @@ -216,7 +216,7 @@ int main(int argc, char ** argv) { // run the inference { - whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); wparams.print_realtime = true; wparams.print_progress = false; diff --git a/stream.cpp b/stream.cpp index 86b09a0e..3cfb178e 100644 --- a/stream.cpp +++ b/stream.cpp @@ -282,7 +282,7 @@ int main(int argc, char ** argv) { // run the inference { - whisper_full_params wparams = whisper_full_default_params(WHISPER_DECODE_GREEDY); + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); wparams.print_progress = false; wparams.print_special_tokens = params.print_special_tokens; diff --git a/whisper.cpp b/whisper.cpp index 61d0a8a2..236fcf1d 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2256,59 +2256,63 @@ void whisper_print_timings(struct whisper_context * ctx) { //////////////////////////////////////////////////////////////////////////// -struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy) { +struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy) { struct whisper_full_params result; switch (strategy) { - case WHISPER_DECODE_GREEDY: + case WHISPER_SAMPLING_GREEDY: { -#if defined(_MSC_VER) result = { -#else - result = (struct whisper_full_params) { -#endif - .strategy = WHISPER_DECODE_GREEDY, - .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()), - .offset_ms = 0, + /*.strategy =*/ WHISPER_SAMPLING_GREEDY, - .translate = false, - .no_context = false, - .print_special_tokens = false, - .print_progress = true, - .print_realtime = false, - .print_timestamps = true, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, - .language = "en", + /*.translate =*/ false, + /*.no_context =*/ false, + /*.print_special_tokens =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, - .greedy = { - .n_past = 0, + /*.language =*/ "en", + + /*.greedy =*/ { + /*.n_past =*/ 0, + }, + + /*.beam_search =*/ { + /*.n_past =*/ -1, + /*.beam_width =*/ -1, + /*.n_best =*/ -1, }, }; } break; - case WHISPER_DECODE_BEAM_SEARCH: + case WHISPER_SAMPLING_BEAM_SEARCH: { -#if defined(_MSC_VER) result = { -#else - result = (struct whisper_full_params) { -#endif - .strategy = WHISPER_DECODE_BEAM_SEARCH, - .n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()), - .offset_ms = 0, + /*.strategy =*/ WHISPER_SAMPLING_BEAM_SEARCH, - .translate = false, - .no_context = false, - .print_special_tokens = false, - .print_progress = true, - .print_realtime = false, - .print_timestamps = true, + /*.n_threads =*/ std::min(4, (int32_t) std::thread::hardware_concurrency()), + /*.offset_ms =*/ 0, - .language = "en", + /*.translate =*/ false, + /*.no_context =*/ false, + /*.print_special_tokens =*/ false, + /*.print_progress =*/ true, + /*.print_realtime =*/ false, + /*.print_timestamps =*/ true, - .beam_search = { - .n_past = 0, - .beam_width = 10, - .n_best = 5, + /*.language =*/ "en", + + /*.greedy =*/ { + /*.n_past =*/ -1, + }, + + /*.beam_search =*/ { + /*.n_past =*/ 0, + /*.beam_width =*/ 10, + /*.n_best =*/ 5, }, }; } break; diff --git a/whisper.h b/whisper.h index 381afd71..cf97e235 100644 --- a/whisper.h +++ b/whisper.h @@ -153,14 +153,14 @@ extern "C" { //////////////////////////////////////////////////////////////////////////// - // Available decoding strategies - enum whisper_decode_strategy { - WHISPER_DECODE_GREEDY, // Always select the most probable token - WHISPER_DECODE_BEAM_SEARCH, // TODO: not implemented yet! + // Available sampling strategies + enum whisper_sampling_strategy { + WHISPER_SAMPLING_GREEDY, // Always select the most probable token + WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet! }; struct whisper_full_params { - enum whisper_decode_strategy strategy; + enum whisper_sampling_strategy strategy; int n_threads; int offset_ms; @@ -174,20 +174,18 @@ extern "C" { const char * language; - union { - struct { - int n_past; - } greedy; + struct { + int n_past; + } greedy; - struct { - int n_past; - int beam_width; - int n_best; - } beam_search; - }; + struct { + int n_past; + int beam_width; + int n_best; + } beam_search; }; - WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_decode_strategy strategy); + WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy); // Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text // Uses the specified decoding strategy to obtain the text.