diff --git a/README.md b/README.md index fdbc65e9..a8888805 100644 --- a/README.md +++ b/README.md @@ -101,13 +101,14 @@ options: -ot N, --offset-t N time offset in milliseconds (default: 0) -on N, --offset-n N segment index offset (default: 0) -mc N, --max-context N maximum number of text context tokens to store (default: max) + -ml N, --max-len N maximum segment length in characters (default: 0) -wt N, --word-thold N word timestamp probability threshold (default: 0.010000) -v, --verbose verbose output --translate translate from source language to english -otxt, --output-txt output result in a text file -ovtt, --output-vtt output result in a vtt file -osrt, --output-srt output result in a srt file - -owts, --output-words output word-level timestamps to a text file + -owts, --output-words output script for generating karaoke video -ps, --print_special print special tokens -pc, --print_colors print colors -nt, --no_timestamps do not print timestamps diff --git a/examples/main/README.md b/examples/main/README.md index 27f47ff5..f2bf2a8d 100644 --- a/examples/main/README.md +++ b/examples/main/README.md @@ -8,7 +8,6 @@ It can be used as a reference for using the `whisper.cpp` library in other proje usage: ./bin/main [options] file0.wav file1.wav ... -options: -h, --help show this help message and exit -s SEED, --seed SEED RNG seed (default: -1) -t N, --threads N number of threads to use during computation (default: 4) @@ -16,18 +15,20 @@ options: -ot N, --offset-t N time offset in milliseconds (default: 0) -on N, --offset-n N segment index offset (default: 0) -mc N, --max-context N maximum number of text context tokens to store (default: max) + -ml N, --max-len N maximum segment length in characters (default: 0) -wt N, --word-thold N word timestamp probability threshold (default: 0.010000) -v, --verbose verbose output --translate translate from source language to english -otxt, --output-txt output result in a text file -ovtt, --output-vtt output result in a vtt file -osrt, --output-srt output result in a srt file - -owts, --output-words output word-level timestamps to a text file + -owts, --output-words output script for generating karaoke video -ps, --print_special print special tokens -pc, --print_colors print colors -nt, --no_timestamps do not print timestamps -l LANG, --language LANG spoken language (default: en) -m FNAME, --model FNAME model path (default: models/ggml-base.en.bin) -f FNAME, --file FNAME input WAV file path + -h, --help show this help message and exit ``` diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 83438921..b5894599 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -36,6 +36,7 @@ std::string to_timestamp(int64_t t, bool comma = false) { return std::string(buf); } +// helper function to replace substrings void replace_all(std::string & s, const std::string & search, const std::string & replace) { for (size_t pos = 0; ; pos += replace.length()) { pos = s.find(search, pos); @@ -45,31 +46,6 @@ void replace_all(std::string & s, const std::string & search, const std::string } } -// a cost-function that is high for text that takes longer to pronounce -float voice_length(const std::string & text) { - float res = 0.0f; - - for (size_t i = 0; i < text.size(); ++i) { - if (text[i] == ' ') { - res += 0.01f; - } else if (text[i] == ',') { - res += 2.00f; - } else if (text[i] == '.') { - res += 3.00f; - } else if (text[i] == '!') { - res += 3.00f; - } else if (text[i] == '?') { - res += 3.00f; - } else if (text[i] >= '0' && text[i] <= '9') { - res += 3.00f; - } else { - res += 1.00f; - } - } - - return res; -} - // command-line parameters struct whisper_params { int32_t seed = -1; // RNG seed, not used currently @@ -78,6 +54,7 @@ struct whisper_params { int32_t offset_t_ms = 0; int32_t offset_n = 0; int32_t max_context = -1; + int32_t max_len = 0; float word_thold = 0.01f; @@ -120,6 +97,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { params.offset_n = std::stoi(argv[++i]); } else if (arg == "-mc" || arg == "--max-context") { params.max_context = std::stoi(argv[++i]); + } else if (arg == "-ml" || arg == "--max-len") { + params.max_len = std::stoi(argv[++i]); } else if (arg == "-wt" || arg == "--word-thold") { params.word_thold = std::stof(argv[++i]); } else if (arg == "-v" || arg == "--verbose") { @@ -176,13 +155,14 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, " -ot N, --offset-t N time offset in milliseconds (default: %d)\n", params.offset_t_ms); fprintf(stderr, " -on N, --offset-n N segment index offset (default: %d)\n", params.offset_n); fprintf(stderr, " -mc N, --max-context N maximum number of text context tokens to store (default: max)\n"); + fprintf(stderr, " -ml N, --max-len N maximum segment length in characters (default: %d)\n", params.max_len); fprintf(stderr, " -wt N, --word-thold N word timestamp probability threshold (default: %f)\n", params.word_thold); fprintf(stderr, " -v, --verbose verbose output\n"); fprintf(stderr, " --translate translate from source language to english\n"); fprintf(stderr, " -otxt, --output-txt output result in a text file\n"); fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n"); fprintf(stderr, " -osrt, --output-srt output result in a srt file\n"); - fprintf(stderr, " -owts, --output-words output word-level timestamps to a text file\n"); + fprintf(stderr, " -owts, --output-words output script for generating karaoke video\n"); fprintf(stderr, " -ps, --print_special print special tokens\n"); fprintf(stderr, " -pc, --print_colors print colors\n"); fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n"); @@ -192,65 +172,67 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params) fprintf(stderr, "\n"); } -void whisper_print_segment_callback(struct whisper_context * ctx, void * user_data) { +void whisper_print_segment_callback(struct whisper_context * ctx, int n_new, void * user_data) { const whisper_params & params = *(whisper_params *) user_data; const int n_segments = whisper_full_n_segments(ctx); - // print the last segment - const int i = n_segments - 1; - if (i == 0) { + // print the last n_new segments + const int s0 = n_segments - n_new; + if (s0 == 0) { printf("\n"); } - if (params.no_timestamps) { - if (params.print_colors) { - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special_tokens == false) { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) { - continue; + for (int i = s0; i < n_segments; i++) { + if (params.no_timestamps) { + if (params.print_colors) { + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special_tokens == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } } + + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); } - - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); - - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - - printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + } else { + const char * text = whisper_full_get_segment_text(ctx, i); + printf("%s", text); } + fflush(stdout); } else { - const char * text = whisper_full_get_segment_text(ctx, i); - printf("%s", text); - } - fflush(stdout); - } else { - const int64_t t0 = whisper_full_get_segment_t0(ctx, i); - const int64_t t1 = whisper_full_get_segment_t1(ctx, i); + const int64_t t0 = whisper_full_get_segment_t0(ctx, i); + const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - if (params.print_colors) { - printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); - for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { - if (params.print_special_tokens == false) { - const whisper_token id = whisper_full_get_token_id(ctx, i, j); - if (id >= whisper_token_eot(ctx)) { - continue; + if (params.print_colors) { + printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str()); + for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) { + if (params.print_special_tokens == false) { + const whisper_token id = whisper_full_get_token_id(ctx, i, j); + if (id >= whisper_token_eot(ctx)) { + continue; + } } + + const char * text = whisper_full_get_token_text(ctx, i, j); + const float p = whisper_full_get_token_p (ctx, i, j); + + const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); + + printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); } + printf("\n"); + } else { + const char * text = whisper_full_get_segment_text(ctx, i); - const char * text = whisper_full_get_token_text(ctx, i, j); - const float p = whisper_full_get_token_p (ctx, i, j); - - const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size())))); - - printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m"); + printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); } - printf("\n"); - } else { - const char * text = whisper_full_get_segment_text(ctx, i); - - printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text); } } } @@ -320,297 +302,41 @@ bool output_srt(struct whisper_context * ctx, const char * fname, const whisper_ return true; } -// word-level timestamps (experimental) -// TODO: make ffmpeg output optional -// TODO: extra pass to detect unused speech and assign to tokens +// karaoke video generation +// outputs a bash script that uses ffmpeg to generate a video with the subtitles // TODO: font parameter adjustments -// TODO: move to whisper.h/whisper.cpp and add parameter to select max line-length of subtitles -bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, const std::vector & pcmf32) { - std::vector pcm_avg(pcmf32.size(), 0); - - // average the fabs of the signal - { - const int hw = 32; - - for (int i = 0; i < pcmf32.size(); i++) { - float sum = 0; - for (int j = -hw; j <= hw; j++) { - if (i + j >= 0 && i + j < pcmf32.size()) { - sum += fabs(pcmf32[i + j]); - } - } - pcm_avg[i] = sum/(2*hw + 1); - } - } - - struct token_info { - int64_t t0 = -1; - int64_t t1 = -1; - - int64_t tt0 = -1; - int64_t tt1 = -1; - - whisper_token id; - whisper_token tid; - - float p = 0.0f; - float pt = 0.0f; - float ptsum = 0.0f; - - std::string text; - float vlen = 0.0f; // voice length of this token - }; - - int64_t t_beg = 0; - int64_t t_last = 0; - - whisper_token tid_last = 0; - +bool output_wts(struct whisper_context * ctx, const char * fname, const char * fname_inp, const whisper_params & params, float t_sec) { std::ofstream fout(fname); fprintf(stderr, "%s: saving output to '%s'\n", __func__, fname); + // TODO: become parameter + static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; + fout << "!/bin/bash" << "\n"; fout << "\n"; - fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE << ":rate=25:color=black -vf \""; - - bool is_first = true; + fout << "ffmpeg -i " << fname_inp << " -f lavfi -i color=size=1200x120:duration=" << t_sec << ":rate=25:color=black -vf \""; for (int i = 0; i < whisper_full_n_segments(ctx); i++) { const int64_t t0 = whisper_full_get_segment_t0(ctx, i); const int64_t t1 = whisper_full_get_segment_t1(ctx, i); - const char *text = whisper_full_get_segment_text(ctx, i); - - const int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100)); - const int s1 = std::min((int) pcmf32.size(), (int) (t1*WHISPER_SAMPLE_RATE/100)); - const int n = whisper_full_n_tokens(ctx, i); - std::vector tokens(n); - - if (n <= 1) { - continue; - } - + std::vector tokens(n); for (int j = 0; j < n; ++j) { - struct whisper_token_data token = whisper_full_get_token_data(ctx, i, j); - - if (j == 0) { - if (token.id == whisper_token_beg(ctx)) { - tokens[j ].t0 = t0; - tokens[j ].t1 = t0; - tokens[j + 1].t0 = t0; - - t_beg = t0; - t_last = t0; - tid_last = whisper_token_beg(ctx); - } else { - tokens[j ].t0 = t_last; - } - } - - const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); - - tokens[j].id = token.id; - tokens[j].tid = token.tid; - tokens[j].p = token.p; - tokens[j].pt = token.pt; - tokens[j].ptsum = token.ptsum; - - tokens[j].text = whisper_token_to_str(ctx, token.id); - tokens[j].vlen = voice_length(tokens[j].text); - - if (token.pt > params.word_thold && token.ptsum > 0.01 && token.tid > tid_last && tt <= t1) { - if (j > 0) { - tokens[j - 1].t1 = tt; - } - tokens[j].t0 = tt; - tid_last = token.tid; - } + tokens[j] = whisper_full_get_token_data(ctx, i, j); } - tokens[n - 2].t1 = t1; - tokens[n - 1].t0 = t1; - tokens[n - 1].t1 = t1; - - t_last = t1; - - // find intervals of tokens with unknown timestamps - // fill the timestamps by proportionally splitting the interval based on the token voice lengths - { - int p0 = 0; - int p1 = 0; - while (true) { - while (p1 < n && tokens[p1].t1 < 0) { - p1++; - } - - if (p1 >= n) { - p1--; - } - - if (p1 > p0) { - double psum = 0.0; - for (int j = p0; j <= p1; j++) { - psum += tokens[j].vlen; - } - - //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); - - const double dt = tokens[p1].t1 - tokens[p0].t0; - - // split the time proportionally to the voice length - for (int j = p0 + 1; j <= p1; j++) { - const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; - - tokens[j - 1].t1 = ct; - tokens[j ].t0 = ct; - } - } - - p1++; - p0 = p1; - if (p1 >= n) { - break; - } - } - } - - // fix up (just in case) - for (int j = 0; j < n - 1; j++) { - if (tokens[j].t1 < 0) { - tokens[j + 1].t0 = tokens[j].t1; - } - - if (j > 0) { - if (tokens[j - 1].t1 > tokens[j].t0) { - tokens[j].t0 = tokens[j - 1].t1; - tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); - } - } - - tokens[j].tt0 = tokens[j].t0; - tokens[j].tt1 = tokens[j].t1; - } - - // VAD - // expand or contract tokens based on voice activity - { - const int hw = WHISPER_SAMPLE_RATE/8; - - for (int j = 0; j < n; j++) { - if (tokens[j].id >= whisper_token_eot(ctx)) { - continue; - } - - const int64_t t0 = tokens[j].t0; - const int64_t t1 = tokens[j].t1; - - int s0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100)); - int s1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100)); - - const int ss0 = std::max(0, (int) (t0*WHISPER_SAMPLE_RATE/100) - hw); - const int ss1 = std::min((int) pcmf32.size() - 1, (int) (t1*WHISPER_SAMPLE_RATE/100) + hw); - - const int n = ss1 - ss0; - - float sum = 0.0f; - - for (int k = ss0; k < ss1; k++) { - sum += pcm_avg[k]; - } - - const float thold = 0.5*sum/n; - - { - int k = s0; - if (pcm_avg[k] > thold && j > 0) { - while (k > 0 && pcm_avg[k] > thold) { - k--; - } - tokens[j].t0 = (int64_t) (100*k/WHISPER_SAMPLE_RATE); - if (tokens[j].t0 < tokens[j - 1].t1) { - tokens[j].t0 = tokens[j - 1].t1; - } else { - s0 = k; - } - } else { - while (pcm_avg[k] < thold && k < s1) { - k++; - } - s0 = k; - tokens[j].t0 = 100*k/WHISPER_SAMPLE_RATE; - } - } - - { - int k = s1; - if (pcm_avg[k] > thold) { - while (k < (int) pcmf32.size() - 1 && pcm_avg[k] > thold) { - k++; - } - tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE; - if (j < n - 1 && tokens[j].t1 > tokens[j + 1].t0) { - tokens[j].t1 = tokens[j + 1].t0; - } else { - s1 = k; - } - } else { - while (pcm_avg[k] < thold && k > s0) { - k--; - } - s1 = k; - tokens[j].t1 = 100*k/WHISPER_SAMPLE_RATE; - } - } - } - } - - // fixed token expand (optional) - { - const int t_expand = 0; - - for (int j = 0; j < n; j++) { - if (j > 0) { - tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); - } - if (j < n - 1) { - tokens[j].t1 = tokens[j].t1 + t_expand; - } - } - } - - // debug info - // TODO: toggle via parameter - for (int j = 0; j < n; ++j) { - const auto & token = tokens[j]; - const auto tt = token.pt > params.word_thold && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; - printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, - tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, token.text.c_str()); - - if (tokens[j].id >= whisper_token_eot(ctx)) { - continue; - } - - //printf("[%s --> %s] %s\n", to_timestamp(token.t0).c_str(), to_timestamp(token.t1).c_str(), whisper_token_to_str(ctx, token.id)); - - //fout << "# " << to_timestamp(token.t0) << " --> " << to_timestamp(token.t1) << " " << whisper_token_to_str(ctx, token.id) << "\n"; - } - - // TODO: become parameters - static const int line_wrap = 60; - static const char * font = "/System/Library/Fonts/Supplemental/Courier New Bold.ttf"; - - if (!is_first) { + if (i > 0) { fout << ","; } // background text fout << "drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='':enable='between(t," << t0/100.0 << "," << t0/100.0 << ")'"; - is_first = false; + bool is_first = true; for (int j = 0; j < n; ++j) { const auto & token = tokens[j]; @@ -654,17 +380,6 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f } ncnt += txt.size(); - - if (ncnt > line_wrap) { - if (k < j) { - txt_bg = "> "; - txt_fg = "> "; - txt_ul = "\\ \\ "; - ncnt = 0; - } else { - break; - } - } } ::replace_all(txt_bg, "'", "’"); @@ -673,8 +388,11 @@ bool output_wts(struct whisper_context * ctx, const char * fname, const char * f ::replace_all(txt_fg, "\"", "\\\""); } - // background text - fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << token.tt0/100.0 << "," << token.tt1/100.0 << ")'"; + if (is_first) { + // background text + fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=gray:x=(w-text_w)/2:y=h/2:text='" << txt_bg << "':enable='between(t," << t0/100.0 << "," << t1/100.0 << ")'"; + is_first = false; + } // foreground text fout << ",drawtext=fontfile='" << font << "':fontsize=24:fontcolor=lightgreen:x=(w-text_w)/2+8:y=h/2:text='" << txt_fg << "':enable='between(t," << token.t0/100.0 << "," << token.t1/100.0 << ")'"; @@ -815,6 +533,10 @@ int main(int argc, char ** argv) { wparams.n_max_text_ctx = params.max_context >= 0 ? params.max_context : wparams.n_max_text_ctx; wparams.offset_ms = params.offset_t_ms; + wparams.token_timestamps = params.output_wts || params.max_len > 0; + wparams.thold_pt = params.word_thold; + wparams.max_len = params.output_wts && params.max_len == 0 ? 60 : params.max_len; + // this callback is called on each new segment if (!wparams.print_realtime) { wparams.new_segment_callback = whisper_print_segment_callback; @@ -852,7 +574,7 @@ int main(int argc, char ** argv) { // output to WTS file if (params.output_wts) { const auto fname_wts = fname_inp + ".wts"; - output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, pcmf32); + output_wts(ctx, fname_wts.c_str(), fname_inp.c_str(), params, float(pcmf32.size() + 1000)/WHISPER_SAMPLE_RATE); } } } diff --git a/whisper.cpp b/whisper.cpp index b230d0c0..02ab5cbc 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -418,6 +418,12 @@ struct whisper_context { std::vector result_all; std::vector prompt_past; + + // [EXPERIMENTAL] token-level timestamps data + int64_t t_beg; + int64_t t_last; + whisper_token tid_last; + std::vector energy; // PCM signal energy }; // load the model from a ggml file @@ -431,7 +437,7 @@ struct whisper_context { // // see the convert-pt-to-ggml.py script for details // -bool whisper_model_load(const std::string & fname, whisper_context & wctx) { +static bool whisper_model_load(const std::string & fname, whisper_context & wctx) { fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str()); auto & model = wctx.model; @@ -1062,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) { // - n_threads: number of threads to use // - mel_offset: offset in the mel spectrogram (i.e. audio offset) // -bool whisper_encode( +static bool whisper_encode( whisper_context & wctx, const int n_threads, const int mel_offset) { @@ -1448,7 +1454,7 @@ bool whisper_encode( // - n_tokens: number of tokens in the prompt // - n_past: number of past tokens to prefix the prompt with // -bool whisper_decode( +static bool whisper_decode( whisper_context & wctx, const int n_threads, const whisper_token * tokens, @@ -1811,10 +1817,12 @@ bool whisper_decode( } // the most basic sampling scheme - select the top token -whisper_token_data whisper_sample_best( +static whisper_token_data whisper_sample_best( const whisper_vocab & vocab, const float * probs) { - whisper_token_data result; + whisper_token_data result = { + 0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f, + }; int n_logits = vocab.id_to_token.size(); @@ -1887,7 +1895,7 @@ whisper_token_data whisper_sample_best( } // samples only from the timestamps tokens -whisper_vocab::id whisper_sample_timestamp( +static whisper_vocab::id whisper_sample_timestamp( const whisper_vocab & vocab, const float * probs) { int n_logits = vocab.id_to_token.size(); @@ -1939,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) { // naive Discrete Fourier Transform // input is real-valued // output is complex-valued -void dft(const std::vector & in, std::vector & out) { +static void dft(const std::vector & in, std::vector & out) { int N = in.size(); out.resize(N*2); @@ -1963,7 +1971,7 @@ void dft(const std::vector & in, std::vector & out) { // poor man's implementation - use something better // input is real-valued // output is complex-valued -void fft(const std::vector & in, std::vector & out) { +static void fft(const std::vector & in, std::vector & out) { out.resize(in.size()*2); int N = in.size(); @@ -2014,7 +2022,7 @@ void fft(const std::vector & in, std::vector & out) { } // ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124 -bool log_mel_spectrogram( +static bool log_mel_spectrogram( const float * samples, const int n_samples, const int sample_rate, @@ -2339,6 +2347,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.print_realtime =*/ false, /*.print_timestamps =*/ true, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.language =*/ "en", /*.greedy =*/ { @@ -2371,6 +2384,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.print_realtime =*/ false, /*.print_timestamps =*/ true, + /*.token_timestamps =*/ false, + /*.thold_pt =*/ 0.01f, + /*.thold_ptsum =*/ 0.01f, + /*.max_len =*/ 0, + /*.language =*/ "en", /*.greedy =*/ { @@ -2392,6 +2410,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str return result; } +// forward declarations +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window); +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context * ctx, + int i_segment, + float thold_pt, + float thold_ptsum); + +// wrap the last segment to max_len characters +// returns the number of new segments +static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) { + auto segment = ctx->result_all.back(); + + int res = 1; + int acc = 0; + + std::string text; + + for (int i = 0; i < (int) segment.tokens.size(); i++) { + const auto & token = segment.tokens[i]; + if (token.id >= whisper_token_eot(ctx)) { + continue; + } + + const auto txt = whisper_token_to_str(ctx, token.id); + + const int cur = strlen(txt); + + if (acc + cur > max_len && i > 0) { + // split here + ctx->result_all.back().text = std::move(text); + ctx->result_all.back().t1 = token.t0; + ctx->result_all.back().tokens.resize(i); + + ctx->result_all.push_back({}); + ctx->result_all.back().t0 = token.t0; + ctx->result_all.back().t1 = segment.t1; + + // add tokens [i, end] to the new segment + ctx->result_all.back().tokens.insert( + ctx->result_all.back().tokens.end(), + segment.tokens.begin() + i, + segment.tokens.end()); + + acc = 0; + text = ""; + + segment = ctx->result_all.back(); + i = -1; + + res++; + } else { + acc += cur; + text += txt; + } + } + + ctx->result_all.back().text = std::move(text); + + return res; +} + int whisper_full( struct whisper_context * ctx, struct whisper_full_params params, @@ -2408,6 +2488,13 @@ int whisper_full( return -1; } + if (params.token_timestamps) { + ctx->t_beg = 0; + ctx->t_last = 0; + ctx->tid_last = 0; + ctx->energy = get_signal_energy(samples, n_samples, 32); + } + const int seek_start = params.offset_ms/10; // if length of spectrogram is less than 1s (100 samples), then return @@ -2557,6 +2644,7 @@ int whisper_full( } } + // shrink down to result_len tokens_cur.resize(result_len); for (const auto & r : tokens_cur) { @@ -2595,8 +2683,19 @@ int whisper_full( for (int j = i0; j <= i; j++) { result_all.back().tokens.push_back(tokens_cur[j]); } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(ctx, params.max_len); + } + } if (params.new_segment_callback) { - params.new_segment_callback(ctx, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } } text = ""; @@ -2625,8 +2724,19 @@ int whisper_full( for (int j = i0; j < (int) tokens_cur.size(); j++) { result_all.back().tokens.push_back(tokens_cur[j]); } + + int n_new = 1; + + if (params.token_timestamps) { + whisper_exp_compute_token_level_timestamps( + ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum); + + if (params.max_len > 0) { + n_new = whisper_wrap_segment(ctx, params.max_len); + } + } if (params.new_segment_callback) { - params.new_segment_callback(ctx, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data); } } } @@ -2760,7 +2870,7 @@ int whisper_full_parallel( // call the new_segment_callback for each segment if (params.new_segment_callback) { - params.new_segment_callback(ctx, params.new_segment_callback_user_data); + params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data); } } @@ -2836,3 +2946,304 @@ const char * whisper_print_system_info() { return s.c_str(); } + +// ================================================================================================= + +// +// Experimental stuff below +// +// Not sure if these should be part of the library at all, because the quality of the results is not +// guaranteed. Might get removed at some point unless a robust algorithm implementation is found +// + +// ================================================================================================= + +// +// token-level timestamps +// + +static int timestamp_to_sample(int64_t t, int n_samples) { + return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100))); +} + +static int64_t sample_to_timestamp(int i_sample) { + return (100*i_sample)/WHISPER_SAMPLE_RATE; +} + +// a cost-function / heuristic that is high for text that takes longer to pronounce +// obviously, can be improved +static float voice_length(const std::string & text) { + float res = 0.0f; + + for (size_t i = 0; i < text.size(); ++i) { + if (text[i] == ' ') { + res += 0.01f; + } else if (text[i] == ',') { + res += 2.00f; + } else if (text[i] == '.') { + res += 3.00f; + } else if (text[i] == '!') { + res += 3.00f; + } else if (text[i] == '?') { + res += 3.00f; + } else if (text[i] >= '0' && text[i] <= '9') { + res += 3.00f; + } else { + res += 1.00f; + } + } + + return res; +} + +// average the fabs of the signal +static std::vector get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) { + const int hw = n_samples_per_half_window; + + std::vector result(n_samples); + + for (int i = 0; i < n_samples; i++) { + float sum = 0; + for (int j = -hw; j <= hw; j++) { + if (i + j >= 0 && i + j < n_samples) { + sum += fabs(signal[i + j]); + } + } + result[i] = sum/(2*hw + 1); + } + + return result; +} + +static void whisper_exp_compute_token_level_timestamps( + struct whisper_context * ctx, + int i_segment, + float thold_pt, + float thold_ptsum) { + auto & segment = ctx->result_all[i_segment]; + auto & tokens = segment.tokens; + + const int n_samples = ctx->energy.size(); + + if (n_samples == 0) { + fprintf(stderr, "%s: no signal data available\n", __func__); + return; + } + + const int64_t t0 = segment.t0; + const int64_t t1 = segment.t1; + + const int s0 = timestamp_to_sample(t0, n_samples); + const int s1 = timestamp_to_sample(t1, n_samples); + + const int n = tokens.size(); + + if (n == 0) { + return; + } + + if (n == 1) { + tokens[0].t0 = t0; + tokens[0].t1 = t1; + + return; + } + + auto & t_beg = ctx->t_beg; + auto & t_last = ctx->t_last; + auto & tid_last = ctx->tid_last; + + for (int j = 0; j < n; ++j) { + auto & token = tokens[j]; + + if (j == 0) { + if (token.id == whisper_token_beg(ctx)) { + tokens[j ].t0 = t0; + tokens[j ].t1 = t0; + tokens[j + 1].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = whisper_token_beg(ctx); + } else { + tokens[j ].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx)); + + tokens[j].id = token.id; + tokens[j].tid = token.tid; + tokens[j].p = token.p; + tokens[j].pt = token.pt; + tokens[j].ptsum = token.ptsum; + + tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id)); + + if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) { + if (j > 0) { + tokens[j - 1].t1 = tt; + } + tokens[j].t0 = tt; + tid_last = token.tid; + } + } + + tokens[n - 2].t1 = t1; + tokens[n - 1].t0 = t1; + tokens[n - 1].t1 = t1; + + t_last = t1; + + // find intervals of tokens with unknown timestamps + // fill the timestamps by proportionally splitting the interval based on the token voice lengths + { + int p0 = 0; + int p1 = 0; + + while (true) { + while (p1 < n && tokens[p1].t1 < 0) { + p1++; + } + + if (p1 >= n) { + p1--; + } + + if (p1 > p0) { + double psum = 0.0; + for (int j = p0; j <= p1; j++) { + psum += tokens[j].vlen; + } + + //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + + const double dt = tokens[p1].t1 - tokens[p0].t0; + + // split the time proportionally to the voice length + for (int j = p0 + 1; j <= p1; j++) { + const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum; + + tokens[j - 1].t1 = ct; + tokens[j ].t0 = ct; + } + } + + p1++; + p0 = p1; + if (p1 >= n) { + break; + } + } + } + + // fix up (just in case) + for (int j = 0; j < n - 1; j++) { + if (tokens[j].t1 < 0) { + tokens[j + 1].t0 = tokens[j].t1; + } + + if (j > 0) { + if (tokens[j - 1].t1 > tokens[j].t0) { + tokens[j].t0 = tokens[j - 1].t1; + tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1); + } + } + } + + // VAD + // expand or contract tokens based on voice activity + { + const int hw = WHISPER_SAMPLE_RATE/8; + + for (int j = 0; j < n; j++) { + if (tokens[j].id >= whisper_token_eot(ctx)) { + continue; + } + + int s0 = timestamp_to_sample(tokens[j].t0, n_samples); + int s1 = timestamp_to_sample(tokens[j].t1, n_samples); + + const int ss0 = std::max(s0 - hw, 0); + const int ss1 = std::min(s1 + hw, n_samples); + + const int ns = ss1 - ss0; + + float sum = 0.0f; + + for (int k = ss0; k < ss1; k++) { + sum += ctx->energy[k]; + } + + const float thold = 0.5*sum/ns; + + { + int k = s0; + if (ctx->energy[k] > thold && j > 0) { + while (k > 0 && ctx->energy[k] > thold) { + k--; + } + tokens[j].t0 = sample_to_timestamp(k); + if (tokens[j].t0 < tokens[j - 1].t1) { + tokens[j].t0 = tokens[j - 1].t1; + } else { + s0 = k; + } + } else { + while (ctx->energy[k] < thold && k < s1) { + k++; + } + s0 = k; + tokens[j].t0 = sample_to_timestamp(k); + } + } + + { + int k = s1; + if (ctx->energy[k] > thold) { + while (k < n_samples - 1 && ctx->energy[k] > thold) { + k++; + } + tokens[j].t1 = sample_to_timestamp(k); + if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) { + tokens[j].t1 = tokens[j + 1].t0; + } else { + s1 = k; + } + } else { + while (ctx->energy[k] < thold && k > s0) { + k--; + } + s1 = k; + tokens[j].t1 = sample_to_timestamp(k); + } + } + } + } + + // fixed token expand (optional) + //{ + // const int t_expand = 0; + + // for (int j = 0; j < n; j++) { + // if (j > 0) { + // tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand)); + // } + // if (j < n - 1) { + // tokens[j].t1 = tokens[j].t1 + t_expand; + // } + // } + //} + + // debug info + //for (int j = 0; j < n; ++j) { + // const auto & token = tokens[j]; + // const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]"; + // printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__, + // tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id)); + + // if (tokens[j].id >= whisper_token_eot(ctx)) { + // continue; + // } + //} +} diff --git a/whisper.h b/whisper.h index 5d7c40d0..57ea5db8 100644 --- a/whisper.h +++ b/whisper.h @@ -68,14 +68,21 @@ extern "C" { typedef int whisper_token; - struct whisper_token_data { + typedef struct whisper_token_data { whisper_token id; // token id whisper_token tid; // forced timestamp token id float p; // probability of the token float pt; // probability of the timestamp token float ptsum; // sum of probabilities of all timestamp tokens - }; + + // token-level timestamp data + // do not use if you haven't computed token-level timestamps + int64_t t0; // start time of the token + int64_t t1; // end time of the token + + float vlen; // voice length of the token + } whisper_token_data; // Allocates all memory needed for the model and loads the model from the given file. // Returns NULL on failure. @@ -129,7 +136,7 @@ extern "C" { // You can also implement your own sampling method using the whisper_get_probs() function. // whisper_sample_best() returns the token with the highest probability // whisper_sample_timestamp() returns the most probable timestamp token - WHISPER_API struct whisper_token_data whisper_sample_best(struct whisper_context * ctx); + WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx); WHISPER_API whisper_token whisper_sample_timestamp(struct whisper_context * ctx); // Return the id of the specified language, returns -1 if not found @@ -172,7 +179,7 @@ extern "C" { // Text segment callback // Called on every newly generated text segment // Use the whisper_full_...() functions to obtain the text segments - typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, void * user_data); + typedef void (*whisper_new_segment_callback)(struct whisper_context * ctx, int n_new, void * user_data); struct whisper_full_params { enum whisper_sampling_strategy strategy; @@ -188,6 +195,12 @@ extern "C" { bool print_realtime; bool print_timestamps; + // [EXPERIMENTAL] token-level timestamps + bool token_timestamps; // enable token-level timestamps + float thold_pt; // timestamp token probability threshold (~0.01) + float thold_ptsum; // timestamp token sum probability threshold (~0.01) + int max_len; // max segment length in characters + const char * language; struct { @@ -244,7 +257,7 @@ extern "C" { // Get token data for the specified token in the specified segment. // This contains probabilities, timestamps, etc. - WHISPER_API struct whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); + WHISPER_API whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token); // Get the probability of the specified token in the specified segment. WHISPER_API float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token);