mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-21 21:47:47 +00:00
whisper : add initial_prompt param (#645)
This commit is contained in:
parent
aac1710afb
commit
eefed45e37
@ -160,22 +160,6 @@ int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
|||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// initial prompt
|
|
||||||
std::vector<whisper_token> prompt_tokens;
|
|
||||||
|
|
||||||
if (!params.prompt.empty()) {
|
|
||||||
prompt_tokens.resize(1024);
|
|
||||||
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
|
|
||||||
fprintf(stderr, "initial tokens: [ ");
|
|
||||||
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
|
|
||||||
fprintf(stderr, "%d ", prompt_tokens[i]);
|
|
||||||
}
|
|
||||||
fprintf(stderr, "]\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||||
const auto fname_inp = params.fname_inp[f];
|
const auto fname_inp = params.fname_inp[f];
|
||||||
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
const auto fname_out = f < (int)params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
||||||
@ -243,8 +227,7 @@ int run(whisper_params ¶ms, std::vector<std::vector<std::string>> &result) {
|
|||||||
wparams.greedy.best_of = params.best_of;
|
wparams.greedy.best_of = params.best_of;
|
||||||
wparams.beam_search.beam_size = params.beam_size;
|
wparams.beam_search.beam_size = params.beam_size;
|
||||||
|
|
||||||
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
|
wparams.initial_prompt = params.prompt.c_str();
|
||||||
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
|
||||||
|
|
||||||
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
whisper_print_user_data user_data = { ¶ms, &pcmf32s };
|
||||||
|
|
||||||
|
@ -639,22 +639,6 @@ int main(int argc, char ** argv) {
|
|||||||
return 3;
|
return 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// initial prompt
|
|
||||||
std::vector<whisper_token> prompt_tokens;
|
|
||||||
|
|
||||||
if (!params.prompt.empty()) {
|
|
||||||
prompt_tokens.resize(1024);
|
|
||||||
prompt_tokens.resize(whisper_tokenize(ctx, params.prompt.c_str(), prompt_tokens.data(), prompt_tokens.size()));
|
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
|
||||||
fprintf(stderr, "initial prompt: '%s'\n", params.prompt.c_str());
|
|
||||||
fprintf(stderr, "initial tokens: [ ");
|
|
||||||
for (int i = 0; i < (int) prompt_tokens.size(); ++i) {
|
|
||||||
fprintf(stderr, "%d ", prompt_tokens[i]);
|
|
||||||
}
|
|
||||||
fprintf(stderr, "]\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
for (int f = 0; f < (int) params.fname_inp.size(); ++f) {
|
||||||
const auto fname_inp = params.fname_inp[f];
|
const auto fname_inp = params.fname_inp[f];
|
||||||
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
const auto fname_out = f < (int) params.fname_out.size() && !params.fname_out[f].empty() ? params.fname_out[f] : params.fname_inp[f];
|
||||||
@ -718,8 +702,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
wparams.speed_up = params.speed_up;
|
wparams.speed_up = params.speed_up;
|
||||||
|
|
||||||
wparams.prompt_tokens = prompt_tokens.empty() ? nullptr : prompt_tokens.data();
|
wparams.initial_prompt = params.prompt.c_str();
|
||||||
wparams.prompt_n_tokens = prompt_tokens.empty() ? 0 : prompt_tokens.size();
|
|
||||||
|
|
||||||
wparams.greedy.best_of = params.best_of;
|
wparams.greedy.best_of = params.best_of;
|
||||||
wparams.beam_search.beam_size = params.beam_size;
|
wparams.beam_search.beam_size = params.beam_size;
|
||||||
|
10
whisper.cpp
10
whisper.cpp
@ -3121,6 +3121,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
|
|||||||
/*.speed_up =*/ false,
|
/*.speed_up =*/ false,
|
||||||
/*.audio_ctx =*/ 0,
|
/*.audio_ctx =*/ 0,
|
||||||
|
|
||||||
|
/*.initial_prompt =*/ nullptr,
|
||||||
/*.prompt_tokens =*/ nullptr,
|
/*.prompt_tokens =*/ nullptr,
|
||||||
/*.prompt_n_tokens =*/ 0,
|
/*.prompt_n_tokens =*/ 0,
|
||||||
|
|
||||||
@ -3793,6 +3794,15 @@ int whisper_full_with_state(
|
|||||||
prompt_past.clear();
|
prompt_past.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// initial prompt
|
||||||
|
if (!params.prompt_tokens && params.initial_prompt) {
|
||||||
|
std::vector<whisper_token> prompt_tokens;
|
||||||
|
prompt_tokens.resize(1024);
|
||||||
|
prompt_tokens.resize(whisper_tokenize(ctx, params.initial_prompt, prompt_tokens.data(), prompt_tokens.size()));
|
||||||
|
params.prompt_tokens = prompt_tokens.data();
|
||||||
|
params.prompt_n_tokens = prompt_tokens.size();
|
||||||
|
}
|
||||||
|
|
||||||
// prepend the prompt tokens to the prompt_past
|
// prepend the prompt tokens to the prompt_past
|
||||||
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
|
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
|
||||||
// parse tokens from the pointer
|
// parse tokens from the pointer
|
||||||
|
@ -356,6 +356,7 @@ extern "C" {
|
|||||||
|
|
||||||
// tokens to provide to the whisper decoder as initial prompt
|
// tokens to provide to the whisper decoder as initial prompt
|
||||||
// these are prepended to any existing text context from a previous call
|
// these are prepended to any existing text context from a previous call
|
||||||
|
const char * initial_prompt;
|
||||||
const whisper_token * prompt_tokens;
|
const whisper_token * prompt_tokens;
|
||||||
int prompt_n_tokens;
|
int prompt_n_tokens;
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user