mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
ref #10 : option to keep context in "stream" example
Seems the results become worse when we keep the context, so by default this is not enabled
This commit is contained in:
parent
3f15bb8a08
commit
481cd685d5
@ -40,6 +40,7 @@ struct whisper_params {
|
|||||||
|
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
bool translate = false;
|
bool translate = false;
|
||||||
|
bool no_context = true;
|
||||||
bool print_special_tokens = false;
|
bool print_special_tokens = false;
|
||||||
bool no_timestamps = true;
|
bool no_timestamps = true;
|
||||||
|
|
||||||
@ -64,6 +65,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
|||||||
params.verbose = true;
|
params.verbose = true;
|
||||||
} else if (arg == "--translate") {
|
} else if (arg == "--translate") {
|
||||||
params.translate = true;
|
params.translate = true;
|
||||||
|
} else if (arg == "-kc" || arg == "--keep-context") {
|
||||||
|
params.no_context = false;
|
||||||
} else if (arg == "-l" || arg == "--language") {
|
} else if (arg == "-l" || arg == "--language") {
|
||||||
params.language = argv[++i];
|
params.language = argv[++i];
|
||||||
if (whisper_lang_id(params.language.c_str()) == -1) {
|
if (whisper_lang_id(params.language.c_str()) == -1) {
|
||||||
@ -103,6 +106,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
|||||||
fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms);
|
fprintf(stderr, " --step N audio step size in milliseconds (default: %d)\n", params.step_ms);
|
||||||
fprintf(stderr, " -v, --verbose verbose output\n");
|
fprintf(stderr, " -v, --verbose verbose output\n");
|
||||||
fprintf(stderr, " --translate translate from source language to english\n");
|
fprintf(stderr, " --translate translate from source language to english\n");
|
||||||
|
fprintf(stderr, " -nc, --no-context disable context from earlier audio (default: false)\n");
|
||||||
fprintf(stderr, " -ps, --print_special print special tokens\n");
|
fprintf(stderr, " -ps, --print_special print special tokens\n");
|
||||||
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
|
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
|
||||||
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
|
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
|
||||||
@ -273,6 +277,7 @@ int main(int argc, char ** argv) {
|
|||||||
wparams.print_realtime = false;
|
wparams.print_realtime = false;
|
||||||
wparams.print_timestamps = !params.no_timestamps;
|
wparams.print_timestamps = !params.no_timestamps;
|
||||||
wparams.translate = params.translate;
|
wparams.translate = params.translate;
|
||||||
|
wparams.no_context = params.no_context;
|
||||||
wparams.language = params.language.c_str();
|
wparams.language = params.language.c_str();
|
||||||
wparams.n_threads = params.n_threads;
|
wparams.n_threads = params.n_threads;
|
||||||
|
|
||||||
|
19
whisper.cpp
19
whisper.cpp
@ -405,6 +405,8 @@ struct whisper_context {
|
|||||||
|
|
||||||
std::vector<whisper_result> result_cur;
|
std::vector<whisper_result> result_cur;
|
||||||
std::vector<whisper_segment> result_all;
|
std::vector<whisper_segment> result_all;
|
||||||
|
|
||||||
|
std::vector<whisper_token> prompt_past;
|
||||||
};
|
};
|
||||||
|
|
||||||
// load the model from a ggml file
|
// load the model from a ggml file
|
||||||
@ -1020,8 +1022,6 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
|||||||
// - model: the model
|
// - model: the model
|
||||||
// - n_threads: number of threads to use
|
// - n_threads: number of threads to use
|
||||||
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
|
||||||
// - mel_inp: input mel spectrogram
|
|
||||||
// - features: output encoded features
|
|
||||||
//
|
//
|
||||||
bool whisper_encode(
|
bool whisper_encode(
|
||||||
whisper_context & wctx,
|
whisper_context & wctx,
|
||||||
@ -1405,10 +1405,9 @@ bool whisper_encode(
|
|||||||
//
|
//
|
||||||
// - model: the model
|
// - model: the model
|
||||||
// - n_threads: number of threads to use
|
// - n_threads: number of threads to use
|
||||||
// - n_past: prompt length
|
// - tokens: text prompt
|
||||||
// - prompt: text prompt
|
// - n_tokens: number of tokens in the prompt
|
||||||
// - logits_out: output logits
|
// - n_past: number of past tokens to prefix the prompt with
|
||||||
// - probs_out: output probabilities
|
|
||||||
//
|
//
|
||||||
bool whisper_decode(
|
bool whisper_decode(
|
||||||
whisper_context & wctx,
|
whisper_context & wctx,
|
||||||
@ -2259,6 +2258,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
|
|||||||
.offset_ms = 0,
|
.offset_ms = 0,
|
||||||
|
|
||||||
.translate = false,
|
.translate = false,
|
||||||
|
.no_context = false,
|
||||||
.print_special_tokens = false,
|
.print_special_tokens = false,
|
||||||
.print_progress = true,
|
.print_progress = true,
|
||||||
.print_realtime = false,
|
.print_realtime = false,
|
||||||
@ -2279,6 +2279,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
|
|||||||
.offset_ms = 0,
|
.offset_ms = 0,
|
||||||
|
|
||||||
.translate = false,
|
.translate = false,
|
||||||
|
.no_context = false,
|
||||||
.print_special_tokens = false,
|
.print_special_tokens = false,
|
||||||
.print_progress = true,
|
.print_progress = true,
|
||||||
.print_realtime = false,
|
.print_realtime = false,
|
||||||
@ -2297,6 +2298,7 @@ struct whisper_full_params whisper_full_default_params(enum whisper_decode_strat
|
|||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
int whisper_full(
|
int whisper_full(
|
||||||
struct whisper_context * ctx,
|
struct whisper_context * ctx,
|
||||||
struct whisper_full_params params,
|
struct whisper_full_params params,
|
||||||
@ -2309,7 +2311,10 @@ int whisper_full(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// the accumulated text context so far
|
// the accumulated text context so far
|
||||||
std::vector<whisper_token> prompt_past = { };
|
auto & prompt_past = ctx->prompt_past;
|
||||||
|
if (params.no_context) {
|
||||||
|
prompt_past.clear();
|
||||||
|
}
|
||||||
|
|
||||||
// these tokens determine the task that will be performed
|
// these tokens determine the task that will be performed
|
||||||
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
std::vector<whisper_token> prompt_init = { whisper_token_sot(ctx) };
|
||||||
|
Loading…
Reference in New Issue
Block a user