From 14baf2e7f38d356c60d3c65daf48c836727724fd Mon Sep 17 00:00:00 2001 From: Colin Date: Sun, 25 Jun 2023 07:07:57 -0500 Subject: [PATCH] main : add diarization support for all current output types (#1031) Co-authored-by: Georgi Gerganov --- examples/main/main.cpp | 168 +++++++++++++++++++++++++++++------------ 1 file changed, 118 insertions(+), 50 deletions(-) diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 07a7591f..ff62f74b 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -210,6 +210,39 @@ struct whisper_print_user_data { const std::vector> * pcmf32s; }; +std::string estimate_diarization_speaker(std::vector> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) { + std::string speaker = ""; + const int64_t n_samples = pcmf32s[0].size(); + + const int64_t is0 = timestamp_to_sample(t0, n_samples); + const int64_t is1 = timestamp_to_sample(t1, n_samples); + + double energy0 = 0.0f; + double energy1 = 0.0f; + + for (int64_t j = is0; j < is1; j++) { + energy0 += fabs(pcmf32s[0][j]); + energy1 += fabs(pcmf32s[1][j]); + } + + if (energy0 > 1.1*energy1) { + speaker = "0"; + } else if (energy1 > 1.1*energy0) { + speaker = "1"; + } else { + speaker = "?"; + } + + //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str()); + + if (!id_only) { + speaker.insert(0, "(speaker "); + speaker.append(")"); + } + + return speaker; +} + void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) { const auto & params = *((whisper_print_user_data *) user_data)->params; const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s; @@ -239,28 +272,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper } if (params.diarize && pcmf32s.size() == 2) { - const int64_t n_samples = pcmf32s[0].size(); - - const int64_t is0 = timestamp_to_sample(t0, n_samples); - const int64_t is1 = timestamp_to_sample(t1, n_samples); - - double energy0 = 0.0f; - double energy1 = 0.0f; - - for (int64_t j = is0; j < is1; j++) { - energy0 += fabs(pcmf32s[0][j]); - energy1 += fabs(pcmf32s[1][j]); - } - - if (energy0 > 1.1*energy1) { - speaker = "(speaker 0)"; - } else if (energy1 > 1.1*energy0) { - speaker = "(speaker 1)"; - } else { - speaker = "(speaker ?)"; - } - - //printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str()); + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); } if (params.print_colors) { @@ -294,7 +306,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper } } -bool output_txt(struct whisper_context * ctx, const char * fname) { +bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); @@ -306,13 +318,22 @@ bool output_txt(struct whisper_context * ctx, const char * fname) { const int n_segments = whisper_full_n_segments(ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); - fout << text << "\n"; + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) + { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + fout << speaker << text << "\n"; } return true; } -bool output_vtt(struct whisper_context * ctx, const char * fname) { +bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); @@ -328,15 +349,23 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) { const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) + { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true); + speaker.insert(0, ""); + } fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n"; - fout << text << "\n\n"; + fout << speaker << text << "\n\n"; } return true; } -bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) { +bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); @@ -350,10 +379,16 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_ const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) + { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } fout << i + 1 + params.offset_n << "\n"; fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n"; - fout << text << "\n\n"; + fout << speaker << text << "\n\n"; } return true; @@ -390,7 +425,7 @@ char *escape_double_quotes_and_backslashes(const char *str) { return escaped; } -bool output_csv(struct whisper_context * ctx, const char * fname) { +bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); @@ -400,7 +435,13 @@ bool output_csv(struct whisper_context * ctx, const char * fname) { fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); const int n_segments = whisper_full_n_segments(ctx); - fout << "start,end,text\n"; + fout << "start,end,"; + if (params.diarize && pcmf32s.size() == 2) + { + fout << "speaker,"; + } + fout << "text\n"; + for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(ctx, i); const int64_t t0 = whisper_full_get_segment_t0(ctx, i); @@ -408,13 +449,18 @@ bool output_csv(struct whisper_context * ctx, const char * fname) { char * text_escaped = escape_double_quotes_and_backslashes(text); //need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds. - fout << 10 * t0 << "," << 10 * t1 << ",\"" << text_escaped << "\"\n"; + fout << 10 * t0 << "," << 10 * t1 << ","; + if (params.diarize && pcmf32s.size() == 2) + { + fout << estimate_diarization_speaker(pcmf32s, t0, t1, true) << ","; + } + fout << "\"" << text_escaped << "\"\n"; } return true; } -bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) { +bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { std::ofstream fout(fname); int indent = 0; @@ -530,7 +576,11 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper value_i("from", t0 * 10, false); value_i("to", t1 * 10, true); end_obj(false); - value_s("text", text, true); + value_s("text", text, !params.diarize); + + if (params.diarize && pcmf32s.size() == 2) { + value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true); + } end_obj(i == (n_segments - 1)); } @@ -542,7 +592,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper // karaoke video generation // outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments -bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) { +bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector> pcmf32s) { std::ofstream fout(fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); @@ -579,6 +629,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'"; bool is_first = true; + std::string speaker = ""; + + if (params.diarize && pcmf32s.size() == 2) { + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } for (int j = 0; j < n; ++j) { const auto & token = tokens[j]; @@ -587,13 +642,19 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f continue; } - std::string txt_bg; - std::string txt_fg; // highlight token - std::string txt_ul; // underline + std::string txt_bg = ""; + std::string txt_fg = ""; // highlight token + std::string txt_ul = ""; // underline - txt_bg = "> "; - txt_fg = "> "; - txt_ul = "\\ \\ "; + if (params.diarize && pcmf32s.size() == 2) { + txt_bg = speaker; + txt_fg = speaker; + txt_ul = "\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ "; + } + + txt_bg.append("> "); + txt_fg.append("> "); + txt_ul.append("\\ \\ "); { for (int k = 0; k < n; ++k) { @@ -656,8 +717,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f return true; } -bool output_lrc(struct whisper_context * ctx, const char * fname) { - +bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector> pcmf32s) { std::ofstream fout(fname); if (!fout.is_open()) { fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname); @@ -682,8 +742,16 @@ bool output_lrc(struct whisper_context * ctx, const char * fname) { char buf[16]; snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10)); std::string timestamp_lrc = std::string(buf); + std::string speaker = ""; - fout << '[' << timestamp_lrc << ']' << text << "\n"; + if (params.diarize && pcmf32s.size() == 2) + { + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + speaker = estimate_diarization_speaker(pcmf32s, t0, t1); + } + + fout << '[' << timestamp_lrc << ']' << speaker << text << "\n"; } return true; @@ -828,43 +896,43 @@ int main(int argc, char ** argv) { // output to text file if (params.output_txt) { const auto fname_txt = fname_out + ".txt"; - output_txt(ctx, fname_txt.c_str()); + output_txt(ctx, fname_txt.c_str(), params, pcmf32s); } // output to VTT file if (params.output_vtt) { const auto fname_vtt = fname_out + ".vtt"; - output_vtt(ctx, fname_vtt.c_str()); + output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s); } // output to SRT file if (params.output_srt) { const auto fname_srt = fname_out + ".srt"; - output_srt(ctx, fname_srt.c_str(), params); + output_srt(ctx, fname_srt.c_str(), params, pcmf32s); } // output to WTS file if (params.output_wts) { const auto fname_wts = fname_out + ".wts"; - output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); + output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s); } // output to CSV file if (params.output_csv) { const auto fname_csv = fname_out + ".csv"; - output_csv(ctx, fname_csv.c_str()); + output_csv(ctx, fname_csv.c_str(), params, pcmf32s); } // output to JSON file if (params.output_jsn) { const auto fname_jsn = fname_out + ".json"; - output_json(ctx, fname_jsn.c_str(), params); + output_json(ctx, fname_jsn.c_str(), params, pcmf32s); } // output to LRC file if (params.output_lrc) { const auto fname_lrc = fname_out + ".lrc"; - output_lrc(ctx, fname_lrc.c_str()); + output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s); } } }