diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index eb791a21..e6f837f3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -60,7 +60,6 @@ if (EMSCRIPTEN) add_subdirectory(command.wasm) add_subdirectory(talk.wasm) add_subdirectory(bench.wasm) - add_subdirectory(chess.wasm) elseif(CMAKE_JS_VERSION) add_subdirectory(addon.node) else() @@ -74,3 +73,5 @@ else() add_subdirectory(talk-llama) add_subdirectory(lsp) endif() + +add_subdirectory(wchess) diff --git a/examples/chess.wasm/emscripten.cpp b/examples/chess.wasm/emscripten.cpp deleted file mode 100644 index 7b6d49fb..00000000 --- a/examples/chess.wasm/emscripten.cpp +++ /dev/null @@ -1,704 +0,0 @@ -#include "ggml.h" -#include "common.h" -#include "whisper.h" -#include "grammar-parser.h" - -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -constexpr int N_THREAD = 8; - -std::vector g_contexts(4, nullptr); - -std::mutex g_mutex; -std::thread g_worker; - -std::atomic g_running(false); - -std::string g_status = ""; -std::string g_status_forced = ""; -std::string g_transcribed = ""; - -std::vector g_pcmf32; - -void command_set_status(const std::string & status) { - std::lock_guard lock(g_mutex); - g_status = status; -} - -std::string command_transcribe( - whisper_context * ctx, - const whisper_full_params & wparams, - const std::vector & pcmf32, - float & logprob_min, - float & logprob_sum, - int & n_tokens, - int64_t & t_ms) { - const auto t_start = std::chrono::high_resolution_clock::now(); - - logprob_min = 0.0f; - logprob_sum = 0.0f; - n_tokens = 0; - t_ms = 0; - - if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { - return ""; - } - - std::string result; - - const int n_segments = whisper_full_n_segments(ctx); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(ctx, i); - - result += text; - - const int n = whisper_full_n_tokens(ctx, i); - for (int j = 0; j < n; ++j) { - const auto token = whisper_full_get_token_data(ctx, i, j); - - if(token.plog > 0.0f) exit(0); // todo: check for emscripten - logprob_min = std::min(logprob_min, token.plog); - logprob_sum += token.plog; - ++n_tokens; - } - } - - const auto t_end = std::chrono::high_resolution_clock::now(); - t_ms = std::chrono::duration_cast(t_end - t_start).count(); - - return result; -} - -void command_get_audio(int ms, int sample_rate, std::vector & audio) { - const int64_t n_samples = (ms * sample_rate) / 1000; - - int64_t n_take = 0; - if (n_samples > (int) g_pcmf32.size()) { - n_take = g_pcmf32.size(); - } else { - n_take = n_samples; - } - - audio.resize(n_take); - std::copy(g_pcmf32.end() - n_take, g_pcmf32.end(), audio.begin()); -} - -static constexpr std::array positions = { - "a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1", - "a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2", - "a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3", - "a4", "b4", "c4", "d4", "e4", "f4", "g4", "h4", - "a5", "b5", "c5", "d5", "e5", "f5", "g5", "h5", - "a6", "b6", "c6", "d6", "e6", "f6", "g6", "h6", - "a7", "b7", "c7", "d7", "e7", "f7", "g7", "h7", - "a8", "b8", "c8", "d8", "e8", "f8", "g8", "h8", -}; - -static constexpr std::array pieceNames = { - "pawn", "knight", "bishop", "rook", "queen", "king", -}; - - -class Board { -public: - struct Piece { - enum Types { - Pawn, - Knight, - Bishop, - Rook, - Queen, - King, - Taken, - }; - static_assert(pieceNames.size() == Piece::Taken, "Mismatch between piece names and types"); - - enum Colors { - Black, - White - }; - - Types type; - Colors color; - int pos; - }; - - - std::array blackPieces = {{ - {Piece::Pawn, Piece::Black, 48 }, - {Piece::Pawn, Piece::Black, 49 }, - {Piece::Pawn, Piece::Black, 50 }, - {Piece::Pawn, Piece::Black, 51 }, - {Piece::Pawn, Piece::Black, 52 }, - {Piece::Pawn, Piece::Black, 53 }, - {Piece::Pawn, Piece::Black, 54 }, - {Piece::Pawn, Piece::Black, 55 }, - {Piece::Rook, Piece::Black, 56 }, - {Piece::Knight, Piece::Black, 57 }, - {Piece::Bishop, Piece::Black, 58 }, - {Piece::Queen, Piece::Black, 59 }, - {Piece::King, Piece::Black, 60 }, - {Piece::Bishop, Piece::Black, 61 }, - {Piece::Knight, Piece::Black, 62 }, - {Piece::Rook, Piece::Black, 63 }, - }}; - - std::array whitePieces = {{ - {Piece::Pawn, Piece::White, 8 }, - {Piece::Pawn, Piece::White, 9 }, - {Piece::Pawn, Piece::White, 10 }, - {Piece::Pawn, Piece::White, 11 }, - {Piece::Pawn, Piece::White, 12 }, - {Piece::Pawn, Piece::White, 13 }, - {Piece::Pawn, Piece::White, 14 }, - {Piece::Pawn, Piece::White, 15 }, - {Piece::Rook, Piece::White, 0 }, - {Piece::Knight, Piece::White, 1 }, - {Piece::Bishop, Piece::White, 2 }, - {Piece::Queen, Piece::White, 3 }, - {Piece::King, Piece::White, 4 }, - {Piece::Bishop, Piece::White, 5 }, - {Piece::Knight, Piece::White, 6 }, - {Piece::Rook, Piece::White, 7 }, - }}; - - using BB = std::array; - BB board = {{ - &whitePieces[ 8], &whitePieces[ 9], &whitePieces[10], &whitePieces[11], &whitePieces[12], &whitePieces[13], &whitePieces[14], &whitePieces[15], - &whitePieces[ 0], &whitePieces[ 1], &whitePieces[ 2], &whitePieces[ 3], &whitePieces[ 4], &whitePieces[ 5], &whitePieces[ 6], &whitePieces[ 7], - nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - &blackPieces[ 0], &blackPieces[ 1], &blackPieces[ 2], &blackPieces[ 3], &blackPieces[ 4], &blackPieces[ 5], &blackPieces[ 6], &blackPieces[ 7], - &blackPieces[ 8], &blackPieces[ 9], &blackPieces[10], &blackPieces[11], &blackPieces[12], &blackPieces[13], &blackPieces[14], &blackPieces[15], - }}; - - bool checkNext(const Piece& piece, int pos, bool kingCheck = false) { - if (piece.type == Piece::Taken) return false; - if (piece.pos == pos) return false; - int i = piece.pos / 8; - int j = piece.pos - i * 8; - - int ii = pos / 8; - int jj = pos - ii * 8; - - if (piece.type == Piece::Pawn) { - if (piece.color == Piece::White) { - int direction = piece.color == Piece::White ? 1 : -1; - if (j == jj) { - if (i == ii - direction) return board[pos] == nullptr; - if (i == ii - direction * 2) return board[(ii - direction) * 8 + jj] == nullptr && board[pos] == nullptr; - } - else if (j + 1 == jj || j - 1 == jj) { - if (i == ii - direction) return board[pos] != nullptr && board[pos]->color != piece.color; - } - } - return false; - } - if (piece.type == Piece::Knight) { - int di = std::abs(i - ii); - int dj = std::abs(j - jj); - if ((di == 2 && dj == 1) || (di == 1 && dj == 2)) return board[pos] == nullptr || board[pos]->color != piece.color; - return false; - } - if (piece.type == Piece::Bishop) { - if (i - j == ii - jj) { - int direction = i < ii ? 1 : -1; - i += direction; - j += direction; - while (i != ii) { - if (board[i * 8 + j]) return false; - i += direction; - j += direction; - } - return board[pos] == nullptr || board[pos]->color != piece.color; - } - if (i + j == ii + jj) { - int direction = i < ii ? 1 : -1; - i += direction; - j -= direction; - while (i != ii) { - if (board[i * 8 + j]) return false; - i += direction; - j -= direction; - } - return board[pos] == nullptr || board[pos]->color != piece.color; - } - return false; - } - if (piece.type == Piece::Rook) { - if (i == ii) { - int direction = j < jj ? 1 : -1; - j += direction; - while (j != jj) { - if (board[i * 8 + j]) return false; - j += direction; - } - return board[pos] == nullptr || board[pos]->color != piece.color; - } - if (j == jj) { - int direction = i < ii ? 1 : -1; - i += direction; - while (i != ii) { - if (board[i * 8 + j]) return false; - i += direction; - } - return board[pos] == nullptr || board[pos]->color != piece.color; - } - return false; - } - if (piece.type == Piece::Queen) { - if (i - j == ii - jj) { - int direction = i < ii ? 1 : -1; - i += direction; - j += direction; - while (i != ii) { - if (board[i * 8 + j]) return false; - i += direction; - j += direction; - } - return board[pos] == nullptr || board[pos]->color != piece.color; - } - if (i + j == ii + jj) { - int direction = i < ii ? 1 : -1; - i += direction; - j -= direction; - while (i != ii) { - if (board[i * 8 + j]) return false; - i += direction; - j -= direction; - } - return board[pos] == nullptr || board[pos]->color != piece.color; - } - if (i == ii) { - int direction = j < jj ? 1 : -1; - j += direction; - while (j != jj) { - if (board[i * 8 + j]) return false; - j += direction; - } - return board[pos] == nullptr || board[pos]->color != piece.color; - } - if (j == jj) { - int direction = i < ii ? 1 : -1; - i += direction; - while (i != ii) { - if (board[i * 8 + j]) return false; - i += direction; - } - return board[pos] == nullptr || board[pos]->color != piece.color; - } - return false; - } - if (piece.type == Piece::King) { - if (std::abs(i - ii) < 2 && std::abs(j - jj) < 2) { - auto& pieces = piece.color == Piece::White ? whitePieces : blackPieces; - for (auto& enemyPiece: pieces) { - if (!kingCheck && piece.type != Piece::Taken && checkNext(enemyPiece, pos, true)) return false; - } - return board[pos] == nullptr || board[pos]->color != piece.color; - } - } - return false; - } - - - int moveCount = 0; - - - void addMoves(const std::string& t) { - - std::vector moves; - size_t cur = 0; - size_t last = 0; - while (cur != std::string::npos) { - cur = t.find(',', last); - moves.push_back(t.substr(last, cur)); - last = cur + 1; - } - - // fixme: lookup depends on grammar - int count = moveCount; - for (auto& move : moves) { - fprintf(stdout, "%s: Move '%s%s%s'\n", __func__, "\033[1m", move.c_str(), "\033[0m"); - if (move.empty()) continue; - auto pieceIndex = 0u; - for (; pieceIndex < pieceNames.size(); ++pieceIndex) { - if (std::string::npos != move.find(pieceNames[pieceIndex])) break; - } - auto posIndex = 0u; - for (; posIndex < positions.size(); ++posIndex) { - if (std::string::npos != move.find(positions[posIndex])) break; - } - if (pieceIndex >= pieceNames.size() || posIndex >= positions.size()) continue; - - auto& pieces = count % 2 ? blackPieces : whitePieces; - auto type = Piece::Types(pieceIndex); - pieceIndex = 0; - for (; pieceIndex < pieces.size(); ++pieceIndex) { - if (pieces[pieceIndex].type == type && checkNext(pieces[pieceIndex], posIndex)) break; - } - if (pieceIndex < pieces.size()) { - m_pendingMoves.push_back({&pieces[pieceIndex], posIndex}); - } - } - } - - std::string stringifyMoves() { - std::string res; - for (auto& m : m_pendingMoves) { - res.append(positions[m.first->pos]); - res.push_back('-'); - res.append(positions[m.second]); - res.push_back(' '); - } - if (!res.empty()) res.pop_back(); - return res; - } - - void commitMoves() { - for (auto& m : m_pendingMoves) { - if (board[m.second]) board[m.second]->type = Piece::Taken; - board[m.first->pos] = nullptr; - m.first->pos = m.second; - board[m.second] = m.first; - } - m_pendingMoves.clear(); - } - - std::vector> m_pendingMoves; -}; - -Board g_board; - -void command_main(size_t index) { - command_set_status("loading data ..."); - - struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY); - - wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency()); - wparams.offset_ms = 0; - wparams.translate = false; - wparams.no_context = true; - wparams.single_segment = true; - wparams.print_realtime = false; - wparams.print_progress = false; - wparams.print_timestamps = true; - wparams.print_special = false; - - wparams.max_tokens = 32; - // wparams.audio_ctx = 768; // partial encoder context for better performance - - wparams.temperature = 0.4f; - wparams.temperature_inc = 1.0f; - wparams.greedy.best_of = 1; - - wparams.beam_search.beam_size = 5; - - wparams.language = "en"; - - printf("command: using %d threads\n", wparams.n_threads); - - bool have_prompt = false; - bool ask_prompt = true; - bool print_energy = false; - - float logprob_min0 = 0.0f; - float logprob_min = 0.0f; - - float logprob_sum0 = 0.0f; - float logprob_sum = 0.0f; - - int n_tokens0 = 0; - int n_tokens = 0; - - std::vector pcmf32_cur; - std::vector pcmf32_prompt; - - // todo: grammar to be based on js input - const std::string k_prompt = "rook to b4, f3,"; - wparams.initial_prompt = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,"; - - auto grammar_parsed = grammar_parser::parse( -"\n" -"root ::= init move move? move? \".\"\n" -"prompt ::= init \".\"\n" -"\n" -"# leading space is very important!\n" -"init ::= \" rook to b4, f3\"\n" -"\n" -"move ::= \", \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n" -"\n" -"piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n" -"king ::= \"king\"\n" -"pawn ::= \"pawn\"\n" -"\n" - ); - auto grammar_rules = grammar_parsed.c_rules(); - - if (grammar_parsed.rules.empty()) { - fprintf(stdout, "%s: Failed to parse grammar ...\n", __func__); - } - else { - wparams.grammar_rules = grammar_rules.data(); - wparams.n_grammar_rules = grammar_rules.size(); - wparams.grammar_penalty = 100.0; - } - - // whisper context - auto & ctx = g_contexts[index]; - - const int32_t vad_ms = 2000; - const int32_t prompt_ms = 5000; - const int32_t command_ms = 4000; - - const float vad_thold = 0.1f; - const float freq_thold = -1.0f; - - while (g_running) { - // delay - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - if (ask_prompt) { - fprintf(stdout, "\n"); - fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); - fprintf(stdout, "\n"); - - { - char txt[1024]; - snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str()); - command_set_status(txt); - } - - ask_prompt = false; - } - - int64_t t_ms = 0; - - { - command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); - - if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) { - fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); - command_set_status("Speech detected! Processing ..."); - - if (!have_prompt) { - command_get_audio(prompt_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); - - wparams.i_start_rule = grammar_parsed.symbol_ids.at("prompt"); - const auto txt = ::trim(::command_transcribe(ctx, wparams, pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); - - fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); - - const float sim = similarity(txt, k_prompt); - - if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { - fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); - ask_prompt = true; - } else { - fprintf(stdout, "\n"); - fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); - fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); - fprintf(stdout, "\n"); - - { - char txt[1024]; - snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ..."); - command_set_status(txt); - } - - // save the audio for the prompt - pcmf32_prompt = pcmf32_cur; - have_prompt = true; - } - } else { - command_get_audio(command_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); - - // prepend 3 second of silence - pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f); - - // prepend the prompt audio - pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); - - wparams.i_start_rule = grammar_parsed.symbol_ids.at("root"); - const auto txt = ::trim(::command_transcribe(ctx, wparams, pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); - - const float p = 100.0f * std::exp(logprob_min); - - fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); - - // find the prompt in the text - float best_sim = 0.0f; - size_t best_len = 0; - for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { - if (n >= int(txt.size())) { - break; - } - - const auto prompt = txt.substr(0, n); - - const float sim = similarity(prompt, k_prompt); - - //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); - - if (sim > best_sim) { - best_sim = sim; - best_len = n; - } - } - - fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p); - std::string command = ::trim(txt.substr(best_len)); - - fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); - fprintf(stdout, "\n"); - - { - char txt[1024]; - snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms); - command_set_status(txt); - } - { - std::lock_guard lock(g_mutex); - if (!command.empty()) { - g_board.addMoves(command); - } - g_transcribed = std::move(command); - } - } - - g_pcmf32.clear(); - } - } - } - - if (index < g_contexts.size()) { - whisper_free(g_contexts[index]); - g_contexts[index] = nullptr; - } -} - -EMSCRIPTEN_BINDINGS(command) { - emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { - for (size_t i = 0; i < g_contexts.size(); ++i) { - if (g_contexts[i] == nullptr) { - g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); - if (g_contexts[i] != nullptr) { - g_running = true; - if (g_worker.joinable()) { - g_worker.join(); - } - g_worker = std::thread([i]() { - command_main(i); - }); - - return i + 1; - } else { - return (size_t) 0; - } - } - } - - return (size_t) 0; - })); - - emscripten::function("free", emscripten::optional_override([](size_t index) { - if (g_running) { - g_running = false; - } - })); - - emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) { - --index; - - if (index >= g_contexts.size()) { - return -1; - } - - if (g_contexts[index] == nullptr) { - return -2; - } - - { - std::lock_guard lock(g_mutex); - const int n = audio["length"].as(); - - emscripten::val heap = emscripten::val::module_property("HEAPU8"); - emscripten::val memory = heap["buffer"]; - - g_pcmf32.resize(n); - - emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast(g_pcmf32.data()), n); - memoryView.call("set", audio); - } - - return 0; - })); - - emscripten::function("get_transcribed", emscripten::optional_override([]() { - std::string transcribed; - - { - std::lock_guard lock(g_mutex); - transcribed = std::move(g_transcribed); - } - - return transcribed; - })); - - - emscripten::function("get_moves", emscripten::optional_override([]() { - std::string moves; - - { - std::lock_guard lock(g_mutex); - moves = g_board.stringifyMoves(); - - fprintf(stdout, "%s: Moves '%s%s%s'\n", __func__, "\033[1m", moves.c_str(), "\033[0m"); - } - - return moves; - })); - - emscripten::function("commit_moves", emscripten::optional_override([]() { - { - std::lock_guard lock(g_mutex); - g_board.commitMoves(); - } - - })); - - emscripten::function("discard_moves", emscripten::optional_override([]() { - { - std::lock_guard lock(g_mutex); - g_board.m_pendingMoves.clear(); - } - - })); - - emscripten::function("get_status", emscripten::optional_override([]() { - std::string status; - - { - std::lock_guard lock(g_mutex); - status = g_status_forced.empty() ? g_status : g_status_forced; - } - - return status; - })); - - emscripten::function("set_status", emscripten::optional_override([](const std::string & status) { - { - std::lock_guard lock(g_mutex); - g_status_forced = status; - } - })); -} diff --git a/examples/wchess/CMakeLists.txt b/examples/wchess/CMakeLists.txt new file mode 100644 index 00000000..9f2dbb9a --- /dev/null +++ b/examples/wchess/CMakeLists.txt @@ -0,0 +1,8 @@ + +add_subdirectory(libwchess) + +if (EMSCRIPTEN) + add_subdirectory(wchess.wasm) +else() + add_subdirectory(wchess.cmd) +endif() \ No newline at end of file diff --git a/examples/wchess/libwchess/CMakeLists.txt b/examples/wchess/libwchess/CMakeLists.txt new file mode 100644 index 00000000..7c89883d --- /dev/null +++ b/examples/wchess/libwchess/CMakeLists.txt @@ -0,0 +1,16 @@ +add_library(libwchess + WChess.cpp + WChess.h + Chessboard.cpp + Chessboard.h + ) + +target_link_libraries(libwchess + PUBLIC + whisper +) + +target_include_directories(libwchess + PUBLIC + "$" + ) diff --git a/examples/wchess/libwchess/Chessboard.cpp b/examples/wchess/libwchess/Chessboard.cpp new file mode 100644 index 00000000..9fc87b50 --- /dev/null +++ b/examples/wchess/libwchess/Chessboard.cpp @@ -0,0 +1,291 @@ +#include "Chessboard.h" +#include + +static constexpr std::array positions = { + "a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1", + "a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2", + "a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3", + "a4", "b4", "c4", "d4", "e4", "f4", "g4", "h4", + "a5", "b5", "c5", "d5", "e5", "f5", "g5", "h5", + "a6", "b6", "c6", "d6", "e6", "f6", "g6", "h6", + "a7", "b7", "c7", "d7", "e7", "f7", "g7", "h7", + "a8", "b8", "c8", "d8", "e8", "f8", "g8", "h8", +}; + +static constexpr std::array pieceNames = { + "pawn", "knight", "bishop", "rook", "queen", "king", +}; + +Chessboard::Chessboard() + : blackPieces {{ + {Piece::Pawn, Piece::Black, 48 }, + {Piece::Pawn, Piece::Black, 49 }, + {Piece::Pawn, Piece::Black, 50 }, + {Piece::Pawn, Piece::Black, 51 }, + {Piece::Pawn, Piece::Black, 52 }, + {Piece::Pawn, Piece::Black, 53 }, + {Piece::Pawn, Piece::Black, 54 }, + {Piece::Pawn, Piece::Black, 55 }, + {Piece::Rook, Piece::Black, 56 }, + {Piece::Knight, Piece::Black, 57 }, + {Piece::Bishop, Piece::Black, 58 }, + {Piece::Queen, Piece::Black, 59 }, + {Piece::King, Piece::Black, 60 }, + {Piece::Bishop, Piece::Black, 61 }, + {Piece::Knight, Piece::Black, 62 }, + {Piece::Rook, Piece::Black, 63 }, + }} + , whitePieces {{ + {Piece::Pawn, Piece::White, 8 }, + {Piece::Pawn, Piece::White, 9 }, + {Piece::Pawn, Piece::White, 10 }, + {Piece::Pawn, Piece::White, 11 }, + {Piece::Pawn, Piece::White, 12 }, + {Piece::Pawn, Piece::White, 13 }, + {Piece::Pawn, Piece::White, 14 }, + {Piece::Pawn, Piece::White, 15 }, + {Piece::Rook, Piece::White, 0 }, + {Piece::Knight, Piece::White, 1 }, + {Piece::Bishop, Piece::White, 2 }, + {Piece::Queen, Piece::White, 3 }, + {Piece::King, Piece::White, 4 }, + {Piece::Bishop, Piece::White, 5 }, + {Piece::Knight, Piece::White, 6 }, + {Piece::Rook, Piece::White, 7 }, + }} + , board {{ + &whitePieces[ 8], &whitePieces[ 9], &whitePieces[10], &whitePieces[11], &whitePieces[12], &whitePieces[13], &whitePieces[14], &whitePieces[15], + &whitePieces[ 0], &whitePieces[ 1], &whitePieces[ 2], &whitePieces[ 3], &whitePieces[ 4], &whitePieces[ 5], &whitePieces[ 6], &whitePieces[ 7], + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, + &blackPieces[ 0], &blackPieces[ 1], &blackPieces[ 2], &blackPieces[ 3], &blackPieces[ 4], &blackPieces[ 5], &blackPieces[ 6], &blackPieces[ 7], + &blackPieces[ 8], &blackPieces[ 9], &blackPieces[10], &blackPieces[11], &blackPieces[12], &blackPieces[13], &blackPieces[14], &blackPieces[15], + }} +{ + static_assert(pieceNames.size() == Chessboard::Piece::Taken, "Mismatch between piece names and types"); +} + +std::string Chessboard::stringifyBoard() { + static constexpr std::array blackShort = { + 'p', 'n', 'b', 'r', 'q', 'k', + }; + static constexpr std::array whiteShort = { + 'P', 'N', 'B', 'R', 'Q', 'K', + }; + + std::string result; + result.reserve(16 + 2 * 64 + 16); + for (char rank = 'a'; rank <= 'h'; ++rank) { + result.push_back(rank); + result.push_back(' '); + } + result.back() = '\n'; + for (int i = 7; i >= 0; --i) { + for (int j = 0; j < 8; ++j) { + if (auto p = board[i * 8 + j]; p) result.push_back(p->color == Piece::White ? whiteShort[p->type] : blackShort[p->type]); + else result.push_back('.'); + result.push_back(' '); + } + result.push_back('0' + i + 1); + result.push_back('\n'); + } + return result; +} + +std::string Chessboard::processTranscription(const std::string& t) { + std::vector moves; + size_t cur = 0; + size_t last = 0; + while (cur != std::string::npos) { + cur = t.find(',', last); + moves.push_back(t.substr(last, cur)); + last = cur + 1; + } + + // fixme: lookup depends on grammar + int count = m_moveCounter; + std::vector pendingMoves; + for (auto& move : moves) { + fprintf(stdout, "%s: Move '%s%s%s'\n", __func__, "\033[1m", move.c_str(), "\033[0m"); + if (move.empty()) continue; + auto pieceIndex = 0u; + for (; pieceIndex < pieceNames.size(); ++pieceIndex) { + if (std::string::npos != move.find(pieceNames[pieceIndex])) break; + } + auto posIndex = 0u; + for (; posIndex < positions.size(); ++posIndex) { + if (std::string::npos != move.find(positions[posIndex])) break; + } + if (pieceIndex >= pieceNames.size() || posIndex >= positions.size()) continue; + + auto& pieces = count % 2 ? blackPieces : whitePieces; + auto type = Piece::Types(pieceIndex); + pieceIndex = 0; + for (; pieceIndex < pieces.size(); ++pieceIndex) { + if (pieces[pieceIndex].type == type && checkNext(pieces[pieceIndex], posIndex)) break; + } + + if (pieceIndex < pieces.size()) { + pendingMoves.emplace_back(pieces[pieceIndex].pos, posIndex); + } + ++count; + } + auto result = stringifyMoves(pendingMoves); + commitMoves(pendingMoves); + m_moveCounter = count; + return result; + } + + bool Chessboard::checkNext(const Piece& piece, int pos, bool kingCheck) { + if (piece.type == Piece::Taken) return false; + if (piece.pos == pos) return false; + int i = piece.pos / 8; + int j = piece.pos - i * 8; + + int ii = pos / 8; + int jj = pos - ii * 8; + + if (piece.type == Piece::Pawn) { + if (piece.color == Piece::White) { + int direction = piece.color == Piece::White ? 1 : -1; + if (j == jj) { + if (i == ii - direction) return board[pos] == nullptr; + if (i == ii - direction * 2) return board[(ii - direction) * 8 + jj] == nullptr && board[pos] == nullptr; + } + else if (j + 1 == jj || j - 1 == jj) { + if (i == ii - direction) return board[pos] != nullptr && board[pos]->color != piece.color; + } + } + return false; + } + if (piece.type == Piece::Knight) { + int di = std::abs(i - ii); + int dj = std::abs(j - jj); + if ((di == 2 && dj == 1) || (di == 1 && dj == 2)) return board[pos] == nullptr || board[pos]->color != piece.color; + return false; + } + if (piece.type == Piece::Bishop) { + if (i - j == ii - jj) { + int direction = i < ii ? 1 : -1; + i += direction; + j += direction; + while (i != ii) { + if (board[i * 8 + j]) return false; + i += direction; + j += direction; + } + return board[pos] == nullptr || board[pos]->color != piece.color; + } + if (i + j == ii + jj) { + int direction = i < ii ? 1 : -1; + i += direction; + j -= direction; + while (i != ii) { + if (board[i * 8 + j]) return false; + i += direction; + j -= direction; + } + return board[pos] == nullptr || board[pos]->color != piece.color; + } + return false; + } + if (piece.type == Piece::Rook) { + if (i == ii) { + int direction = j < jj ? 1 : -1; + j += direction; + while (j != jj) { + if (board[i * 8 + j]) return false; + j += direction; + } + return board[pos] == nullptr || board[pos]->color != piece.color; + } + if (j == jj) { + int direction = i < ii ? 1 : -1; + i += direction; + while (i != ii) { + if (board[i * 8 + j]) return false; + i += direction; + } + return board[pos] == nullptr || board[pos]->color != piece.color; + } + return false; + } + if (piece.type == Piece::Queen) { + if (i - j == ii - jj) { + int direction = i < ii ? 1 : -1; + i += direction; + j += direction; + while (i != ii) { + if (board[i * 8 + j]) return false; + i += direction; + j += direction; + } + return board[pos] == nullptr || board[pos]->color != piece.color; + } + if (i + j == ii + jj) { + int direction = i < ii ? 1 : -1; + i += direction; + j -= direction; + while (i != ii) { + if (board[i * 8 + j]) return false; + i += direction; + j -= direction; + } + return board[pos] == nullptr || board[pos]->color != piece.color; + } + if (i == ii) { + int direction = j < jj ? 1 : -1; + j += direction; + while (j != jj) { + if (board[i * 8 + j]) return false; + j += direction; + } + return board[pos] == nullptr || board[pos]->color != piece.color; + } + if (j == jj) { + int direction = i < ii ? 1 : -1; + i += direction; + while (i != ii) { + if (board[i * 8 + j]) return false; + i += direction; + } + return board[pos] == nullptr || board[pos]->color != piece.color; + } + return false; + } + if (piece.type == Piece::King) { + if (std::abs(i - ii) < 2 && std::abs(j - jj) < 2) { + auto& pieces = piece.color == Piece::White ? whitePieces : blackPieces; + for (auto& enemyPiece: pieces) { + if (!kingCheck && piece.type != Piece::Taken && checkNext(enemyPiece, pos, true)) return false; + } + return board[pos] == nullptr || board[pos]->color != piece.color; + } + } + return false; + } + + + std::string Chessboard::stringifyMoves(const std::vector& pendingMoves) { + std::string res; + for (auto& m : pendingMoves) { + res.append(positions[m.first]); + res.push_back('-'); + res.append(positions[m.second]); + res.push_back(' '); + } + if (!res.empty()) res.pop_back(); + return res; + } + + void Chessboard::commitMoves(std::vector& pendingMoves) { + for (auto& m : pendingMoves) { + if (!board[m.first] || (board[m.second] && board[m.first]->type == board[m.second]->type)) continue; + if (board[m.second]) board[m.second]->type = Piece::Taken; + board[m.second] = board[m.first]; + board[m.first] = nullptr; + } + pendingMoves.clear(); + } \ No newline at end of file diff --git a/examples/wchess/libwchess/Chessboard.h b/examples/wchess/libwchess/Chessboard.h new file mode 100644 index 00000000..91ed80d1 --- /dev/null +++ b/examples/wchess/libwchess/Chessboard.h @@ -0,0 +1,47 @@ +#pragma once +#include +#include +#include + +class Chessboard { +public: + Chessboard(); + std::string processTranscription(const std::string& t); + std::string stringifyBoard(); +private: + using Move = std::pair; + std::string stringifyMoves(const std::vector&); + void commitMoves(std::vector&); + + struct Piece { + enum Types { + Pawn, + Knight, + Bishop, + Rook, + Queen, + King, + Taken, + }; + + enum Colors { + Black, + White + }; + + Types type; + Colors color; + int pos; + }; + + using PieceSet = std::array; + + PieceSet blackPieces; + PieceSet whitePieces; + int m_moveCounter; + + using Board = std::array; + Board board; + + bool checkNext(const Piece& piece, int pos, bool kingCheck = false); +}; diff --git a/examples/wchess/libwchess/WChess.cpp b/examples/wchess/libwchess/WChess.cpp new file mode 100644 index 00000000..249b7b53 --- /dev/null +++ b/examples/wchess/libwchess/WChess.cpp @@ -0,0 +1,252 @@ +#include "WChess.h" +#include "grammar-parser.h" +#include "common.h" +#include + +Chess::Chess(whisper_context * ctx, + const whisper_full_params & wparams, + StatusSetter status_setter, + ISRunning running, + AudioGetter audio, + MovesSetter m_moveSetter) + : m_ctx(ctx) + , m_wparams(wparams) + , m_status_setter(status_setter) + , m_running(running) + , m_audio(audio) + , m_moveSetter( m_moveSetter) +{} + +void Chess::set_status(const char * msg) { + if (m_status_setter) (*m_status_setter)(msg); +} + +void Chess::set_moves(const std::string& moves) { + if (m_moveSetter) (*m_moveSetter)(moves); +} + +bool Chess::check_running() { + if (m_running) return (*m_running)(); + return false; +} + +void Chess::get_audio(int ms, std::vector& pcmf32) { + if (m_audio) (*m_audio)(ms, pcmf32); +} + +std::string Chess::stringifyBoard() { + return m_board.stringifyBoard(); +} + +void Chess::run() { + set_status("loading data ..."); + + bool have_prompt = false; + bool ask_prompt = true; + bool print_energy = false; + + float logprob_min0 = 0.0f; + float logprob_min = 0.0f; + + float logprob_sum0 = 0.0f; + float logprob_sum = 0.0f; + + int n_tokens0 = 0; + int n_tokens = 0; + + std::vector pcmf32_cur; + std::vector pcmf32_prompt; + + // todo: grammar to be based on js input + const std::string k_prompt = "rook to b4, f3,"; + m_wparams.initial_prompt = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,"; + + auto grammar_parsed = grammar_parser::parse( +"\n" +"root ::= init move move? move? \".\"\n" +"prompt ::= init \".\"\n" +"\n" +"# leading space is very important!\n" +"init ::= \" rook to b4, f3\"\n" +"\n" +"move ::= \", \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n" +"\n" +"piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n" +"king ::= \"king\"\n" +"pawn ::= \"pawn\"\n" +"\n" + ); + auto grammar_rules = grammar_parsed.c_rules(); + + if (grammar_parsed.rules.empty()) { + fprintf(stdout, "%s: Failed to parse grammar ...\n", __func__); + } + else { + m_wparams.grammar_rules = grammar_rules.data(); + m_wparams.n_grammar_rules = grammar_rules.size(); + m_wparams.grammar_penalty = 100.0; + } + + const int32_t vad_ms = 2000; + const int32_t prompt_ms = 5000; + const int32_t command_ms = 4000; + + const float vad_thold = 0.1f; + const float freq_thold = -1.0f; + + while (check_running()) { + // delay + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + if (ask_prompt) { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); + fprintf(stdout, "\n"); + + { + char txt[1024]; + snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str()); + set_status(txt); + } + + ask_prompt = false; + } + + int64_t t_ms = 0; + + { + get_audio(vad_ms, pcmf32_cur); + + if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) { + fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); + set_status("Speech detected! Processing ..."); + + if (!have_prompt) { + get_audio(prompt_ms, pcmf32_cur); + + m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("prompt"); + const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); + + fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); + + const float sim = similarity(txt, k_prompt); + + if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { + fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); + ask_prompt = true; + } else { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); + fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); + fprintf(stdout, "\n"); + + { + char txt[1024]; + snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ..."); + set_status(txt); + } + + // save the audio for the prompt + pcmf32_prompt = pcmf32_cur; + have_prompt = true; + } + } else { + get_audio(command_ms, pcmf32_cur); + + // prepend 3 second of silence + pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f); + + // prepend the prompt audio + pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); + + m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("root"); + const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); + + const float p = 100.0f * std::exp(logprob_min); + + fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); + + // find the prompt in the text + float best_sim = 0.0f; + size_t best_len = 0; + for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { + if (n >= int(txt.size())) { + break; + } + + const auto prompt = txt.substr(0, n); + + const float sim = similarity(prompt, k_prompt); + + //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); + + if (sim > best_sim) { + best_sim = sim; + best_len = n; + } + } + + fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p); + std::string command = ::trim(txt.substr(best_len)); + + fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); + fprintf(stdout, "\n"); + + { + char txt[1024]; + snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms); + set_status(txt); + } + if (!command.empty()) { + set_moves(m_board.processTranscription(command)); + } + + } + + + } + } + } +} + +std::string Chess::transcribe( + const std::vector & pcmf32, + float & logprob_min, + float & logprob_sum, + int & n_tokens, + int64_t & t_ms) { + const auto t_start = std::chrono::high_resolution_clock::now(); + + logprob_min = 0.0f; + logprob_sum = 0.0f; + n_tokens = 0; + t_ms = 0; + + if (whisper_full(m_ctx, m_wparams, pcmf32.data(), pcmf32.size()) != 0) { + return ""; + } + + std::string result; + + const int n_segments = whisper_full_n_segments(m_ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(m_ctx, i); + + result += text; + + const int n = whisper_full_n_tokens(m_ctx, i); + for (int j = 0; j < n; ++j) { + const auto token = whisper_full_get_token_data(m_ctx, i, j); + + if(token.plog > 0.0f) return {}; + logprob_min = std::min(logprob_min, token.plog); + logprob_sum += token.plog; + ++n_tokens; + } + } + + const auto t_end = std::chrono::high_resolution_clock::now(); + t_ms = std::chrono::duration_cast(t_end - t_start).count(); + + return result; +} diff --git a/examples/wchess/libwchess/WChess.h b/examples/wchess/libwchess/WChess.h new file mode 100644 index 00000000..f97980e8 --- /dev/null +++ b/examples/wchess/libwchess/WChess.h @@ -0,0 +1,39 @@ +#pragma once +#include "Chessboard.h" +#include "whisper.h" +#include +#include + +class Chess { +public: + using StatusSetter = void (*)(const std::string & status); + using ISRunning = bool (*)(); + using AudioGetter = void (*)(int, std::vector&); + using MovesSetter = void (*)(const std::string & moves); + Chess( whisper_context * ctx, + const whisper_full_params & wparams, + StatusSetter status_setter, + ISRunning running, + AudioGetter audio, + MovesSetter moveSetter); + void run(); + std::string stringifyBoard(); +private: + void get_audio(int ms, std::vector& pcmf32); + void set_status(const char* msg); + void set_moves(const std::string& moves); + bool check_running(); + std::string transcribe( + const std::vector & pcmf32, + float & logprob_min, + float & logprob_sum, + int & n_tokens, + int64_t & t_ms); + whisper_context * m_ctx; + whisper_full_params m_wparams; + StatusSetter m_status_setter; + ISRunning m_running; + AudioGetter m_audio; + MovesSetter m_moveSetter; + Chessboard m_board; +}; diff --git a/examples/wchess/wchess.cmd/CMakeLists.txt b/examples/wchess/wchess.cmd/CMakeLists.txt new file mode 100644 index 00000000..4cd93ac2 --- /dev/null +++ b/examples/wchess/wchess.cmd/CMakeLists.txt @@ -0,0 +1,8 @@ +if (WHISPER_SDL2) + set(TARGET wchess) + add_executable(${TARGET} wchess.cmd.cpp) + + include(DefaultTargetOptions) + + target_link_libraries(${TARGET} PRIVATE libwchess common common-sdl ${CMAKE_THREAD_LIBS_INIT}) +endif () \ No newline at end of file diff --git a/examples/wchess/wchess.cmd/wchess.cmd.cpp b/examples/wchess/wchess.cmd/wchess.cmd.cpp new file mode 100644 index 00000000..32e126a8 --- /dev/null +++ b/examples/wchess/wchess.cmd/wchess.cmd.cpp @@ -0,0 +1,208 @@ +// Voice assistant example +// +// Speak short text commands to the microphone. +// This program will detect your voice command and convert them to text. +// +// ref: https://github.com/ggerganov/whisper.cpp/issues/171 +// + +#include "common-sdl.h" +#include "common.h" +#include "WChess.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +bool file_exists(const std::string & fname) { + std::ifstream f(fname.c_str()); + return f.good(); +} + +// command-line parameters +struct whisper_params { + int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); + int32_t prompt_ms = 5000; + int32_t command_ms = 8000; + int32_t capture_id = -1; + int32_t max_tokens = 32; + int32_t audio_ctx = 0; + + float vad_thold = 0.6f; + float freq_thold = 100.0f; + + float grammar_penalty = 100.0f; + + bool speed_up = false; + bool translate = false; + bool print_special = false; + bool print_energy = false; + bool no_timestamps = true; + bool use_gpu = true; + + std::string language = "en"; + std::string model = "models/ggml-base.en.bin"; + std::string fname_out; + std::string commands; + std::string prompt; + std::string context; + std::string grammar; +}; + +void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & params) { + fprintf(stderr, "\n"); + fprintf(stderr, "usage: %s [options]\n", argv[0]); + fprintf(stderr, "\n"); + fprintf(stderr, "options:\n"); + fprintf(stderr, " -h, --help [default] show this help message and exit\n"); + fprintf(stderr, " -t N, --threads N [%-7d] number of threads to use during computation\n", params.n_threads); + fprintf(stderr, " -pms N, --prompt-ms N [%-7d] prompt duration in milliseconds\n", params.prompt_ms); + fprintf(stderr, " -cms N, --command-ms N [%-7d] command duration in milliseconds\n", params.command_ms); + fprintf(stderr, " -c ID, --capture ID [%-7d] capture device ID\n", params.capture_id); + fprintf(stderr, " -mt N, --max-tokens N [%-7d] maximum number of tokens per audio chunk\n", params.max_tokens); + fprintf(stderr, " -ac N, --audio-ctx N [%-7d] audio context size (0 - all)\n", params.audio_ctx); + fprintf(stderr, " -vth N, --vad-thold N [%-7.2f] voice activity detection threshold\n", params.vad_thold); + fprintf(stderr, " -fth N, --freq-thold N [%-7.2f] high-pass frequency cutoff\n", params.freq_thold); + fprintf(stderr, " -su, --speed-up [%-7s] speed up audio by x2 (reduced accuracy)\n", params.speed_up ? "true" : "false"); + fprintf(stderr, " -tr, --translate [%-7s] translate from source language to english\n", params.translate ? "true" : "false"); + fprintf(stderr, " -ps, --print-special [%-7s] print special tokens\n", params.print_special ? "true" : "false"); + fprintf(stderr, " -pe, --print-energy [%-7s] print sound energy (for debugging)\n", params.print_energy ? "true" : "false"); + fprintf(stderr, " -ng, --no-gpu [%-7s] disable GPU\n", params.use_gpu ? "false" : "true"); + fprintf(stderr, " -l LANG, --language LANG [%-7s] spoken language\n", params.language.c_str()); + fprintf(stderr, " -m FNAME, --model FNAME [%-7s] model path\n", params.model.c_str()); + fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); + fprintf(stderr, " -cmd FNAME, --commands FNAME [%-7s] text file with allowed commands\n", params.commands.c_str()); + fprintf(stderr, " -p, --prompt [%-7s] the required activation prompt\n", params.prompt.c_str()); + fprintf(stderr, " -ctx, --context [%-7s] sample text to help the transcription\n", params.context.c_str()); + fprintf(stderr, " --grammar GRAMMAR [%-7s] GBNF grammar to guide decoding\n", params.grammar.c_str()); + fprintf(stderr, " --grammar-penalty N [%-7.1f] scales down logits of nongrammar tokens\n", params.grammar_penalty); + fprintf(stderr, "\n"); +} + +bool whisper_params_parse(int argc, char ** argv, whisper_params & params) { + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "-h" || arg == "--help") { + whisper_print_usage(argc, argv, params); + exit(0); + } + else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } + else if (arg == "-pms" || arg == "--prompt-ms") { params.prompt_ms = std::stoi(argv[++i]); } + else if (arg == "-cms" || arg == "--command-ms") { params.command_ms = std::stoi(argv[++i]); } + else if (arg == "-c" || arg == "--capture") { params.capture_id = std::stoi(argv[++i]); } + else if (arg == "-mt" || arg == "--max-tokens") { params.max_tokens = std::stoi(argv[++i]); } + else if (arg == "-ac" || arg == "--audio-ctx") { params.audio_ctx = std::stoi(argv[++i]); } + else if (arg == "-vth" || arg == "--vad-thold") { params.vad_thold = std::stof(argv[++i]); } + else if (arg == "-fth" || arg == "--freq-thold") { params.freq_thold = std::stof(argv[++i]); } + else if (arg == "-su" || arg == "--speed-up") { params.speed_up = true; } + else if (arg == "-tr" || arg == "--translate") { params.translate = true; } + else if (arg == "-ps" || arg == "--print-special") { params.print_special = true; } + else if (arg == "-pe" || arg == "--print-energy") { params.print_energy = true; } + else if (arg == "-ng" || arg == "--no-gpu") { params.use_gpu = false; } + else if (arg == "-l" || arg == "--language") { params.language = argv[++i]; } + else if (arg == "-m" || arg == "--model") { params.model = argv[++i]; } + else if (arg == "-f" || arg == "--file") { params.fname_out = argv[++i]; } + else if (arg == "-cmd" || arg == "--commands") { params.commands = argv[++i]; } + else if (arg == "-p" || arg == "--prompt") { params.prompt = argv[++i]; } + else if (arg == "-ctx" || arg == "--context") { params.context = argv[++i]; } + else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(argv[++i]); } + else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + } + + return true; +} + + +std::unique_ptr g_chess; +void set_moves(const std::string & /* moves */) { + fprintf(stdout, "%s", g_chess->stringifyBoard().c_str()); +} + +audio_async g_audio(30*1000); +void get_audio(int ms, std::vector & pcmf32_cur) { + g_audio.get(ms, pcmf32_cur); +} + +bool check_running() { + return sdl_poll_events(); +} + +int main(int argc, char ** argv) { + whisper_params params; + + if (whisper_params_parse(argc, argv, params) == false) { + return 1; + } + + if (whisper_lang_id(params.language.c_str()) == -1) { + fprintf(stderr, "error: unknown language '%s'\n", params.language.c_str()); + whisper_print_usage(argc, argv, params); + exit(0); + } + + // whisper init + + struct whisper_context_params cparams; + cparams.use_gpu = params.use_gpu; + + struct whisper_context * ctx = whisper_init_from_file_with_params(params.model.c_str(), cparams); + + // init audio + + if (!g_audio.init(params.capture_id, WHISPER_SAMPLE_RATE)) { + fprintf(stderr, "%s: audio.init() failed!\n", __func__); + return 1; + } + + whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_BEAM_SEARCH); + + wparams.print_progress = false; + wparams.print_special = params.print_special; + wparams.print_realtime = false; + wparams.print_timestamps = !params.no_timestamps; + wparams.translate = params.translate; + wparams.no_context = true; + wparams.no_timestamps = params.no_timestamps; + wparams.single_segment = true; + wparams.max_tokens = params.max_tokens; + wparams.language = params.language.c_str(); + wparams.n_threads = params.n_threads; + + wparams.audio_ctx = params.audio_ctx; + wparams.speed_up = params.speed_up; + + wparams.temperature = 0.4f; + wparams.temperature_inc = 1.0f; + wparams.greedy.best_of = 5; + + wparams.beam_search.beam_size = 5; + + g_audio.resume(); + + // wait for 1 second to avoid any buffered noise + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + g_audio.clear(); + + g_chess.reset(new Chess(ctx, wparams, nullptr, sdl_poll_events, get_audio, set_moves)); + set_moves({}); + g_chess->run(); + + g_audio.pause(); + + whisper_print_timings(ctx); + whisper_free(ctx); + + return 0; +} diff --git a/examples/chess.wasm/CMakeLists.txt b/examples/wchess/wchess.wasm/CMakeLists.txt similarity index 70% rename from examples/chess.wasm/CMakeLists.txt rename to examples/wchess/wchess.wasm/CMakeLists.txt index 4488134d..588a50e0 100644 --- a/examples/chess.wasm/CMakeLists.txt +++ b/examples/wchess/wchess.wasm/CMakeLists.txt @@ -1,18 +1,14 @@ -# -# libchess -# - -set(TARGET libchess) +set(TARGET wchess.wasm) add_executable(${TARGET} - emscripten.cpp + wchess.wasm.cpp ) include(DefaultTargetOptions) target_link_libraries(${TARGET} PRIVATE common - whisper + libwchess ) unset(EXTRA_FLAGS) @@ -24,8 +20,8 @@ if (WHISPER_WASM_SINGLE_FILE) add_custom_command( TARGET ${TARGET} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy - ${CMAKE_BINARY_DIR}/bin/libchess.js - ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/chess.wasm/js/chess.js + ${CMAKE_BINARY_DIR}/bin/${TARGET}.js + ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/chess.js ) endif() @@ -45,16 +41,11 @@ add_custom_command( TARGET ${TARGET} POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_directory ${CMAKE_CURRENT_SOURCE_DIR}/chessboardjs-1.0.0 - ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/chess.wasm/ + ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/ COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/jquery-3.7.1.min.js - ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/chess.wasm/js/ + ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/ ) -# -# chess.wasm -# - -set(TARGET chess.wasm) configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY) -configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/helpers.js @ONLY) +configure_file(${CMAKE_SOURCE_DIR}/examples/helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/js/helpers.js @ONLY) diff --git a/examples/chess.wasm/chessboardjs-1.0.0/CHANGELOG.md b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/CHANGELOG.md similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/CHANGELOG.md rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/CHANGELOG.md diff --git a/examples/chess.wasm/chessboardjs-1.0.0/LICENSE.md b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/LICENSE.md similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/LICENSE.md rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/LICENSE.md diff --git a/examples/chess.wasm/chessboardjs-1.0.0/README.md b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/README.md similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/README.md rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/README.md diff --git a/examples/chess.wasm/chessboardjs-1.0.0/css/chessboard-1.0.0.css b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/css/chessboard-1.0.0.css similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/css/chessboard-1.0.0.css rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/css/chessboard-1.0.0.css diff --git a/examples/chess.wasm/chessboardjs-1.0.0/css/chessboard-1.0.0.min.css b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/css/chessboard-1.0.0.min.css similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/css/chessboard-1.0.0.min.css rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/css/chessboard-1.0.0.min.css diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bB.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bB.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bB.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bB.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bK.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bK.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bK.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bK.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bN.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bN.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bN.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bN.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bP.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bP.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bP.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bP.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bQ.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bQ.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bQ.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bQ.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bR.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bR.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bR.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/bR.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wB.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wB.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wB.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wB.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wK.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wK.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wK.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wK.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wN.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wN.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wN.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wN.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wP.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wP.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wP.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wP.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wQ.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wQ.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wQ.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wQ.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wR.png b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wR.png similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wR.png rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/img/chesspieces/wikipedia/wR.png diff --git a/examples/chess.wasm/chessboardjs-1.0.0/js/chessboard-1.0.0.js b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/js/chessboard-1.0.0.js similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/js/chessboard-1.0.0.js rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/js/chessboard-1.0.0.js diff --git a/examples/chess.wasm/chessboardjs-1.0.0/js/chessboard-1.0.0.min.js b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/js/chessboard-1.0.0.min.js similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/js/chessboard-1.0.0.min.js rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/js/chessboard-1.0.0.min.js diff --git a/examples/chess.wasm/chessboardjs-1.0.0/package.json b/examples/wchess/wchess.wasm/chessboardjs-1.0.0/package.json similarity index 100% rename from examples/chess.wasm/chessboardjs-1.0.0/package.json rename to examples/wchess/wchess.wasm/chessboardjs-1.0.0/package.json diff --git a/examples/chess.wasm/index-tmpl.html b/examples/wchess/wchess.wasm/index-tmpl.html similarity index 95% rename from examples/chess.wasm/index-tmpl.html rename to examples/wchess/wchess.wasm/index-tmpl.html index 60bcf2a3..90aecbf1 100644 --- a/examples/chess.wasm/index-tmpl.html +++ b/examples/wchess/wchess.wasm/index-tmpl.html @@ -69,7 +69,7 @@
Status: not started -
[The recognized voice commands will be displayed here]
+
[The moves will be displayed here]


