whisper : rename suppress_non_speech_tokens to suppress_nst (#2653)
Some checks are pending
Bindings Tests (Ruby) / ubuntu-latest (push) Waiting to run
CI / ubuntu-latest (linux/amd64) (push) Waiting to run
CI / ubuntu-latest (linux/arm/v7) (push) Waiting to run
CI / ubuntu-latest (linux/arm64) (push) Waiting to run
CI / ubuntu-latest (linux/ppc64le) (push) Waiting to run
CI / macOS-latest (push) Waiting to run
CI / ubuntu-latest-gcc (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm/v7, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm/v7, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-gcc (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/amd64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/arm64, Release) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Debug) (push) Waiting to run
CI / ubuntu-latest-clang (linux/ppc64le, Release) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, ADDRESS) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, THREAD) (push) Waiting to run
CI / ubuntu-latest-gcc-sanitized (linux/amd64, UNDEFINED) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Waiting to run
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Waiting to run
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Waiting to run
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Waiting to run
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Waiting to run
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (Win32, ON, Release, x86, 2.28.5, ON) (push) Waiting to run
CI / windows-blas (x64, ON, Release, x64, 2.28.5, ON) (push) Waiting to run
CI / emscripten (Release) (push) Waiting to run
CI / ios-xcode-build (Release) (push) Waiting to run
CI / android (push) Waiting to run
CI / quantize (push) Waiting to run
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64,linux/arm64 tag:main]) (push) Waiting to run

This commit is contained in:
Georgi Gerganov 2024-12-21 12:54:35 +02:00 committed by GitHub
parent 944ce49439
commit f4668169a0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 31 additions and 27 deletions

View File

