mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-12 20:18:08 +00:00
whisper : support speaker segmentation (local diarization) of mono audio via tinydiarize (#1058)
* add HuggingFace mirror to download ggml model * support tdrz via simple hack overriding solm tokens * fix incorrect translate/transcribe token_ids that are not static const * add apollo 13 sample for tdrz demo * render [SPEAKER TURN] consistently in all terminal output using vocab.id_to_token * extend whisper_segment with speaker_turn_next field and save in json output * fix failing go build * slipped in some python syntax whoops * whisper : finalize tinydiarize support (add flag + fixes) * whisper : tdrz support for word-level timestamps (respect max_len) * java : try to fix tests after adding tdrz_enable flag * main : remove TODO leftover * java : fix params order list after adding "tdrz_enable" * whisper : fix solm and add nosp token * main : print tinydiarize help --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
@ -68,28 +68,32 @@ struct whisper_params {
|
||||
float entropy_thold = 2.40f;
|
||||
float logprob_thold = -1.00f;
|
||||
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool detect_language= false;
|
||||
bool diarize = false;
|
||||
bool split_on_word = false;
|
||||
bool no_fallback = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool output_jsn = false;
|
||||
bool output_lrc = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
bool speed_up = false;
|
||||
bool translate = false;
|
||||
bool detect_language = false;
|
||||
bool diarize = false;
|
||||
bool tinydiarize = false;
|
||||
bool split_on_word = false;
|
||||
bool no_fallback = false;
|
||||
bool output_txt = false;
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool output_wts = false;
|
||||
bool output_csv = false;
|
||||
bool output_jsn = false;
|
||||
bool output_lrc = false;
|
||||
bool print_special = false;
|
||||
bool print_colors = false;
|
||||
bool print_progress = false;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
std::string language = "en";
|
||||
std::string prompt;
|
||||
std::string font_path = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
std::string model = "models/ggml-base.en.bin";
|
||||
|
||||
// [TDRZ] speaker turn string
|
||||
std::string tdrz_speaker_turn = " [SPEAKER_TURN]"; // TODO: set from command line
|
||||
|
||||
std::vector<std::string> fname_inp = {};
|
||||
std::vector<std::string> fname_out = {};
|
||||
@ -115,41 +119,42 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
|
||||
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
|
||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-dl" || arg == "--detect-language"){ params.detect_language= true; }
|
||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||
else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); }
|
||||
else if (arg == "-p" || arg == "--processors") { params.n_processors = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ot" || arg == "--offset-t") { params.offset_t_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-on" || arg == "--offset-n") { params.offset_n = std::stoi(argv[++i]); }
|
||||
else if (arg == "-d" || arg == "--duration") { params.duration_ms = std::stoi(argv[++i]); }
|
||||
else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); }
|
||||
else if (arg == "-ml" || arg == "--max-len") { params.max_len = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bo" || arg == "--best-of") { params.best_of = std::stoi(argv[++i]); }
|
||||
else if (arg == "-bs" || arg == "--beam-size") { params.beam_size = std::stoi(argv[++i]); }
|
||||
else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-et" || arg == "--entropy-thold") { params.entropy_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-lpt" || arg == "--logprob-thold") { params.logprob_thold = std::stof(argv[++i]); }
|
||||
else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; }
|
||||
else if (arg == "-tr" || arg == "--translate") { params.translate = true; }
|
||||
else if (arg == "-di" || arg == "--diarize") { params.diarize = true; }
|
||||
else if (arg == "-tdrz" || arg == "--tinydiarize") { params.tinydiarize = true; }
|
||||
else if (arg == "-sow" || arg == "--split-on-word") { params.split_on_word = true; }
|
||||
else if (arg == "-nf" || arg == "--no-fallback") { params.no_fallback = true; }
|
||||
else if (arg == "-otxt" || arg == "--output-txt") { params.output_txt = true; }
|
||||
else if (arg == "-ovtt" || arg == "--output-vtt") { params.output_vtt = true; }
|
||||
else if (arg == "-osrt" || arg == "--output-srt") { params.output_srt = true; }
|
||||
else if (arg == "-owts" || arg == "--output-words") { params.output_wts = true; }
|
||||
else if (arg == "-olrc" || arg == "--output-lrc") { params.output_lrc = true; }
|
||||
else if (arg == "-fp" || arg == "--font-path") { params.font_path = argv[++i]; }
|
||||
else if (arg == "-ocsv" || arg == "--output-csv") { params.output_csv = true; }
|
||||
else if (arg == "-oj" || arg == "--output-json") { params.output_jsn = true; }
|
||||
else if (arg == "-of" || arg == "--output-file") { params.fname_out.emplace_back(argv[++i]); }
|
||||
else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; }
|
||||
else if (arg == "-pc" || arg == "--print-colors") { params.print_colors = true; }
|
||||
else if (arg == "-pp" || arg == "--print-progress") { params.print_progress = true; }
|
||||
else if (arg == "-nt" || arg == "--no-timestamps") { params.no_timestamps = true; }
|
||||
else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; }
|
||||
else if (arg == "-dl" || arg == "--detect-language") { params.detect_language = true; }
|
||||
else if ( arg == "--prompt") { params.prompt = argv[++i]; }
|
||||
else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; }
|
||||
else if (arg == "-f" || arg == "--file") { params.fname_inp.emplace_back(argv[++i]); }
|
||||
else {
|
||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||
whisper_print_usage(argc, argv, params);
|
||||
@ -182,6 +187,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
||||
fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false");
|
||||
fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false");
|
||||
fprintf(stderr, " -di, --diarize [%-7s] stereo audio diarization\n", params.diarize ? "true" : "false");
|
||||
fprintf(stderr, " -tdrz, --tinydiarize [%-7s] enable tinydiarize (requires a tdrz model)\n", params.tinydiarize ? "true" : "false");
|
||||
fprintf(stderr, " -nf, --no-fallback [%-7s] do not use temperature fallback while decoding\n", params.no_fallback ? "true" : "false");
|
||||
fprintf(stderr, " -otxt, --output-txt [%-7s] output result in a text file\n", params.output_txt ? "true" : "false");
|
||||
fprintf(stderr, " -ovtt, --output-vtt [%-7s] output result in a vtt file\n", params.output_vtt ? "true" : "false");
|
||||
@ -297,6 +303,12 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
|
||||
printf("%s%s", speaker.c_str(), text);
|
||||
}
|
||||
|
||||
if (params.tinydiarize) {
|
||||
if (whisper_full_get_segment_speaker_turn_next(ctx, i)) {
|
||||
printf("%s", params.tdrz_speaker_turn.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
// with timestamps or speakers: each segment on new line
|
||||
if (!params.no_timestamps || params.diarize) {
|
||||
printf("\n");
|
||||
@ -564,6 +576,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
@ -576,11 +589,15 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
|
||||
value_i("from", t0 * 10, false);
|
||||
value_i("to", t1 * 10, true);
|
||||
end_obj(false);
|
||||
value_s("text", text, !params.diarize);
|
||||
value_s("text", text, !params.diarize && !params.tinydiarize);
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
|
||||
}
|
||||
|
||||
if (params.tinydiarize) {
|
||||
value_b("speaker_turn_next", whisper_full_get_segment_speaker_turn_next(ctx, i), true);
|
||||
}
|
||||
end_obj(i == (n_segments - 1));
|
||||
}
|
||||
|
||||
@ -777,6 +794,12 @@ int main(int argc, char ** argv) {
|
||||
exit(0);
|
||||
}
|
||||
|
||||
if (params.diarize && params.tinydiarize) {
|
||||
fprintf(stderr, "error: cannot use both --diarize and --tinydiarize\n");
|
||||
whisper_print_usage(argc, argv, params);
|
||||
exit(0);
|
||||
}
|
||||
|
||||
// whisper init
|
||||
|
||||
struct whisper_context * ctx = whisper_init_from_file(params.model.c_str());
|
||||
@ -818,11 +841,12 @@ int main(int argc, char ** argv) {
|
||||
if (params.detect_language) {
|
||||
params.language = "auto";
|
||||
}
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, timestamps = %d ...\n",
|
||||
fprintf(stderr, "%s: processing '%s' (%d samples, %.1f sec), %d threads, %d processors, lang = %s, task = %s, %stimestamps = %d ...\n",
|
||||
__func__, fname_inp.c_str(), int(pcmf32.size()), float(pcmf32.size())/WHISPER_SAMPLE_RATE,
|
||||
params.n_threads, params.n_processors,
|
||||
params.language.c_str(),
|
||||
params.translate ? "translate" : "transcribe",
|
||||
params.tinydiarize ? "tdrz = 1, " : "",
|
||||
params.no_timestamps ? 0 : 1);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
@ -853,6 +877,8 @@ int main(int argc, char ** argv) {
|
||||
|
||||
wparams.speed_up = params.speed_up;
|
||||
|
||||
wparams.tdrz_enable = params.tinydiarize; // [TDRZ]
|
||||
|
||||
wparams.initial_prompt = params.prompt.c_str();
|
||||
|
||||
wparams.greedy.best_of = params.best_of;
|
||||
|
Reference in New Issue
Block a user