Merge b87d6691f5b2fa248ba0664caf1898039b63dada into 448f3d3b93f2411045fb3192fa4d5ddb21eaba4e

This commit is contained in:
Pranav Seelam 2025-04-05 05:23:35 +02:00 committed by GitHub
commit bcc497f146
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -71,7 +71,8 @@ struct whisper_params {
bool verbose_prompt = false;
bool use_gpu = true;
bool flash_attn = false;
bool reset_cache = false;
bool infinite_inference = false;
std::string person = "Georgi";
std::string bot_name = "LLaMA";
std::string wake_cmd = "";
@ -120,6 +121,8 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "-sf" || arg == "--speak-file") { params.speak_file = argv[++i]; }
else if (arg == "-inf" || arg == "--infinite_inference") { params.infinite_inference = True; }
else if (arg == "-reset" || arg == "--reset_cache") { params.reset_cache = True; }
else if (arg == "--prompt-file") {
std::ifstream file(argv[++i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
@ -170,6 +173,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -inf, --infinite_inference [%-7s] infinite inference\n", params.infinite_inference ? "true" : "false");
fprintf(stderr, " -reset, --reset_cache [%-7s] reset cache after each question\n", params.reset_cache ? "true" : "false");
fprintf(stderr, "\n");
}
@ -199,6 +204,7 @@ static std::string transcribe(
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.infinite_inference = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
@ -650,16 +656,35 @@ int main(int argc, char ** argv) {
// text inference
bool done = false;
std::string text_to_speak;
if (params.reset_cache) {
int n_discard = lcparams.n_ctx - n_keep;
//std::cout << "Number of tokens to discard: " << n_discard << "\n";
llama_kv_cache_seq_rm(ctx_llama, 0, n_keep, n_keep + n_discard); //the number of tokens beyond the on
n_past = n_keep;
}
while (true) {
// predict
if (embd.size() > 0) {
if (n_past + (int) embd.size() > n_ctx) {
n_past = n_keep;
if (params.infinite_inference && (n_past + (int) embd.size() > n_ctx) && n_past != n_keep) {
std::cout<<"\nInfinite context enabled\n";
const int n_left = n_past - n_keep; //the number of tokens beyond the ones we want to keep
const int n_discard = n_left/2; //we decide to discard half of the tokens beyond the ones we want to keep
llama_kv_cache_seq_rm (ctx_llama, 0, n_keep , n_keep + n_discard); //the number of tokens beyond the on
llama_kv_cache_seq_add(ctx_llama, 0, n_keep + n_discard, n_past, -n_discard); // this function is adjusting the cache by
n_past -= n_discard;
// insert n_left/2 tokens at the start of embd from last_n_tokens
//embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
// stop saving session if we run out of context
path_session.clear();
continue;
// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
//embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
// stop saving session if we run out of context
path_session = "";
//path_session = "";
//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {