mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-19 04:37:51 +00:00
allow stream prompt
This commit is contained in:
parent
d17e7139d8
commit
ec05d7705a
@ -40,6 +40,8 @@ struct whisper_params {
|
|||||||
std::string language = "en";
|
std::string language = "en";
|
||||||
std::string model = "models/ggml-base.en.bin";
|
std::string model = "models/ggml-base.en.bin";
|
||||||
std::string fname_out;
|
std::string fname_out;
|
||||||
|
|
||||||
|
std::string initial_prompt;
|
||||||
};
|
};
|
||||||
|
|
||||||
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
void whisper_print_usage(int argc, char ** argv, const whisper_params & params);
|
||||||
@ -72,7 +74,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
|
|||||||
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
|
else if (arg == "-sa" || arg == "--save-audio") { params.save_audio = true; }
|
||||||
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; }
|
||||||
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; }
|
||||||
|
else if ( arg == "--prompt") {params.initial_prompt = argv[++i]; }
|
||||||
else {
|
else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
whisper_print_usage(argc, argv, params);
|
whisper_print_usage(argc, argv, params);
|
||||||
@ -109,6 +111,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
|
|||||||
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
|
fprintf(stderr, " -sa, --save-audio [%-7s] save the recorded audio to a file\n", params.save_audio ? "true" : "false");
|
||||||
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
|
fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU inference\n", params.use_gpu ? "false" : "true");
|
||||||
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false");
|
fprintf(stderr, " -fa, --flash-attn [%-7s] flash attention during inference\n", params.flash_attn ? "true" : "false");
|
||||||
|
fprintf(stderr, " --prompt PROMPT [%-7s] initial prompt (max n_text_ctx/2 tokens)\n", params.initial_prompt.c_str());
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -164,7 +167,13 @@ int main(int argc, char ** argv) {
|
|||||||
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
|
std::vector<float> pcmf32_new(n_samples_30s, 0.0f);
|
||||||
|
|
||||||
std::vector<whisper_token> prompt_tokens;
|
std::vector<whisper_token> prompt_tokens;
|
||||||
|
prompt_tokens.resize(1024);
|
||||||
|
int n= whisper_tokenize(ctx, params.initial_prompt.c_str(),prompt_tokens.data(),1024);
|
||||||
|
if (n < 0) {
|
||||||
|
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, params.initial_prompt.c_str());
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
prompt_tokens.resize(n);
|
||||||
// print some info about the processing
|
// print some info about the processing
|
||||||
{
|
{
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
@ -318,8 +327,8 @@ int main(int argc, char ** argv) {
|
|||||||
//wparams.temperature_inc = -1.0f;
|
//wparams.temperature_inc = -1.0f;
|
||||||
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
|
wparams.temperature_inc = params.no_fallback ? 0.0f : wparams.temperature_inc;
|
||||||
|
|
||||||
wparams.prompt_tokens = params.no_context ? nullptr : prompt_tokens.data();
|
wparams.prompt_tokens = prompt_tokens.data();
|
||||||
wparams.prompt_n_tokens = params.no_context ? 0 : prompt_tokens.size();
|
wparams.prompt_n_tokens = prompt_tokens.size();
|
||||||
|
|
||||||
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
||||||
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
fprintf(stderr, "%s: failed to process audio\n", argv[0]);
|
||||||
@ -397,7 +406,13 @@ int main(int argc, char ** argv) {
|
|||||||
// Add tokens of the last full length segment as the prompt
|
// Add tokens of the last full length segment as the prompt
|
||||||
if (!params.no_context) {
|
if (!params.no_context) {
|
||||||
prompt_tokens.clear();
|
prompt_tokens.clear();
|
||||||
|
prompt_tokens.resize(1024);
|
||||||
|
int n= whisper_tokenize(ctx, params.initial_prompt.c_str(),prompt_tokens.data(),1024);
|
||||||
|
if (n < 0) {
|
||||||
|
fprintf(stderr, "%s: error: failed to tokenize prompt '%s'\n", __func__, params.initial_prompt.c_str());
|
||||||
|
return 3;
|
||||||
|
}
|
||||||
|
prompt_tokens.resize(n);
|
||||||
const int n_segments = whisper_full_n_segments(ctx);
|
const int n_segments = whisper_full_n_segments(ctx);
|
||||||
for (int i = 0; i < n_segments; ++i) {
|
for (int i = 0; i < n_segments; ++i) {
|
||||||
const int token_count = whisper_full_n_tokens(ctx, i);
|
const int token_count = whisper_full_n_tokens(ctx, i);
|
||||||
|
Loading…
Reference in New Issue
Block a user