diff --git a/examples/wchess/libwchess/WChess.cpp b/examples/wchess/libwchess/WChess.cpp index 9ed49d44..9293180c 100644 --- a/examples/wchess/libwchess/WChess.cpp +++ b/examples/wchess/libwchess/WChess.cpp @@ -6,20 +6,20 @@ static constexpr auto RULES = "\n" -"root ::= init move move? move? \".\"\n" +"root ::= move \".\"\n" "prompt ::= init \".\"\n" "\n" "# leading space is very important!\n" "init ::= \" rook to b4, f3\"\n" "\n" -"move ::= \", \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n" +"move ::= \" \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n" "\n" "piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n" "king ::= \"king\"\n" "pawn ::= \"pawn\"\n" "\n"; -static constexpr auto PROMPT = "rook to b4, f3,"; +static constexpr auto PROMPT = "rook to b4, f3"; static constexpr auto CONTEXT = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,"; WChess::WChess(whisper_context * ctx, @@ -63,8 +63,8 @@ std::string WChess::stringify_board() const { void WChess::run() { set_status("loading data ..."); - bool have_prompt = false; - bool ask_prompt = true; + bool have_prompt = true; + bool ask_prompt = false; float logprob_min0 = 0.0f; float logprob_min = 0.0f; @@ -79,7 +79,6 @@ void WChess::run() { std::vector pcmf32_prompt; const std::string k_prompt = PROMPT; - m_wparams.initial_prompt = CONTEXT; auto grammar_parsed = grammar_parser::parse(RULES); auto grammar_rules = grammar_parsed.c_rules(); @@ -149,13 +148,16 @@ void WChess::run() { } } else { // prepend 3 second of silence - pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f); + // pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f); // prepend the prompt audio - pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); + // pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); + + if (WHISPER_SAMPLE_RATE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), WHISPER_SAMPLE_RATE - pcmf32_cur.size(), 0.0f); m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("root"); - const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); + auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); + txt = PROMPT + txt; const float p = 100.0f * std::exp(logprob_min);