From eefed45e37095c08e8107cf75fba8481c7105f14 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Thu, 30 Mar 2023 04:23:23 +0800 Subject: [PATCH] whisper : add initial_prompt param (#645) --- examples/addon.node/addon.cpp | 19 +------------------ examples/main/main.cpp | 19 +------------------ whisper.cpp | 10 ++++++++++ whisper.h | 1 + 4 files changed, 13 insertions(+), 36 deletions(-) diff --git a/examples/addon.node/addon.cpp b/examples/addon.node/addon.cpp index 0fa4a8ca..52e80ad8 100644 --- a/examples/addon.node/addon.cpp +++ b/examples/addon.node/addon.cpp @@ -160,22 +160,6 @@ int run(whisper_params ¶ms, std::vector> &result) { return 3; } - // initial prompt - std::vector prompt_tokens; - - if (!params.prompt.empty()) { - prompt_tokens.resize(1024); - prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size())); - - fprintf(stderr, "\n"); - fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str()); - fprintf(stderr, "initial tokens: [ "); - for (int i = 0; i < (int) prompt_tokens.size(); ++i) { - fprintf(stderr, "%d ", prompt_tokens[i]); - } - fprintf(stderr, "]\n"); - } - for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; @@ -243,8 +227,7 @@ int run(whisper_params ¶ms, std::vector> &result) { wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; - wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); - wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + wparams.initial_prompt = params.prompt.c_str(); whisper_print_user_data user_data = { ¶ms, &pcmf32s }; diff --git a/examples/main/main.cpp b/examples/main/main.cpp index dd30ba4c..7131a937 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -639,22 +639,6 @@ int main(int argc, char ** argv) { return 3; } - // initial prompt - std::vector prompt_tokens; - - if (!params.prompt.empty()) { - prompt_tokens.resize(1024); - prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size())); - - fprintf(stderr, "\n"); - fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str()); - fprintf(stderr, "initial tokens: [ "); - for (int i = 0; i < (int) prompt_tokens.size(); ++i) { - fprintf(stderr, "%d ", prompt_tokens[i]); - } - fprintf(stderr, "]\n"); - } - for (int f = 0; f < (int) params.fname_inp.size(); ++f) { const auto fname_inp = params.fname_inp[f]; const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f]; @@ -718,8 +702,7 @@ int main(int argc, char ** argv) { wparams.speed_up = params.speed_up; - wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data(); - wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size(); + wparams.initial_prompt = params.prompt.c_str(); wparams.greedy.best_of = params.best_of; wparams.beam_search.beam_size = params.beam_size; diff --git a/whisper.cpp b/whisper.cpp index f44e5034..13e11141 100644 --- a/whisper.cpp +++ b/whisper.cpp @@ -3121,6 +3121,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str /*.speed_up =*/ false, /*.audio_ctx =*/ 0, + /*.initial_prompt =*/ nullptr, /*.prompt_tokens =*/ nullptr, /*.prompt_n_tokens =*/ 0, @@ -3793,6 +3794,15 @@ int whisper_full_with_state( prompt_past.clear(); } + // initial prompt + if (!params.prompt_tokens && params.initial_prompt) { + std::vector prompt_tokens; + prompt_tokens.resize(1024); + prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size())); + params.prompt_tokens = prompt_tokens.data(); + params.prompt_n_tokens = prompt_tokens.size(); + } + // prepend the prompt tokens to the prompt_past if (params.prompt_tokens && params.prompt_n_tokens > 0) { // parse tokens from the pointer diff --git a/whisper.h b/whisper.h index fc107108..fa6bff4f 100644 --- a/whisper.h +++ b/whisper.h @@ -356,6 +356,7 @@ extern "C" { // tokens to provide to the whisper decoder as initial prompt // these are prepended to any existing text context from a previous call + const char * initial_prompt; const whisper_token * prompt_tokens; int prompt_n_tokens;