diff --git a/examples/wchess/libwchess/Chessboard.cpp b/examples/wchess/libwchess/Chessboard.cpp index 0bd24113..a47ed120 100644 --- a/examples/wchess/libwchess/Chessboard.cpp +++ b/examples/wchess/libwchess/Chessboard.cpp @@ -108,14 +108,15 @@ Chessboard::Chessboard() std::sort(blackMoves.begin(), blackMoves.end()); } -std::string Chessboard::getRules() const { +std::string Chessboard::getRules(const std::string& prompt) const { // leading space is very important! std::string result = "\n" "# leading space is very important!\n" "\n" - "move ::= \" \" ((piece | frompos) \" \" \"to \"?)? topos\n" - "\n"; + "move ::= prompt \" \" ((piece | frompos) \" \" \"to \"?)? topos\n" + "\n" + "prompt ::= \" " + prompt + "\"\n"; std::set pieces; std::set from_pos; diff --git a/examples/wchess/libwchess/Chessboard.h b/examples/wchess/libwchess/Chessboard.h index 7808a704..0d4e847a 100644 --- a/examples/wchess/libwchess/Chessboard.h +++ b/examples/wchess/libwchess/Chessboard.h @@ -8,7 +8,7 @@ public: Chessboard(); std::string process(const std::string& t); std::string stringifyBoard(); - std::string getRules() const; + std::string getRules(const std::string & prompt) const; using Move = std::pair; private: bool move(const Move& move); diff --git a/examples/wchess/libwchess/WChess.cpp b/examples/wchess/libwchess/WChess.cpp index f85c21db..e41e32a7 100644 --- a/examples/wchess/libwchess/WChess.cpp +++ b/examples/wchess/libwchess/WChess.cpp @@ -45,8 +45,8 @@ std::string WChess::stringify_board() const { void WChess::run() { set_status("loading data ..."); - bool have_prompt = true; - bool ask_prompt = false; + bool have_prompt = false; + bool ask_prompt = true; float logprob_min0 = 0.0f; float logprob_min = 0.0f; @@ -60,13 +60,26 @@ void WChess::run() { std::vector pcmf32_cur; std::vector pcmf32_prompt; - std::string prompt = ""; - float prompt_prop = 0.0f; + const std::string k_prompt = "King bishop rook queen knight"; while (check_running()) { // delay std::this_thread::sleep_for(std::chrono::milliseconds(100)); + if (ask_prompt) { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); + fprintf(stdout, "\n"); + + { + char txt[1024]; + snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str()); + set_status(txt); + } + + ask_prompt = false; + } + int64_t t_ms = 0; { @@ -76,12 +89,37 @@ void WChess::run() { fprintf(stdout, "%s: Processing ...\n", __func__); set_status("Processing ..."); - { + if (!have_prompt) { + const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); + + fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); + + const float sim = similarity(txt, k_prompt); + + if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { + fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); + ask_prompt = true; + } else { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); + fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); + fprintf(stdout, "\n"); + + { + char txt[1024]; + snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ..."); + set_status(txt); + } + + // save the audio for the prompt + pcmf32_prompt = pcmf32_cur; + have_prompt = true; + } + } else { + pcmf32_cur.insert(pcmf32_cur.begin(), WHISPER_SAMPLE_RATE, 0.0f); if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); - if (WHISPER_SAMPLE_RATE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), WHISPER_SAMPLE_RATE - pcmf32_cur.size(), 0.0f); - - std::string rules = m_board->getRules(); + std::string rules = m_board->getRules(k_prompt); fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, rules.c_str()); auto grammar_parsed = grammar_parser::parse(rules.c_str()); @@ -100,18 +138,16 @@ void WChess::run() { // find the prompt in the text float best_sim = 0.0f; size_t best_len = 0; - if (!prompt.empty()) { - auto pos = txt.find_first_of('.'); + for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { + const auto prompt = txt.substr(0, n); - const auto header = txt.substr(0, pos); - - const float sim = similarity(prompt, header); + const float sim = similarity(prompt, k_prompt); //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); if (sim > best_sim) { best_sim = sim; - best_len = pos + 1; + best_len = n; } }