@ -181,11 +181,11 @@ public class WhisperFullParams extends Structure {
} }
/** Flag to suppress non-speech tokens. */ /** Flag to suppress non-speech tokens. */
public CBool suppress_non_speech_tokens; public CBool suppress_nst;
/** Flag to suppress non-speech tokens. */ /** Flag to suppress non-speech tokens. */
public void suppressNonSpeechTokens(boolean enable) { public void suppressNonSpeechTokens(boolean enable) {
suppress_non_speech_tokens = enable ? CBool.TRUE : CBool.FALSE; suppress_nst = enable ? CBool.TRUE : CBool.FALSE;
} }
/** Initial decoding temperature. */ /** Initial decoding temperature. */
@ -315,7 +315,7 @@ public class WhisperFullParams extends Structure {
"print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps", "print_special", "print_progress", "print_realtime", "print_timestamps", "token_timestamps",
"thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx", "thold_pt", "thold_ptsum", "max_len", "split_on_word", "max_tokens", "audio_ctx",
"tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language", "tdrz_enable", "suppress_regex", "initial_prompt", "prompt_tokens", "prompt_n_tokens", "language", "detect_language",
"suppress_blank", "suppress_non_speech_tokens", "temperature", "max_initial_ts", "length_penalty", "suppress_blank", "suppress_nst", "temperature", "max_initial_ts", "length_penalty",
"temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search", "temperature_inc", "entropy_thold", "logprob_thold", "no_speech_thold", "greedy", "beam_search",
"new_segment_callback", "new_segment_callback_user_data", "new_segment_callback", "new_segment_callback_user_data",
"progress_callback", "progress_callback_user_data", "progress_callback", "progress_callback_user_data",

View File

@ -979,19 +979,19 @@ static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) {
} }
/* /*
* call-seq: * call-seq:
* suppress_non_speech_tokens = force_suppress -> force_suppress * suppress_nst = force_suppress -> force_suppress
*/ */
static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) { static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) {
BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value) BOOL_PARAMS_SETTER(self, suppress_nst, value)
} }
/* /*
* If true, suppresses non-speech-tokens. * If true, suppresses non-speech-tokens.
* *
* call-seq: * call-seq:
* suppress_non_speech_tokens -> bool * suppress_nst -> bool
*/ */
static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) { static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) {
BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens) BOOL_PARAMS_GETTER(self, suppress_nst)
} }
/* /*
* If true, enables token-level timestamps. * If true, enables token-level timestamps.
@ -1832,8 +1832,8 @@ void Init_whisper() {
rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1); rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1);
rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0); rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0);
rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1); rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1);
rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0); rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0);
rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1); rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1);
rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0); rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0);
rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1);
rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0);

View File

@ -89,11 +89,11 @@ class TestParams < TestBase
assert !@params.suppress_blank assert !@params.suppress_blank
end end
def test_suppress_non_speech_tokens def test_suppress_nst
@params.suppress_non_speech_tokens = true @params.suppress_nst = true
assert @params.suppress_non_speech_tokens assert @params.suppress_nst
@params.suppress_non_speech_tokens = false @params.suppress_nst = false
assert !@params.suppress_non_speech_tokens assert !@params.suppress_nst
end end
def test_token_timestamps def test_token_timestamps

View File

@ -181,7 +181,7 @@ static json unguided_transcription(struct whisper_context * ctx, audio_async &au
wparams.n_threads = params.n_threads; wparams.n_threads = params.n_threads;
wparams.audio_ctx = params.audio_ctx; wparams.audio_ctx = params.audio_ctx;
wparams.suppress_non_speech_tokens = true; wparams.suppress_nst = true;
// run the transformer and a single decoding pass // run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__); fprintf(stderr, "%s: ERROR: whisper_full() failed\n", __func__);
@ -225,7 +225,7 @@ static json guided_transcription(struct whisper_context * ctx, audio_async &audi
wparams.prompt_tokens = cs.prompt_tokens.data(); wparams.prompt_tokens = cs.prompt_tokens.data();
wparams.prompt_n_tokens = cs.prompt_tokens.size(); wparams.prompt_n_tokens = cs.prompt_tokens.size();
// TODO: properly expose as option // TODO: properly expose as option
wparams.suppress_non_speech_tokens = true; wparams.suppress_nst = true;
// run the transformer and a single decoding pass // run the transformer and a single decoding pass
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {

View File

@ -76,7 +76,7 @@ struct whisper_params {
bool no_timestamps = false; bool no_timestamps = false;
bool use_gpu = true; bool use_gpu = true;
bool flash_attn = false; bool flash_attn = false;
bool suppress_non_speech_tokens = false; bool suppress_nst = false;
std::string language = "en"; std::string language = "en";
std::string prompt = ""; std::string prompt = "";
@ -136,7 +136,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str()); fprintf(stderr, " --request-path PATH, [%-7s] Request path for all requests\n", sparams.request_path.c_str());
fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str()); fprintf(stderr, " --inference-path PATH, [%-7s] Inference path for all requests\n", sparams.inference_path.c_str());
fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false"); fprintf(stderr, " --convert, [%-7s] Convert audio to WAV, requires ffmpeg on the server", sparams.ffmpeg_converter ? "true" : "false");
fprintf(stderr, " -sns, --suppress-non-speech [%-7s] suppress non-speech tokens\n", params.suppress_non_speech_tokens ? "true" : "false"); fprintf(stderr, " -sns, --suppress-nst [%-7s] suppress non-speech tokens\n", params.suppress_nst ? "true" : "false");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -181,7 +181,7 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params, serve
else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; } else if (arg == "-dtw" || arg == "--dtw") { params.dtw = argv[++i]; }
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
else if (arg == "-sns" || arg == "--suppress-non-speech") { params.suppress_non_speech_tokens = true; } else if (arg == "-sns" || arg == "--suppress-nst") { params.suppress_nst = true; }
// server params // server params
else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); } else if ( arg == "--port") { sparams.port = std::stoi(argv[++i]); }
else if ( arg == "--host") { sparams.hostname = argv[++i]; } else if ( arg == "--host") { sparams.hostname = argv[++i]; }
@ -477,7 +477,11 @@ void get_req_parameters(const Request & req, whisper_params & params)
} }
if (req.has_file("suppress_non_speech")) if (req.has_file("suppress_non_speech"))
{ {
params.suppress_non_speech_tokens = parse_str_to_bool(req.get_file_value("suppress_non_speech").content); params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_non_speech").content);
}
if (req.has_file("suppress_nst"))
{
params.suppress_nst = parse_str_to_bool(req.get_file_value("suppress_nst").content);
} }
} }
@ -793,7 +797,7 @@ int main(int argc, char ** argv) {
wparams.no_timestamps = params.no_timestamps; wparams.no_timestamps = params.no_timestamps;
wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format; wparams.token_timestamps = !params.no_timestamps && params.response_format == vjson_format;
wparams.suppress_non_speech_tokens = params.suppress_non_speech_tokens; wparams.suppress_nst = params.suppress_nst;
whisper_print_user_data user_data = { &params, &pcmf32s, 0 }; whisper_print_user_data user_data = { &params, &pcmf32s, 0 };

View File

@ -522,8 +522,8 @@ extern "C" {
bool detect_language; bool detect_language;
// common decoding parameters: // common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
bool suppress_non_speech_tokens; // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 bool suppress_nst; // non-speech tokens, ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478 float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97

View File

@ -4676,7 +4676,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.detect_language =*/ false, /*.detect_language =*/ false,
/*.suppress_blank =*/ true, /*.suppress_blank =*/ true,
/*.suppress_non_speech_tokens =*/ false, /*.suppress_nst =*/ false,
/*.temperature =*/ 0.0f, /*.temperature =*/ 0.0f,
/*.max_initial_ts =*/ 1.0f, /*.max_initial_ts =*/ 1.0f,
@ -4960,7 +4960,7 @@ static void whisper_process_logits(
// suppress non-speech tokens // suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 // ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
if (params.suppress_non_speech_tokens) { if (params.suppress_nst) {
for (const std::string & token : non_speech_tokens) { for (const std::string & token : non_speech_tokens) {
const std::string suppress_tokens[] = {token, " " + token}; const std::string suppress_tokens[] = {token, " " + token};
for (const std::string & suppress_token : suppress_tokens) { for (const std::string & suppress_token : suppress_tokens) {