From f6614155e40198bad739fed9400d0de8de9cc311 Mon Sep 17 00:00:00 2001 From: Benjamin Heiniger Date: Tue, 16 Jan 2024 14:52:01 +0100 Subject: [PATCH] talk-llama : optional wake-up command and audio confirmation (#1765) * talk-llama: add optional wake-word detection from command * talk-llama: add optional audio confirmation before generating answer * talk-llama: fix small formatting issue in output * talk-llama.cpp: fix Windows build --- examples/talk-llama/talk-llama.cpp | 64 +++++++++++++++++++++++++++++- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 5eef1f4e..d418d0c3 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -14,6 +14,7 @@ #include #include #include +#include std::vector llama_tokenize(struct llama_context * ctx, const std::string & text, bool add_bos) { auto * model = llama_get_model(ctx); @@ -68,6 +69,8 @@ struct whisper_params { std::string person = "Georgi"; std::string bot_name = "LLaMA"; + std::string wake_cmd = ""; + std::string heard_ok = ""; std::string language = "en"; std::string model_wsp = "models/ggml-base.en.bin"; std::string model_llama = "models/ggml-llama-7B.bin"; @@ -104,6 +107,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { else if (arg == "-p" || arg == "--person") { params.person = argv[++i]; } else if (arg == "-bn" || arg == "--bot-name") { params.bot_name = argv[++i]; } else if (arg == "--session") { params.path_session = argv[++i]; } + else if (arg == "-w" || arg == "--wake-command") { params.wake_cmd = argv[++i]; } + else if (arg == "-ho" || arg == "--heard-ok") { params.heard_ok = argv[++i]; } else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } else if (arg == "-mw" || arg == "--model-whisper") { params.model_wsp = argv[++i]; } else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; } @@ -149,6 +154,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); fprintf(stderr, " -p NAME, --person NAME [%-7s] person name (for prompt selection)\n", params.person.c_str()); fprintf(stderr, " -bn NAME, --bot-name NAME [%-7s] bot name (to display)\n", params.bot_name.c_str()); + fprintf(stderr, " -w TEXT, --wake-command T [%-7s] wake-up command to listen for\n", params.wake_cmd.c_str()); + fprintf(stderr, " -ho TEXT, --heard-ok TEXT [%-7s] said by TTS before generating reply\n", params.heard_ok.c_str()); fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); fprintf(stderr, " -mw FILE, --model-whisper [%-7s] whisper model file\n", params.model_wsp.c_str()); fprintf(stderr, " -ml FILE, --model-llama [%-7s] llama model file\n", params.model_llama.c_str()); @@ -227,6 +234,18 @@ std::string transcribe( return result; } +std::vector get_words(const std::string &txt) { + std::vector words; + + std::istringstream iss(txt); + std::string word; + while (iss >> word) { + words.push_back(word); + } + + return words; +} + const std::string k_prompt_whisper = R"(A conversation with a person called {1}.)"; const std::string k_prompt_llama = R"(Text transcript of a never ending dialog, where {0} interacts with an AI assistant named {1}. @@ -441,6 +460,16 @@ int main(int argc, char ** argv) { bool need_to_save_session = !path_session.empty() && n_matching_session_tokens < (embd_inp.size() * 3 / 4); printf("%s : done! start speaking in the microphone\n", __func__); + + // show wake command if enabled + const std::string wake_cmd = params.wake_cmd; + const int wake_cmd_length = get_words(wake_cmd).size(); + const bool use_wake_cmd = wake_cmd_length > 0; + + if (use_wake_cmd) { + printf("%s : the wake-up command is: '%s%s%s'\n", __func__, "\033[1m", wake_cmd.c_str(), "\033[0m"); + } + printf("\n"); printf("%s%s", params.person.c_str(), chat_symb.c_str()); fflush(stdout); @@ -486,10 +515,41 @@ int main(int argc, char ** argv) { audio.get(params.voice_ms, pcmf32_cur); - std::string text_heard; + std::string all_heard; if (!force_speak) { - text_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms)); + all_heard = ::trim(::transcribe(ctx_wsp, params, pcmf32_cur, prompt_whisper, prob0, t_ms)); + } + + const auto words = get_words(all_heard); + + std::string wake_cmd_heard; + std::string text_heard; + + for (int i = 0; i < (int) words.size(); ++i) { + if (i < wake_cmd_length) { + wake_cmd_heard += words[i] + " "; + } else { + text_heard += words[i] + " "; + } + } + + // check if audio starts with the wake-up command if enabled + if (use_wake_cmd) { + const float sim = similarity(wake_cmd_heard, wake_cmd); + + if ((sim < 0.7f) || (text_heard.empty())) { + audio.clear(); + continue; + } + } + + // optionally give audio feedback that the current text is being processed + if (!params.heard_ok.empty()) { + int ret = system((params.speak + " " + std::to_string(voice_id) + " '" + params.heard_ok + "'").c_str()); + if (ret != 0) { + fprintf(stderr, "%s: failed to speak\n", __func__); + } } // remove text between brackets using regex