From 63ae03b8e09acb5a351f1973e871d7e2a0518529 Mon Sep 17 00:00:00 2001 From: "M. Eren Akbiyik" Date: Tue, 22 Nov 2022 17:10:35 +0100 Subject: [PATCH] Prompt previous tokens for streaming (#163) * feat: prompt previous tokens for streaming I used a vector pointer instead of vector itself because it gave weird errors, and why not * convert vector to use with C api * feat: remove old refs, check for prompt size * feat: use better way of getting the pointer --- examples/stream/stream.cpp | 14 ++++++++++++++ whisper.cpp | 15 +++++++++++++++ whisper.h | 4 ++++ 3 files changed, 33 insertions(+) diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index 6f3634b7..32f93d6f 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -234,6 +234,7 @@ int main(int argc, char ** argv) { std::vector pcmf32(n_samples_30s, 0.0f); std::vector pcmf32_old; + std::vector prompt_tokens; const int n_new_line = params.length_ms / params.step_ms - 1; // print some info about the processing @@ -344,6 +345,9 @@ int main(int argc, char ** argv) { wparams.audio_ctx = params.audio_ctx; wparams.speed_up = params.speed_up; + wparams.prompt_tokens = prompt_tokens.data(); + wparams.prompt_n_tokens = prompt_tokens.size(); + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { fprintf(stderr, "%s: failed to process audio\n", argv[0]); return 6; @@ -393,6 +397,16 @@ int main(int argc, char ** argv) { // keep part of the audio for next iteration to try to mitigate word boundary issues pcmf32_old = std::vector(pcmf32.end() - n_samples_keep, pcmf32.end()); + + // Add tokens of the last full length segment as the prompt + prompt_tokens.clear(); + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const int token_count = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < token_count; ++j) { + prompt_tokens.push_back(whisper_full_get_token_id(ctx, i, j)); + } + } } } } diff --git a/whisper.cpp b/whisper.cpp index 7052355b..28c5d26a 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -2412,6 +2412,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.speed_up =*/ false, /*.audio_ctx =*/ 0, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + /*.language =*/ "en", /*.greedy =*/ { @@ -2455,6 +2458,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.speed_up =*/ false, /*.audio_ctx =*/ 0, + /*.prompt_tokens =*/ nullptr, + /*.prompt_n_tokens =*/ 0, + /*.language =*/ "en", /*.greedy =*/ { @@ -2584,6 +2590,15 @@ int whisper_full( prompt_past.clear(); } + // Prepend the prompt tokens to the prompt_past + if (params.prompt_tokens && params.prompt_n_tokens > 0) { + // Parse tokens from the pointer (it points to an std::vector) + for (int i = 0; i < params.prompt_n_tokens; i++) { + prompt_past.push_back(params.prompt_tokens[i]); + } + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); + } + // overwrite audio_ctx ctx->exp_n_audio_ctx = params.audio_ctx; diff --git a/whisper.h b/whisper.h index 88cc7113..1b2a042b 100644 --- a/whisper.h +++ b/whisper.h @@ -208,6 +208,10 @@ extern "C" { bool speed_up; // speed-up the audio by 2x using Phase Vocoder int audio_ctx; // overwrite the audio context size (0 = use default) + // std::vector: tokens to provide the whisper model as initial prompt + const whisper_token * prompt_tokens; + int prompt_n_tokens; + const char * language; struct {