#include "WChess.h" #include "Chessboard.h" #include "grammar-parser.h" #include "common.h" #include WChess::WChess(whisper_context * ctx, const whisper_full_params & wparams, callbacks cb, settings s) : m_ctx(ctx) , m_wparams(wparams) , m_cb(cb) , m_settings(s) , m_board(new Chessboard()) {} WChess::~WChess() = default; void WChess::set_move(const std::string& moves, float prob) const { if (m_cb.set_move) (*m_cb.set_move)(moves, prob); } void WChess::set_grammar(const std::string& grammar) const { if (m_cb.set_grammar) (*m_cb.set_grammar)(grammar); } bool WChess::get_audio(std::vector& pcmf32) const { if (m_cb.get_audio) return (*m_cb.get_audio)(pcmf32); return false; } std::string WChess::stringify_board() const { return m_board->stringifyBoard(); } std::string WChess::get_grammar() const { return m_board->grammar(); } void WChess::run() { bool have_prompt = true; bool ask_prompt = !have_prompt; float logprob_min = 0.0f; float logprob_sum = 0.0f; int n_tokens = 0; std::vector pcmf32_cur; std::vector pcmf32_prompt; const std::string k_prompt = have_prompt ? "" : "rook to d4, f3"; int64_t t_ms = 0; if (ask_prompt) { fprintf(stdout, "\n"); fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); fprintf(stdout, "\n"); ask_prompt = false; } while (get_audio(pcmf32_cur)) { if (!pcmf32_cur.empty()) { // fprintf(stdout, "%s: Processing ...\n", __func__); if (!have_prompt) { const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); const float sim = similarity(txt, k_prompt); if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); ask_prompt = true; } else { fprintf(stdout, "\n"); fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); fprintf(stdout, "\n"); // save the audio for the prompt pcmf32_prompt = pcmf32_cur; have_prompt = true; m_board->setPrompt(k_prompt); } } else { if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); constexpr size_t MIN_SIZE = 1.2 * WHISPER_SAMPLE_RATE; if (MIN_SIZE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), MIN_SIZE - pcmf32_cur.size(), 0.0f); // fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, m_board->grammar().c_str()); auto grammar_parsed = grammar_parser::parse(m_board->grammar().c_str()); auto grammar_rules = grammar_parsed.c_rules(); m_wparams.grammar_rules = grammar_rules.data(); m_wparams.n_grammar_rules = grammar_rules.size(); m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move"); auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); const float p = 100.0f * std::exp(logprob_min); fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); // find the prompt in the text float best_sim = 0.0f; size_t best_len = 0; for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { const auto prompt = txt.substr(0, n); const float sim = similarity(prompt, k_prompt); //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); if (sim > best_sim) { best_sim = sim; best_len = n; } } fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p); std::string command = ::trim(txt.substr(best_len)); fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); fprintf(stdout, "\n"); if (!command.empty()) { set_move(m_board->process(command), p); set_grammar(m_board->grammar()); } if (m_board->grammar().empty()) { fprintf(stdout, "%s: No more moves possible\n", __func__); break; } } } if (ask_prompt) { fprintf(stdout, "\n"); fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); fprintf(stdout, "\n"); ask_prompt = false; } } } std::string WChess::transcribe( const std::vector & pcmf32, float & logprob_min, float & logprob_sum, int & n_tokens, int64_t & t_ms) { const auto t_start = std::chrono::high_resolution_clock::now(); logprob_min = 0.0f; logprob_sum = 0.0f; n_tokens = 0; t_ms = 0; if (whisper_full(m_ctx, m_wparams, pcmf32.data(), pcmf32.size()) != 0) { return {}; } std::string result; const int n_segments = whisper_full_n_segments(m_ctx); for (int i = 0; i < n_segments; ++i) { const char * text = whisper_full_get_segment_text(m_ctx, i); result += text; const int n = whisper_full_n_tokens(m_ctx, i); for (int j = 0; j < n; ++j) { const auto token = whisper_full_get_token_data(m_ctx, i, j); if(token.plog > 0.0f) return {}; logprob_min = std::min(logprob_min, token.plog); logprob_sum += token.plog; ++n_tokens; } } const auto t_end = std::chrono::high_resolution_clock::now(); t_ms = std::chrono::duration_cast(t_end - t_start).count(); return result; }