diff --git a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java index 90d8c157..18c209fc 100644 --- a/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java +++ b/bindings/java/src/main/java/io/github/ggerganov/whispercpp/params/WhisperFullParams.java @@ -181,11 +181,11 @@ public class WhisperFullParams extends Structure { } /** Flag to suppress non-speech tokens. */ - public CBool suppress_non_speech_tokens; + public CBool suppress_nst; /** Flag to suppress non-speech tokens. */ public void suppressNonSpeechTokens(boolean enable) { - suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE; + suppress_nst = enable ? CBool.TRUE : CBool.FALSE; } /** Initial decoding temperature. */ @@ -315,7 +315,7 @@ public class WhisperFullParams extends Structure { "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx", "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", - "suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", + "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", "new_segment_callback", "new_segment_callback_user_data", "progress_callback", "progress_callback_user_data", diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 26e9def4..aa526577 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -979,19 +979,19 @@ static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) { } /* * call-seq: - * suppress_non_speech_tokens = force_suppress -> force_suppress + * suppress_nst = force_suppress -> force_suppress */ -static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value) +static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, suppress_nst, value) } /* * If true, suppresses non-speech-tokens. * * call-seq: - * suppress_non_speech_tokens -> bool + * suppress_nst -> bool */ -static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) { - BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens) +static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) { + BOOL_PARAMS_GETTER(self, suppress_nst) } /* * If true, enables token-level timestamps. @@ -1832,8 +1832,8 @@ void Init_whisper() { rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1); rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0); rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1); - rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0); - rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1); + rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0); + rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1); rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0); rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); diff --git a/bindings/ruby/tests/test_params.rb b/bindings/ruby/tests/test_params.rb index d2667ef0..7981bfaa 100644 --- a/bindings/ruby/tests/test_params.rb +++ b/bindings/ruby/tests/test_params.rb @@ -89,11 +89,11 @@ class TestParams < TestBase assert !@params.suppress_blank end - def test_suppress_non_speech_tokens - @params.suppress_non_speech_tokens = true - assert @params.suppress_non_speech_tokens - @params.suppress_non_speech_tokens = false - assert !@params.suppress_non_speech_tokens + def test_suppress_nst + @params.suppress_nst = true + assert @params.suppress_nst + @params.suppress_nst = false + assert !@params.suppress_nst end def test_token_timestamps diff --git a/examples/lsp/lsp.cpp b/examples/lsp/lsp.cpp index 1afc159f..803cd6d5 100644 --- a/examples/lsp/lsp.cpp +++ b/examples/lsp/lsp.cpp @@ -181,7 +181,7 @@ static json unguided_transcription(struct whisper_context * ctx, audio_async &au wparams.n_threads = params.n_threads; wparams.audio_ctx = params.audio_ctx; - wparams.suppress_non_speech_tokens = true; + wparams.suppress_nst = true; // run the transformer and a single decoding pass if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__); @@ -225,7 +225,7 @@ static json guided_transcription(struct whisper_context * ctx, audio_async &audi wparams.prompt_tokens = cs.prompt_tokens.data(); wparams.prompt_n_tokens = cs.prompt_tokens.size(); // TODO: properly expose as option - wparams.suppress_non_speech_tokens = true; + wparams.suppress_nst = true; // run the transformer and a single decoding pass if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f0484b34..0608bb6b 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -76,7 +76,7 @@ struct whisper_params { bool no_timestamps = false; bool use_gpu = true; bool flash_attn = false; - bool suppress_non_speech_tokens = false; + bool suppress_nst = false; std::string language = "en"; std::string prompt = ""; @@ -136,7 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str()); fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false"); - fprintf(stderr, " -sns, --suppress-non-speech [%-7s] suppress non-speech tokens\n", params.suppress_non_speech_tokens ? "true" : "false"); + fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false"); fprintf(stderr, "\n"); } @@ -181,7 +181,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } - else if (arg == "-sns" || arg == "--suppress-non-speech") { params.suppress_non_speech_tokens = true; } + else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; } // server params else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--host") { sparams.hostname = argv[++i]; } @@ -477,7 +477,11 @@ void get_req_parameters(const Request & req, whisper_params & params) } if (req.has_file("suppress_non_speech")) { - params.suppress_non_speech_tokens = parse_str_to_bool(req.get_file_value("suppress_non_speech").content); + params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_non_speech").content); + } + if (req.has_file("suppress_nst")) + { + params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content); } } @@ -793,7 +797,7 @@ int main(int argc, char ** argv) { wparams.no_timestamps = params.no_timestamps; wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format; - wparams.suppress_non_speech_tokens = params.suppress_non_speech_tokens; + wparams.suppress_nst = params.suppress_nst; whisper_print_user_data user_data = { ¶ms, &pcmf32s, 0 }; diff --git a/include/whisper.h b/include/whisper.h index 71949bdd..6e0db505 100644 --- a/include/whisper.h +++ b/include/whisper.h @@ -522,8 +522,8 @@ extern "C" { bool detect_language; // common decoding parameters: - bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 - bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + bool suppress_nst; // non-speech tokens, ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 diff --git a/src/whisper.cpp b/src/whisper.cpp index bcc530ae..ff200ea0 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -4676,7 +4676,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.detect_language =*/ false, /*.suppress_blank =*/ true, - /*.suppress_non_speech_tokens =*/ false, + /*.suppress_nst =*/ false, /*.temperature =*/ 0.0f, /*.max_initial_ts =*/ 1.0f, @@ -4960,7 +4960,7 @@ static void whisper_process_logits( // suppress non-speech tokens // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 - if (params.suppress_non_speech_tokens) { + if (params.suppress_nst) { for (const std::string & token : non_speech_tokens) { const std::string suppress_tokens[] = {token, " " + token}; for (const std::string & suppress_token : suppress_tokens) {