@@ -350,7 +350,7 @@ var nLines = 0; var intervalUpdate = null; - var transcribedAll = ''; + var movesAll = ''; function onStart() { if (!instance) { @@ -369,31 +369,29 @@ startRecording(); intervalUpdate = setInterval(function() { - var transcribed = Module.get_transcribed(); + var moves = Module.get_moves(); - if (transcribed != null && transcribed.length > 1) { - var moves = Module.get_moves(); - for (move of moves) { + if (moves != null && moves.length > 1) { + + for (move of moves.split(' ')) { board.move(move); } - Module.commit_moves(); - - transcribedAll += transcribed + '
'; + movesAll += moves + '
'; nLines++; // if more than 10 lines, remove the first line if (nLines > 10) { - var i = transcribedAll.indexOf('
'); + var i = movesAll.indexOf('
'); if (i > 0) { - transcribedAll = transcribedAll.substring(i + 4); + movesAll = movesAll.substring(i + 4); nLines--; } } } document.getElementById('state-status').innerHTML = Module.get_status(); - document.getElementById('state-transcribed').innerHTML = transcribedAll; + document.getElementById('state-moves').innerHTML = movesAll; }, 100); } diff --git a/examples/chess.wasm/jquery-3.7.1.min.js b/examples/wchess/wchess.wasm/jquery-3.7.1.min.js similarity index 100% rename from examples/chess.wasm/jquery-3.7.1.min.js rename to examples/wchess/wchess.wasm/jquery-3.7.1.min.js diff --git a/examples/wchess/wchess.wasm/wchess.wasm.cpp b/examples/wchess/wchess.wasm/wchess.wasm.cpp new file mode 100644 index 00000000..856efad6 --- /dev/null +++ b/examples/wchess/wchess.wasm/wchess.wasm.cpp @@ -0,0 +1,185 @@ +#include "ggml.h" +#include "common.h" + + +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +constexpr int N_THREAD = 8; + +std::vector g_contexts(4, nullptr); + +std::mutex g_mutex; +std::thread g_worker; + +std::atomic g_running(false); + +std::string g_status = ""; +std::string g_status_forced = ""; +std::string g_moves = ""; + +std::vector g_pcmf32; + +void set_status(const std::string & status) { + std::lock_guard lock(g_mutex); + g_status = status; +} + +void set_moves(const std::string & moves) { + std::lock_guard lock(g_mutex); + g_moves = moves; +} + +void get_audio(int ms, std::vector & audio) { + const int64_t n_samples = (ms * WHISPER_SAMPLE_RATE) / 1000; + + int64_t n_take = 0; + if (n_samples > (int) g_pcmf32.size()) { + n_take = g_pcmf32.size(); + } else { + n_take = n_samples; + } + + audio.resize(n_take); + std::copy(g_pcmf32.end() - n_take, g_pcmf32.end(), audio.begin()); +} + +bool check_running() { + g_pcmf32.clear(); + return g_running; +} + +EMSCRIPTEN_BINDINGS(command) { + emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { + for (size_t i = 0; i < g_contexts.size(); ++i) { + if (g_contexts[i] == nullptr) { + g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); + if (g_contexts[i] != nullptr) { + g_running = true; + if (g_worker.joinable()) { + g_worker.join(); + } + g_worker = std::thread([i]() { + + struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY); + + wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency()); + wparams.offset_ms = 0; + wparams.translate = false; + wparams.no_context = true; + wparams.single_segment = true; + wparams.print_realtime = false; + wparams.print_progress = false; + wparams.print_timestamps = true; + wparams.print_special = false; + + wparams.max_tokens = 32; + // wparams.audio_ctx = 768; // partial encoder context for better performance + + wparams.temperature = 0.4f; + wparams.temperature_inc = 1.0f; + wparams.greedy.best_of = 1; + + wparams.beam_search.beam_size = 5; + + wparams.language = "en"; + + printf("command: using %d threads\n", wparams.n_threads); + + Chess(g_contexts[i], + wparams, + set_status, + check_running, + get_audio, + set_moves).run(); + + if (i < g_contexts.size()) { + whisper_free(g_contexts[i]); + g_contexts[i] = nullptr; + } + + }); + + return i + 1; + } else { + return (size_t) 0; + } + } + } + + return (size_t) 0; + })); + + emscripten::function("free", emscripten::optional_override([](size_t /* index */) { + if (g_running) { + g_running = false; + } + })); + + emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) { + --index; + + if (index >= g_contexts.size()) { + return -1; + } + + if (g_contexts[index] == nullptr) { + return -2; + } + + { + std::lock_guard lock(g_mutex); + const int n = audio["length"].as(); + + emscripten::val heap = emscripten::val::module_property("HEAPU8"); + emscripten::val memory = heap["buffer"]; + + g_pcmf32.resize(n); + + emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast(g_pcmf32.data()), n); + memoryView.call("set", audio); + } + + return 0; + })); + + emscripten::function("get_moves", emscripten::optional_override([]() { + std::string moves; + + { + std::lock_guard lock(g_mutex); + moves = std::move(g_moves); + } + + + if (!moves.empty()) fprintf(stdout, "%s: Moves '%s%s%s'\n", __func__, "\033[1m", moves.c_str(), "\033[0m"); + + return moves; + })); + + emscripten::function("get_status", emscripten::optional_override([]() { + std::string status; + + { + std::lock_guard lock(g_mutex); + status = g_status_forced.empty() ? g_status : g_status_forced; + } + + return status; + })); + + emscripten::function("set_status", emscripten::optional_override([](const std::string & status) { + std::lock_guard lock(g_mutex); + g_status_forced = status; + })); +}