diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index c6a2c305..eb791a21 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -60,6 +60,7 @@ if (EMSCRIPTEN) add_subdirectory(command.wasm) add_subdirectory(talk.wasm) add_subdirectory(bench.wasm) + add_subdirectory(chess.wasm) elseif(CMAKE_JS_VERSION) add_subdirectory(addon.node) else() diff --git a/examples/chess.wasm/CMakeLists.txt b/examples/chess.wasm/CMakeLists.txt new file mode 100644 index 00000000..06a625ec --- /dev/null +++ b/examples/chess.wasm/CMakeLists.txt @@ -0,0 +1,50 @@ +# +# libchess +# + +set(TARGET libchess) + +add_executable(${TARGET} + emscripten.cpp + ) + +include(DefaultTargetOptions) + +target_link_libraries(${TARGET} PRIVATE + common + whisper + ) + +unset(EXTRA_FLAGS) + +if (WHISPER_WASM_SINGLE_FILE) + set(EXTRA_FLAGS "-s SINGLE_FILE=1") + message(STATUS "Embedding WASM inside chess.js") + + add_custom_command( + TARGET ${TARGET} POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + ${CMAKE_BINARY_DIR}/bin/libchess.js + ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/chess.wasm/chess.js + ) +endif() + +set_target_properties(${TARGET} PROPERTIES LINK_FLAGS " \ + --bind \ + -s USE_PTHREADS=1 \ + -s PTHREAD_POOL_SIZE=8 \ + -s INITIAL_MEMORY=1024MB \ + -s TOTAL_MEMORY=1024MB \ + -s FORCE_FILESYSTEM=1 \ + -s EXPORTED_RUNTIME_METHODS=\"['print', 'printErr', 'ccall', 'cwrap']\" \ + ${EXTRA_FLAGS} \ + ") + +# +# chess.wasm +# + +set(TARGET chess.wasm) + +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/index-tmpl.html ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/index.html @ONLY) +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/../helpers.js ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/${TARGET}/helpers.js @ONLY) diff --git a/examples/chess.wasm/emscripten.cpp b/examples/chess.wasm/emscripten.cpp new file mode 100644 index 00000000..237d1ade --- /dev/null +++ b/examples/chess.wasm/emscripten.cpp @@ -0,0 +1,384 @@ +#include "ggml.h" +#include "common.h" +#include "whisper.h" +#include "grammar-parser.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +constexpr int N_THREAD = 8; + +std::vector g_contexts(4, nullptr); + +std::mutex g_mutex; +std::thread g_worker; + +std::atomic g_running(false); + +std::string g_status = ""; +std::string g_status_forced = ""; +std::string g_transcribed = ""; + +std::vector g_pcmf32; + +void command_set_status(const std::string & status) { + std::lock_guard lock(g_mutex); + g_status = status; +} + +std::string command_transcribe( + whisper_context * ctx, + const whisper_full_params & wparams, + const std::vector & pcmf32, + float & logprob_min, + float & logprob_sum, + int & n_tokens, + int64_t & t_ms) { + const auto t_start = std::chrono::high_resolution_clock::now(); + + logprob_min = 0.0f; + logprob_sum = 0.0f; + n_tokens = 0; + t_ms = 0; + + if (whisper_full(ctx, wparams, pcmf32.data(), pcmf32.size()) != 0) { + return ""; + } + + std::string result; + + const int n_segments = whisper_full_n_segments(ctx); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(ctx, i); + + result += text; + + const int n = whisper_full_n_tokens(ctx, i); + for (int j = 0; j < n; ++j) { + const auto token = whisper_full_get_token_data(ctx, i, j); + + if(token.plog > 0.0f) exit(0); // todo: check for emscripten + logprob_min = std::min(logprob_min, token.plog); + logprob_sum += token.plog; + ++n_tokens; + } + } + + const auto t_end = std::chrono::high_resolution_clock::now(); + t_ms = std::chrono::duration_cast(t_end - t_start).count(); + + return result; +} + +void command_get_audio(int ms, int sample_rate, std::vector & audio) { + const int64_t n_samples = (ms * sample_rate) / 1000; + + int64_t n_take = 0; + if (n_samples > (int) g_pcmf32.size()) { + n_take = g_pcmf32.size(); + } else { + n_take = n_samples; + } + + audio.resize(n_take); + std::copy(g_pcmf32.end() - n_take, g_pcmf32.end(), audio.begin()); +} + +void command_main(size_t index) { + command_set_status("loading data ..."); + + struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_BEAM_SEARCH); + + wparams.n_threads = std::min(N_THREAD, (int) std::thread::hardware_concurrency()); + wparams.offset_ms = 0; + wparams.translate = false; + wparams.no_context = true; + wparams.single_segment = true; + wparams.print_realtime = false; + wparams.print_progress = false; + wparams.print_timestamps = true; + wparams.print_special = false; + + wparams.max_tokens = 32; + wparams.audio_ctx = 768; // partial encoder context for better performance + + wparams.temperature = 0.4f; + wparams.temperature_inc = 1.0f; + wparams.greedy.best_of = 5; + + wparams.beam_search.beam_size = 5; + + wparams.language = "en"; + + printf("command: using %d threads\n", wparams.n_threads); + + bool have_prompt = false; + bool ask_prompt = true; + bool print_energy = false; + + float logprob_min0 = 0.0f; + float logprob_min = 0.0f; + + float logprob_sum0 = 0.0f; + float logprob_sum = 0.0f; + + int n_tokens0 = 0; + int n_tokens = 0; + + std::vector pcmf32_cur; + std::vector pcmf32_prompt; + + // todo: grammar to be based on js input + const std::string k_prompt = "rook to b4, f3,"; + wparams.initial_prompt = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,"; + + auto grammar_parsed = grammar_parser::parse( +"\n" +"root ::= init move move? move? \".\"\n" +"prompt ::= init \".\"\n" +"\n" +"# leading space is very important!\n" +"init ::= \" rook to b4, f3\"\n" +"\n" +"move ::= \", \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n" +"\n" +"piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n" +"king ::= \"king\"\n" +"pawn ::= \"pawn\"\n" +"\n" + ); + auto grammar_rules = grammar_parsed.c_rules(); + + if (grammar_parsed.rules.empty()) { + fprintf(stdout, "%s: Failed to parse grammar ...\n", __func__); + } + else { + wparams.grammar_rules = grammar_rules.data(); + wparams.n_grammar_rules = grammar_rules.size(); + wparams.grammar_penalty = 100.0; + } + + // whisper context + auto & ctx = g_contexts[index]; + + const int32_t vad_ms = 2000; + const int32_t prompt_ms = 5000; + const int32_t command_ms = 4000; + + const float vad_thold = 0.1f; + const float freq_thold = -1.0f; + + while (g_running) { + // delay + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + if (ask_prompt) { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m"); + fprintf(stdout, "\n"); + + { + char txt[1024]; + snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str()); + command_set_status(txt); + } + + ask_prompt = false; + } + + int64_t t_ms = 0; + + { + command_get_audio(vad_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); + + if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1000, vad_thold, freq_thold, print_energy)) { + fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__); + command_set_status("Speech detected! Processing ..."); + + if (!have_prompt) { + command_get_audio(prompt_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); + + wparams.i_start_rule = grammar_parsed.symbol_ids.at("prompt"); + const auto txt = ::trim(::command_transcribe(ctx, wparams, pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); + + fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms); + + const float sim = similarity(txt, k_prompt); + + if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) { + fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__); + ask_prompt = true; + } else { + fprintf(stdout, "\n"); + fprintf(stdout, "%s: The prompt has been recognized!\n", __func__); + fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__); + fprintf(stdout, "\n"); + + { + char txt[1024]; + snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ..."); + command_set_status(txt); + } + + // save the audio for the prompt + pcmf32_prompt = pcmf32_cur; + have_prompt = true; + } + } else { + command_get_audio(command_ms, WHISPER_SAMPLE_RATE, pcmf32_cur); + + // prepend 3 second of silence + pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f); + + // prepend the prompt audio + pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end()); + + wparams.i_start_rule = grammar_parsed.symbol_ids.at("root"); + const auto txt = ::trim(::command_transcribe(ctx, wparams, pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms)); + + const float p = 100.0f * std::exp(logprob_min); + + fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str()); + + // find the prompt in the text + float best_sim = 0.0f; + size_t best_len = 0; + for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) { + if (n >= int(txt.size())) { + break; + } + + const auto prompt = txt.substr(0, n); + + const float sim = similarity(prompt, k_prompt); + + //fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim); + + if (sim > best_sim) { + best_sim = sim; + best_len = n; + } + } + + fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p); + const std::string command = ::trim(txt.substr(best_len)); + + fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms); + fprintf(stdout, "\n"); + + { + char txt[1024]; + snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms); + command_set_status(txt); + } + { + std::lock_guard lock(g_mutex); + g_transcribed = command; + } + } + + g_pcmf32.clear(); + } + } + } + + if (index < g_contexts.size()) { + whisper_free(g_contexts[index]); + g_contexts[index] = nullptr; + } +} + +EMSCRIPTEN_BINDINGS(command) { + emscripten::function("init", emscripten::optional_override([](const std::string & path_model) { + for (size_t i = 0; i < g_contexts.size(); ++i) { + if (g_contexts[i] == nullptr) { + g_contexts[i] = whisper_init_from_file_with_params(path_model.c_str(), whisper_context_default_params()); + if (g_contexts[i] != nullptr) { + g_running = true; + if (g_worker.joinable()) { + g_worker.join(); + } + g_worker = std::thread([i]() { + command_main(i); + }); + + return i + 1; + } else { + return (size_t) 0; + } + } + } + + return (size_t) 0; + })); + + emscripten::function("free", emscripten::optional_override([](size_t index) { + if (g_running) { + g_running = false; + } + })); + + emscripten::function("set_audio", emscripten::optional_override([](size_t index, const emscripten::val & audio) { + --index; + + if (index >= g_contexts.size()) { + return -1; + } + + if (g_contexts[index] == nullptr) { + return -2; + } + + { + std::lock_guard lock(g_mutex); + const int n = audio["length"].as(); + + emscripten::val heap = emscripten::val::module_property("HEAPU8"); + emscripten::val memory = heap["buffer"]; + + g_pcmf32.resize(n); + + emscripten::val memoryView = audio["constructor"].new_(memory, reinterpret_cast(g_pcmf32.data()), n); + memoryView.call("set", audio); + } + + return 0; + })); + + emscripten::function("get_transcribed", emscripten::optional_override([]() { + std::string transcribed; + + { + std::lock_guard lock(g_mutex); + transcribed = std::move(g_transcribed); + } + + return transcribed; + })); + + emscripten::function("get_status", emscripten::optional_override([]() { + std::string status; + + { + std::lock_guard lock(g_mutex); + status = g_status_forced.empty() ? g_status : g_status_forced; + } + + return status; + })); + + emscripten::function("set_status", emscripten::optional_override([](const std::string & status) { + { + std::lock_guard lock(g_mutex); + g_status_forced = status; + } + })); +} diff --git a/examples/chess.wasm/index-tmpl.html b/examples/chess.wasm/index-tmpl.html new file mode 100644 index 00000000..e16d7908 --- /dev/null +++ b/examples/chess.wasm/index-tmpl.html @@ -0,0 +1,391 @@ + + + + command : Voice assistant example using Whisper + WebAssembly + + + + +
+ command : Voice assistant example using Whisper + WebAssembly + +

+ + You can find more about this project on GitHub. + +

+ + More examples: + main | + bench | + stream | + command | + talk | + +

+ +
+ +
+ Whisper model: + + + +
+ +
+ +
+ + + +
+ +
+ +
+ Status: not started + +
[The recognized voice commands will be displayed here]
+
+ +
+ + Debug output: + + +
+ + Troubleshooting + +

+ + The page does some heavy computations, so make sure: + +
    +
  • To use a modern web browser (e.g. Chrome, Firefox)
  • +
  • To use a fast desktop or laptop computer (i.e. not a mobile phone)
  • +
  • Your browser supports WASM Fixed-width SIMD
  • +
+ +
+ + | + Build time: @GIT_DATE@ | + Commit hash: @GIT_SHA1@ | + Commit subject: @GIT_COMMIT_SUBJECT@ | + Source Code | + +
+
+ + + + + +