From 015e562d7eba2c0f696b84880398e6501f1e092d Mon Sep 17 00:00:00 2001 From: Pranav Aditya Seelam Date: Mon, 11 Nov 2024 16:39:09 -0800 Subject: [PATCH 1/3] Added infinite conversation feature to talk-llama --- examples/talk-llama/talk-llama.cpp | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 1b9de94d..a8121c77 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -66,6 +66,7 @@ struct whisper_params { bool verbose_prompt = false; bool use_gpu = true; bool flash_attn = false; + bool inf = false; std::string person = "Georgi"; std::string bot_name = "LLaMA"; @@ -105,6 +106,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } + else if (arg == "-inf" || arg == "--infinite") { params.inf = true; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; } else if (arg == "--session") { params.path_session = argv[++i]; } @@ -165,6 +167,7 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", ""); fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n"); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); + fprintf(stderr, " -inf, --infinite-text [%-7s] infinite text\n", params.flash_attn ? "true" : "false"); fprintf(stderr, "\n"); } @@ -647,12 +650,23 @@ int main(int argc, char ** argv) { // predict if (embd.size() > 0) { if (n_past + (int) embd.size() > n_ctx) { - n_past = n_keep; + //n_past = n_keep; + if ((params.inf == true) && (n_past + (int) embd.size() >= n_ctx - (int)10) && (n_past != n_keep)) { //checks if infinite context is enabled and if the total number of tokens processed so far + the suze of embd is greater than the context size + //std::cout<<"\nCache will be adjusted\n"; + //std::cout << "Number of tokens already processed (n_past): " << n_past << "\n"; + //std::cout << "Number of tokens in n_keep: " << (int) n_keep << "\n"; + const int n_left = n_past - n_keep; //the number of tokens beyond the ones we want to keep + const int n_discard = n_left; //we decide to discard half of the tokens beyond the ones we want to keep + //std::cout << "Number of tokens to discard: " << n_discard << "\n"; + llama_kv_cache_seq_rm(ctx_llama, 0, n_keep , n_keep + n_discard); //the number of tokens beyond the on + llama_kv_cache_seq_add(ctx_llama, 0, n_keep + n_discard, n_past, -n_discard); // this function is adjusting the cache by + n_past -= n_discard; + continue; // insert n_left/2 tokens at the start of embd from last_n_tokens - embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end()); + //embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end()); // stop saving session if we run out of context - path_session = ""; + //path_session = ""; //printf("\n---\n"); //printf("resetting: '"); //for (int i = 0; i < (int) embd.size(); i++) { From e57c38994046ccb0e474ae9451498271cd7d010a Mon Sep 17 00:00:00 2001 From: Pranav Aditya Seelam Date: Thu, 14 Nov 2024 14:46:34 -0800 Subject: [PATCH 2/3] Added reset kv cache after question feature and adjust kv cache when space is full --- examples/talk-llama/talk-llama.cpp | 45 +++++++++++++++++++----------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index a8121c77..42ee2be3 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -66,8 +66,8 @@ struct whisper_params { bool verbose_prompt = false; bool use_gpu = true; bool flash_attn = false; - bool inf = false; - + bool reset_cache = false; + bool infinite_inference = false; std::string person = "Georgi"; std::string bot_name = "LLaMA"; std::string wake_cmd = ""; @@ -106,7 +106,6 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-vp" || arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } else if (arg == "-fa" || arg == "--flash-attn") { params.flash_attn = true; } - else if (arg == "-inf" || arg == "--infinite") { params.inf = true; } else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; } else if (arg == "--session") { params.path_session = argv[++i]; } @@ -117,6 +116,8 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; } else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; } else if (arg == "-sf" || arg == "--speak-file") { params.speak_file = argv[++i]; } + else if (arg == "-inf" || arg == "--infinite_inference") { params.infinite_inference = True; } + else if (arg == "-reset" || arg == "--reset_cache") { params.reset_cache = True; } else if (arg == "--prompt-file") { std::ifstream file(argv[++i]); std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); @@ -167,7 +168,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", ""); fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n"); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); - fprintf(stderr, " -inf, --infinite-text [%-7s] infinite text\n", params.flash_attn ? "true" : "false"); + fprintf(stderr, " -inf, --infinite_inference [%-7s] infinite inference\n", params.infinite_inference ? "true" : "false"); + fprintf(stderr, " -reset, --reset_cache [%-7s] reset cache after each question\n", params.reset_cache ? "true" : "false"); fprintf(stderr, "\n"); } @@ -197,6 +199,7 @@ static std::string transcribe( wparams.translate = params.translate; wparams.no_context = true; wparams.single_segment = true; + wparams.infinite_inference = true; wparams.max_tokens = params.max_tokens; wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; @@ -646,23 +649,31 @@ int main(int argc, char ** argv) { // text inference bool done = false; std::string text_to_speak; + if (params.reset_cache) { + int n_discard = lcparams.n_ctx - n_keep; + //std::cout << "Number of tokens to discard: " << n_discard << "\n"; + + llama_kv_cache_seq_rm(ctx_llama, 0, n_keep, n_keep + n_discard); //the number of tokens beyond the on + n_past = n_keep; + } while (true) { // predict if (embd.size() > 0) { - if (n_past + (int) embd.size() > n_ctx) { - //n_past = n_keep; - if ((params.inf == true) && (n_past + (int) embd.size() >= n_ctx - (int)10) && (n_past != n_keep)) { //checks if infinite context is enabled and if the total number of tokens processed so far + the suze of embd is greater than the context size - //std::cout<<"\nCache will be adjusted\n"; - //std::cout << "Number of tokens already processed (n_past): " << n_past << "\n"; - //std::cout << "Number of tokens in n_keep: " << (int) n_keep << "\n"; - const int n_left = n_past - n_keep; //the number of tokens beyond the ones we want to keep - const int n_discard = n_left; //we decide to discard half of the tokens beyond the ones we want to keep - //std::cout << "Number of tokens to discard: " << n_discard << "\n"; - llama_kv_cache_seq_rm(ctx_llama, 0, n_keep , n_keep + n_discard); //the number of tokens beyond the on - llama_kv_cache_seq_add(ctx_llama, 0, n_keep + n_discard, n_past, -n_discard); // this function is adjusting the cache by + if (params.infinite_inference && (n_past + (int) embd.size() > n_ctx)) { + std::cout<<"\nInfinite context enabled\n"; + const int n_left = n_past - n_keep; //the number of tokens beyond the ones we want to keep + const int n_discard = n_left/2; //we decide to discard half of the tokens beyond the ones we want to keep + + llama_kv_cache_seq_rm (ctx_llama, 0, n_keep , n_keep + n_discard); //the number of tokens beyond the on + llama_kv_cache_seq_add(ctx_llama, 0, n_keep + n_discard, n_past, -n_discard); // this function is adjusting the cache by + + n_past -= n_discard; + // insert n_left/2 tokens at the start of embd from last_n_tokens + //embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end()); + // stop saving session if we run out of context + path_session.clear(); + continue; - n_past -= n_discard; - continue; // insert n_left/2 tokens at the start of embd from last_n_tokens //embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end()); // stop saving session if we run out of context From bb3d903820cd04ce2de51db9b981f5f801fa9c18 Mon Sep 17 00:00:00 2001 From: Pranav Aditya Seelam Date: Thu, 14 Nov 2024 15:25:36 -0800 Subject: [PATCH 3/3] Fixed inf loop --- examples/talk-llama/talk-llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 42ee2be3..3ca7f32b 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -659,7 +659,7 @@ int main(int argc, char ** argv) { while (true) { // predict if (embd.size() > 0) { - if (params.infinite_inference && (n_past + (int) embd.size() > n_ctx)) { + if (params.infinite_inference && (n_past + (int) embd.size() > n_ctx) && n_past != n_keep) { std::cout<<"\nInfinite context enabled\n"; const int n_left = n_past - n_keep; //the number of tokens beyond the ones we want to keep const int n_discard = n_left/2; //we decide to discard half of the tokens beyond the ones we want to keep