mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-20 05:07:52 +00:00
main : add diarization support for all current output types (#1031)
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
bc2dcf85fe
commit
14baf2e7f3
@ -210,6 +210,39 @@ struct whisper_print_user_data {
|
||||
const std::vector<std::vector<float>> * pcmf32s;
|
||||
};
|
||||
|
||||
std::string estimate_diarization_speaker(std::vector<std::vector<float>> pcmf32s, int64_t t0, int64_t t1, bool id_only = false) {
|
||||
std::string speaker = "";
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "0";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "1";
|
||||
} else {
|
||||
speaker = "?";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, speaker = %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
|
||||
if (!id_only) {
|
||||
speaker.insert(0, "(speaker ");
|
||||
speaker.append(")");
|
||||
}
|
||||
|
||||
return speaker;
|
||||
}
|
||||
|
||||
void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper_state * /*state*/, int n_new, void * user_data) {
|
||||
const auto & params = *((whisper_print_user_data *) user_data)->params;
|
||||
const auto & pcmf32s = *((whisper_print_user_data *) user_data)->pcmf32s;
|
||||
@ -239,28 +272,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
|
||||
}
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
const int64_t n_samples = pcmf32s[0].size();
|
||||
|
||||
const int64_t is0 = timestamp_to_sample(t0, n_samples);
|
||||
const int64_t is1 = timestamp_to_sample(t1, n_samples);
|
||||
|
||||
double energy0 = 0.0f;
|
||||
double energy1 = 0.0f;
|
||||
|
||||
for (int64_t j = is0; j < is1; j++) {
|
||||
energy0 += fabs(pcmf32s[0][j]);
|
||||
energy1 += fabs(pcmf32s[1][j]);
|
||||
}
|
||||
|
||||
if (energy0 > 1.1*energy1) {
|
||||
speaker = "(speaker 0)";
|
||||
} else if (energy1 > 1.1*energy0) {
|
||||
speaker = "(speaker 1)";
|
||||
} else {
|
||||
speaker = "(speaker ?)";
|
||||
}
|
||||
|
||||
//printf("is0 = %lld, is1 = %lld, energy0 = %f, energy1 = %f, %s\n", is0, is1, energy0, energy1, speaker.c_str());
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
|
||||
}
|
||||
|
||||
if (params.print_colors) {
|
||||
@ -294,7 +306,7 @@ void whisper_print_segment_callback(struct whisper_context * ctx, struct whisper
|
||||
}
|
||||
}
|
||||
|
||||
bool output_txt(struct whisper_context * ctx, const char * fname) {
|
||||
bool output_txt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
@ -306,13 +318,22 @@ bool output_txt(struct whisper_context * ctx, const char * fname) {
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
fout << text << "\n";
|
||||
std::string speaker = "";
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2)
|
||||
{
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
|
||||
}
|
||||
|
||||
fout << speaker << text << "\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_vtt(struct whisper_context * ctx, const char * fname) {
|
||||
bool output_vtt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
@ -328,15 +349,23 @@ bool output_vtt(struct whisper_context * ctx, const char * fname) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
std::string speaker = "";
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2)
|
||||
{
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1, true);
|
||||
speaker.insert(0, "<v Speaker");
|
||||
speaker.append(">");
|
||||
}
|
||||
|
||||
fout << to_timestamp(t0) << " --> " << to_timestamp(t1) << "\n";
|
||||
fout << text << "\n\n";
|
||||
fout << speaker << text << "\n\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
|
||||
bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
@ -350,10 +379,16 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
std::string speaker = "";
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2)
|
||||
{
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
|
||||
}
|
||||
|
||||
fout << i + 1 + params.offset_n << "\n";
|
||||
fout << to_timestamp(t0, true) << " --> " << to_timestamp(t1, true) << "\n";
|
||||
fout << text << "\n\n";
|
||||
fout << speaker << text << "\n\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -390,7 +425,7 @@ char *escape_double_quotes_and_backslashes(const char *str) {
|
||||
return escaped;
|
||||
}
|
||||
|
||||
bool output_csv(struct whisper_context * ctx, const char * fname) {
|
||||
bool output_csv(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
@ -400,7 +435,13 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
fout << "start,end,text\n";
|
||||
fout << "start,end,";
|
||||
if (params.diarize && pcmf32s.size() == 2)
|
||||
{
|
||||
fout << "speaker,";
|
||||
}
|
||||
fout << "text\n";
|
||||
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
@ -408,13 +449,18 @@ bool output_csv(struct whisper_context * ctx, const char * fname) {
|
||||
char * text_escaped = escape_double_quotes_and_backslashes(text);
|
||||
|
||||
//need to multiply times returned from whisper_full_get_segment_t{0,1}() by 10 to get milliseconds.
|
||||
fout << 10 * t0 << "," << 10 * t1 << ",\"" << text_escaped << "\"\n";
|
||||
fout << 10 * t0 << "," << 10 * t1 << ",";
|
||||
if (params.diarize && pcmf32s.size() == 2)
|
||||
{
|
||||
fout << estimate_diarization_speaker(pcmf32s, t0, t1, true) << ",";
|
||||
}
|
||||
fout << "\"" << text_escaped << "\"\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params) {
|
||||
bool output_json(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||
std::ofstream fout(fname);
|
||||
int indent = 0;
|
||||
|
||||
@ -530,7 +576,11 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
|
||||
value_i("from", t0 * 10, false);
|
||||
value_i("to", t1 * 10, true);
|
||||
end_obj(false);
|
||||
value_s("text", text, true);
|
||||
value_s("text", text, !params.diarize);
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
value_s("speaker", estimate_diarization_speaker(pcmf32s, t0, t1, true).c_str(), true);
|
||||
}
|
||||
end_obj(i == (n_segments - 1));
|
||||
}
|
||||
|
||||
@ -542,7 +592,7 @@ bool output_json(struct whisper_context * ctx, const char * fname, const whisper
|
||||
// karaoke video generation
|
||||
// outputs a bash script that uses ffmpeg to generate a video with the subtitles
|
||||
// TODO: font parameter adjustments
|
||||
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) {
|
||||
bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec, std::vector<std::vector<float>> pcmf32s) {
|
||||
std::ofstream fout(fname);
|
||||
|
||||
fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname);
|
||||
@ -579,6 +629,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
||||
fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'";
|
||||
|
||||
bool is_first = true;
|
||||
std::string speaker = "";
|
||||
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
|
||||
}
|
||||
|
||||
for (int j = 0; j < n; ++j) {
|
||||
const auto & token = tokens[j];
|
||||
@ -587,13 +642,19 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
||||
continue;
|
||||
}
|
||||
|
||||
std::string txt_bg;
|
||||
std::string txt_fg; // highlight token
|
||||
std::string txt_ul; // underline
|
||||
std::string txt_bg = "";
|
||||
std::string txt_fg = ""; // highlight token
|
||||
std::string txt_ul = ""; // underline
|
||||
|
||||
txt_bg = "> ";
|
||||
txt_fg = "> ";
|
||||
txt_ul = "\\ \\ ";
|
||||
if (params.diarize && pcmf32s.size() == 2) {
|
||||
txt_bg = speaker;
|
||||
txt_fg = speaker;
|
||||
txt_ul = "\\ \\ \\ \\ \\ \\ \\ \\ \\ \\ \\ ";
|
||||
}
|
||||
|
||||
txt_bg.append("> ");
|
||||
txt_fg.append("> ");
|
||||
txt_ul.append("\\ \\ ");
|
||||
|
||||
{
|
||||
for (int k = 0; k < n; ++k) {
|
||||
@ -656,8 +717,7 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f
|
||||
return true;
|
||||
}
|
||||
|
||||
bool output_lrc(struct whisper_context * ctx, const char * fname) {
|
||||
|
||||
bool output_lrc(struct whisper_context * ctx, const char * fname, const whisper_params & params, std::vector<std::vector<float>> pcmf32s) {
|
||||
std::ofstream fout(fname);
|
||||
if (!fout.is_open()) {
|
||||
fprintf(stderr, "%s: failed to open '%s' for writing\n", __func__, fname);
|
||||
@ -682,8 +742,16 @@ bool output_lrc(struct whisper_context * ctx, const char * fname) {
|
||||
char buf[16];
|
||||
snprintf(buf, sizeof(buf), "%02d:%02d.%02d", (int) min, (int) sec, (int) ( msec / 10));
|
||||
std::string timestamp_lrc = std::string(buf);
|
||||
std::string speaker = "";
|
||||
|
||||
fout << '[' << timestamp_lrc << ']' << text << "\n";
|
||||
if (params.diarize && pcmf32s.size() == 2)
|
||||
{
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
speaker = estimate_diarization_speaker(pcmf32s, t0, t1);
|
||||
}
|
||||
|
||||
fout << '[' << timestamp_lrc << ']' << speaker << text << "\n";
|
||||
}
|
||||
|
||||
return true;
|
||||
@ -828,43 +896,43 @@ int main(int argc, char ** argv) {
|
||||
// output to text file
|
||||
if (params.output_txt) {
|
||||
const auto fname_txt = fname_out + ".txt";
|
||||
output_txt(ctx, fname_txt.c_str());
|
||||
output_txt(ctx, fname_txt.c_str(), params, pcmf32s);
|
||||
}
|
||||
|
||||
// output to VTT file
|
||||
if (params.output_vtt) {
|
||||
const auto fname_vtt = fname_out + ".vtt";
|
||||
output_vtt(ctx, fname_vtt.c_str());
|
||||
output_vtt(ctx, fname_vtt.c_str(), params, pcmf32s);
|
||||
}
|
||||
|
||||
// output to SRT file
|
||||
if (params.output_srt) {
|
||||
const auto fname_srt = fname_out + ".srt";
|
||||
output_srt(ctx, fname_srt.c_str(), params);
|
||||
output_srt(ctx, fname_srt.c_str(), params, pcmf32s);
|
||||
}
|
||||
|
||||
// output to WTS file
|
||||
if (params.output_wts) {
|
||||
const auto fname_wts = fname_out + ".wts";
|
||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE);
|
||||
output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE, pcmf32s);
|
||||
}
|
||||
|
||||
// output to CSV file
|
||||
if (params.output_csv) {
|
||||
const auto fname_csv = fname_out + ".csv";
|
||||
output_csv(ctx, fname_csv.c_str());
|
||||
output_csv(ctx, fname_csv.c_str(), params, pcmf32s);
|
||||
}
|
||||
|
||||
// output to JSON file
|
||||
if (params.output_jsn) {
|
||||
const auto fname_jsn = fname_out + ".json";
|
||||
output_json(ctx, fname_jsn.c_str(), params);
|
||||
output_json(ctx, fname_jsn.c_str(), params, pcmf32s);
|
||||
}
|
||||
|
||||
// output to LRC file
|
||||
if (params.output_lrc) {
|
||||
const auto fname_lrc = fname_out + ".lrc";
|
||||
output_lrc(ctx, fname_lrc.c_str());
|
||||
output_lrc(ctx, fname_lrc.c_str(), params, pcmf32s);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user