2023-11-25 10:16:48 +02:00
|
|
|
#include "WChess.h"
|
2023-11-25 11:34:06 +02:00
|
|
|
#include "Chessboard.h"
|
2023-11-25 10:16:48 +02:00
|
|
|
#include "grammar-parser.h"
|
|
|
|
#include "common.h"
|
|
|
|
#include <thread>
|
|
|
|
|
2023-11-25 11:34:06 +02:00
|
|
|
WChess::WChess(whisper_context * ctx,
|
2023-11-25 10:16:48 +02:00
|
|
|
const whisper_full_params & wparams,
|
2023-11-25 11:34:06 +02:00
|
|
|
callbacks cb,
|
|
|
|
settings s)
|
2023-11-25 10:16:48 +02:00
|
|
|
: m_ctx(ctx)
|
|
|
|
, m_wparams(wparams)
|
2023-11-25 11:34:06 +02:00
|
|
|
, m_cb(cb)
|
|
|
|
, m_settings(s)
|
|
|
|
, m_board(new Chessboard())
|
2023-11-25 10:16:48 +02:00
|
|
|
{}
|
|
|
|
|
2023-11-25 11:34:06 +02:00
|
|
|
WChess::~WChess() = default;
|
|
|
|
|
|
|
|
void WChess::set_status(const std::string& msg) const {
|
|
|
|
if (m_cb.set_status) (*m_cb.set_status)(msg);
|
2023-11-25 10:16:48 +02:00
|
|
|
}
|
|
|
|
|
2023-11-25 11:34:06 +02:00
|
|
|
void WChess::set_moves(const std::string& moves) const {
|
|
|
|
if (m_cb.set_moves) (*m_cb.set_moves)(moves);
|
2023-11-25 10:16:48 +02:00
|
|
|
}
|
|
|
|
|
2023-11-25 11:34:06 +02:00
|
|
|
bool WChess::check_running() const {
|
|
|
|
if (m_cb.check_running) return (*m_cb.check_running)();
|
2023-11-25 10:16:48 +02:00
|
|
|
return false;
|
|
|
|
}
|
|
|
|
|
2023-11-28 16:21:46 +02:00
|
|
|
void WChess::clear_audio() const {
|
|
|
|
if (m_cb.clear_audio) (*m_cb.clear_audio)();
|
2023-11-28 13:37:26 +02:00
|
|
|
}
|
|
|
|
|
2023-11-25 11:34:06 +02:00
|
|
|
void WChess::get_audio(int ms, std::vector<float>& pcmf32) const {
|
|
|
|
if (m_cb.get_audio) (*m_cb.get_audio)(ms, pcmf32);
|
2023-11-25 10:16:48 +02:00
|
|
|
}
|
|
|
|
|
2023-11-25 11:34:06 +02:00
|
|
|
std::string WChess::stringify_board() const {
|
|
|
|
return m_board->stringifyBoard();
|
2023-11-25 10:16:48 +02:00
|
|
|
}
|
|
|
|
|
2023-11-25 11:34:06 +02:00
|
|
|
void WChess::run() {
|
2023-11-25 10:16:48 +02:00
|
|
|
set_status("loading data ...");
|
|
|
|
|
2023-11-29 09:25:45 +02:00
|
|
|
bool have_prompt = true;
|
|
|
|
bool ask_prompt = false;
|
2023-11-25 10:16:48 +02:00
|
|
|
|
|
|
|
float logprob_min0 = 0.0f;
|
|
|
|
float logprob_min = 0.0f;
|
|
|
|
|
|
|
|
float logprob_sum0 = 0.0f;
|
|
|
|
float logprob_sum = 0.0f;
|
|
|
|
|
|
|
|
int n_tokens0 = 0;
|
|
|
|
int n_tokens = 0;
|
|
|
|
|
|
|
|
std::vector<float> pcmf32_cur;
|
|
|
|
std::vector<float> pcmf32_prompt;
|
|
|
|
|
2023-11-29 15:29:16 +02:00
|
|
|
std::string prompt = "";
|
|
|
|
float prompt_prop = 0.0f;
|
2023-11-25 10:16:48 +02:00
|
|
|
|
|
|
|
while (check_running()) {
|
|
|
|
// delay
|
|
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
|
|
|
|
|
|
|
int64_t t_ms = 0;
|
|
|
|
|
|
|
|
{
|
2023-11-25 11:34:06 +02:00
|
|
|
get_audio(m_settings.vad_ms, pcmf32_cur);
|
2023-11-25 10:16:48 +02:00
|
|
|
|
2023-11-28 19:03:17 +02:00
|
|
|
if (!pcmf32_cur.empty()) {
|
2023-11-29 15:29:16 +02:00
|
|
|
fprintf(stdout, "%s: Processing ...\n", __func__);
|
|
|
|
set_status("Processing ...");
|
2023-11-25 10:16:48 +02:00
|
|
|
|
2023-11-29 15:29:16 +02:00
|
|
|
{
|
|
|
|
if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
2023-11-25 10:16:48 +02:00
|
|
|
|
2023-11-29 15:29:16 +02:00
|
|
|
if (WHISPER_SAMPLE_RATE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), WHISPER_SAMPLE_RATE - pcmf32_cur.size(), 0.0f);
|
2023-11-25 10:16:48 +02:00
|
|
|
|
2023-11-29 15:29:16 +02:00
|
|
|
std::string rules = m_board->getRules();
|
|
|
|
fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, rules.c_str());
|
2023-11-25 10:16:48 +02:00
|
|
|
|
2023-11-29 15:29:16 +02:00
|
|
|
auto grammar_parsed = grammar_parser::parse(rules.c_str());
|
|
|
|
auto grammar_rules = grammar_parsed.c_rules();
|
2023-11-29 09:25:45 +02:00
|
|
|
|
2023-11-29 15:29:16 +02:00
|
|
|
m_wparams.grammar_rules = grammar_rules.data();
|
|
|
|
m_wparams.n_grammar_rules = grammar_rules.size();
|
2023-11-25 10:16:48 +02:00
|
|
|
|
2023-11-29 15:29:16 +02:00
|
|
|
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move");
|
2023-11-29 09:25:45 +02:00
|
|
|
auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
2023-11-25 10:16:48 +02:00
|
|
|
|
|
|
|
const float p = 100.0f * std::exp(logprob_min);
|
|
|
|
|
|
|
|
fprintf(stdout, "%s: heard '%s'\n", __func__, txt.c_str());
|
|
|
|
|
|
|
|
// find the prompt in the text
|
|
|
|
float best_sim = 0.0f;
|
|
|
|
size_t best_len = 0;
|
2023-11-29 15:29:16 +02:00
|
|
|
if (!prompt.empty()) {
|
|
|
|
auto pos = txt.find_first_of('.');
|
2023-11-25 10:16:48 +02:00
|
|
|
|
2023-11-29 15:29:16 +02:00
|
|
|
const auto header = txt.substr(0, pos);
|
2023-11-25 10:16:48 +02:00
|
|
|
|
2023-11-29 15:29:16 +02:00
|
|
|
const float sim = similarity(prompt, header);
|
2023-11-25 10:16:48 +02:00
|
|
|
|
|
|
|
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
|
|
|
|
|
|
|
if (sim > best_sim) {
|
|
|
|
best_sim = sim;
|
2023-11-29 15:29:16 +02:00
|
|
|
best_len = pos + 1;
|
2023-11-25 10:16:48 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
fprintf(stdout, "%s: DEBUG: txt = '%s', prob = %.2f%%\n", __func__, txt.c_str(), p);
|
|
|
|
std::string command = ::trim(txt.substr(best_len));
|
|
|
|
|
|
|
|
fprintf(stdout, "%s: Command '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", command.c_str(), "\033[0m", (int) t_ms);
|
|
|
|
fprintf(stdout, "\n");
|
|
|
|
|
|
|
|
{
|
|
|
|
char txt[1024];
|
|
|
|
snprintf(txt, sizeof(txt), "Command '%s', (t = %d ms)", command.c_str(), (int) t_ms);
|
|
|
|
set_status(txt);
|
|
|
|
}
|
|
|
|
if (!command.empty()) {
|
2023-11-29 15:29:16 +02:00
|
|
|
auto move = m_board->process(command);
|
|
|
|
if (!move.empty()) {
|
|
|
|
set_moves(std::move(move));
|
|
|
|
}
|
2023-11-25 10:16:48 +02:00
|
|
|
}
|
|
|
|
}
|
2023-11-28 13:37:26 +02:00
|
|
|
|
|
|
|
clear_audio();
|
2023-11-25 10:16:48 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-11-25 11:34:06 +02:00
|
|
|
std::string WChess::transcribe(
|
2023-11-25 10:16:48 +02:00
|
|
|
const std::vector<float> & pcmf32,
|
|
|
|
float & logprob_min,
|
|
|
|
float & logprob_sum,
|
|
|
|
int & n_tokens,
|
|
|
|
int64_t & t_ms) {
|
|
|
|
const auto t_start = std::chrono::high_resolution_clock::now();
|
|
|
|
|
|
|
|
logprob_min = 0.0f;
|
|
|
|
logprob_sum = 0.0f;
|
|
|
|
n_tokens = 0;
|
|
|
|
t_ms = 0;
|
|
|
|
|
|
|
|
if (whisper_full(m_ctx, m_wparams, pcmf32.data(), pcmf32.size()) != 0) {
|
2023-11-25 11:34:06 +02:00
|
|
|
return {};
|
2023-11-25 10:16:48 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
std::string result;
|
|
|
|
|
|
|
|
const int n_segments = whisper_full_n_segments(m_ctx);
|
|
|
|
for (int i = 0; i < n_segments; ++i) {
|
|
|
|
const char * text = whisper_full_get_segment_text(m_ctx, i);
|
|
|
|
|
|
|
|
result += text;
|
|
|
|
|
|
|
|
const int n = whisper_full_n_tokens(m_ctx, i);
|
|
|
|
for (int j = 0; j < n; ++j) {
|
|
|
|
const auto token = whisper_full_get_token_data(m_ctx, i, j);
|
|
|
|
|
|
|
|
if(token.plog > 0.0f) return {};
|
|
|
|
logprob_min = std::min(logprob_min, token.plog);
|
|
|
|
logprob_sum += token.plog;
|
|
|
|
++n_tokens;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
const auto t_end = std::chrono::high_resolution_clock::now();
|
|
|
|
t_ms = std::chrono::duration_cast<std::chrono::milliseconds>(t_end - t_start).count();
|
|
|
|
|
|
|
|
return result;
|
|
|
|
}
|