mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-21 16:09:55 +00:00
wchess: preparing dyn grammar
This commit is contained in:
@ -4,24 +4,6 @@
|
||||
#include "common.h"
|
||||
#include <thread>
|
||||
|
||||
static constexpr auto RULES =
|
||||
"\n"
|
||||
"root ::= move \".\"\n"
|
||||
"prompt ::= init \".\"\n"
|
||||
"\n"
|
||||
"# leading space is very important!\n"
|
||||
"init ::= \" rook to b4, f3\"\n"
|
||||
"\n"
|
||||
"move ::= \" \" ((piece | pawn | king) \" \" \"to \"?)? [a-h] [1-8]\n"
|
||||
"\n"
|
||||
"piece ::= \"bishop\" | \"rook\" | \"knight\" | \"queen\"\n"
|
||||
"king ::= \"king\"\n"
|
||||
"pawn ::= \"pawn\"\n"
|
||||
"\n";
|
||||
|
||||
static constexpr auto PROMPT = "rook to b4, f3";
|
||||
static constexpr auto CONTEXT = "d4 d5 knight to c3, pawn to a1, bishop to b2 king e8,";
|
||||
|
||||
WChess::WChess(whisper_context * ctx,
|
||||
const whisper_full_params & wparams,
|
||||
callbacks cb,
|
||||
@ -78,86 +60,38 @@ void WChess::run() {
|
||||
std::vector<float> pcmf32_cur;
|
||||
std::vector<float> pcmf32_prompt;
|
||||
|
||||
const std::string k_prompt = PROMPT;
|
||||
|
||||
auto grammar_parsed = grammar_parser::parse(RULES);
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
if (grammar_parsed.rules.empty()) {
|
||||
fprintf(stdout, "%s: Failed to parse grammar ...\n", __func__);
|
||||
}
|
||||
else {
|
||||
m_wparams.grammar_rules = grammar_rules.data();
|
||||
m_wparams.n_grammar_rules = grammar_rules.size();
|
||||
}
|
||||
std::string prompt = "";
|
||||
float prompt_prop = 0.0f;
|
||||
|
||||
while (check_running()) {
|
||||
// delay
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
if (ask_prompt) {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: Say the following phrase: '%s%s%s'\n", __func__, "\033[1m", k_prompt.c_str(), "\033[0m");
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
{
|
||||
char txt[1024];
|
||||
snprintf(txt, sizeof(txt), "Say the following phrase: '%s'", k_prompt.c_str());
|
||||
set_status(txt);
|
||||
}
|
||||
|
||||
ask_prompt = false;
|
||||
}
|
||||
|
||||
int64_t t_ms = 0;
|
||||
|
||||
{
|
||||
get_audio(m_settings.vad_ms, pcmf32_cur);
|
||||
|
||||
if (!pcmf32_cur.empty()) {
|
||||
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
|
||||
set_status("Speech detected! Processing ...");
|
||||
fprintf(stdout, "%s: Processing ...\n", __func__);
|
||||
set_status("Processing ...");
|
||||
|
||||
if (!have_prompt) {
|
||||
|
||||
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("prompt");
|
||||
const auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
|
||||
fprintf(stdout, "%s: Heard '%s%s%s', (t = %d ms)\n", __func__, "\033[1m", txt.c_str(), "\033[0m", (int) t_ms);
|
||||
|
||||
const float sim = similarity(txt, k_prompt);
|
||||
|
||||
if (txt.length() < 0.8*k_prompt.length() || txt.length() > 1.2*k_prompt.length() || sim < 0.8f) {
|
||||
fprintf(stdout, "%s: WARNING: prompt not recognized, try again\n", __func__);
|
||||
ask_prompt = true;
|
||||
} else {
|
||||
fprintf(stdout, "\n");
|
||||
fprintf(stdout, "%s: The prompt has been recognized!\n", __func__);
|
||||
fprintf(stdout, "%s: Waiting for voice commands ...\n", __func__);
|
||||
fprintf(stdout, "\n");
|
||||
|
||||
{
|
||||
char txt[1024];
|
||||
snprintf(txt, sizeof(txt), "Success! Waiting for voice commands ...");
|
||||
set_status(txt);
|
||||
}
|
||||
|
||||
// save the audio for the prompt
|
||||
pcmf32_prompt = pcmf32_cur;
|
||||
have_prompt = true;
|
||||
}
|
||||
} else {
|
||||
// prepend 3 second of silence
|
||||
// pcmf32_cur.insert(pcmf32_cur.begin(), 3*WHISPER_SAMPLE_RATE, 0.0f);
|
||||
|
||||
// prepend the prompt audio
|
||||
// pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
{
|
||||
if (!pcmf32_prompt.empty()) pcmf32_cur.insert(pcmf32_cur.begin(), pcmf32_prompt.begin(), pcmf32_prompt.end());
|
||||
|
||||
if (WHISPER_SAMPLE_RATE > pcmf32_cur.size()) pcmf32_cur.insert(pcmf32_cur.begin(), WHISPER_SAMPLE_RATE - pcmf32_cur.size(), 0.0f);
|
||||
|
||||
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("root");
|
||||
std::string rules = m_board->getRules();
|
||||
fprintf(stdout, "%s: grammar rules:\n'%s'\n", __func__, rules.c_str());
|
||||
|
||||
auto grammar_parsed = grammar_parser::parse(rules.c_str());
|
||||
auto grammar_rules = grammar_parsed.c_rules();
|
||||
|
||||
m_wparams.grammar_rules = grammar_rules.data();
|
||||
m_wparams.n_grammar_rules = grammar_rules.size();
|
||||
|
||||
m_wparams.i_start_rule = grammar_parsed.symbol_ids.at("move");
|
||||
auto txt = ::trim(transcribe(pcmf32_cur, logprob_min, logprob_sum, n_tokens, t_ms));
|
||||
txt = PROMPT + txt;
|
||||
|
||||
const float p = 100.0f * std::exp(logprob_min);
|
||||
|
||||
@ -166,20 +100,18 @@ void WChess::run() {
|
||||
// find the prompt in the text
|
||||
float best_sim = 0.0f;
|
||||
size_t best_len = 0;
|
||||
for (int n = 0.8*k_prompt.size(); n <= 1.2*k_prompt.size(); ++n) {
|
||||
if (n >= int(txt.size())) {
|
||||
break;
|
||||
}
|
||||
if (!prompt.empty()) {
|
||||
auto pos = txt.find_first_of('.');
|
||||
|
||||
const auto prompt = txt.substr(0, n);
|
||||
const auto header = txt.substr(0, pos);
|
||||
|
||||
const float sim = similarity(prompt, k_prompt);
|
||||
const float sim = similarity(prompt, header);
|
||||
|
||||
//fprintf(stderr, "%s: prompt = '%s', sim = %f\n", __func__, prompt.c_str(), sim);
|
||||
|
||||
if (sim > best_sim) {
|
||||
best_sim = sim;
|
||||
best_len = n;
|
||||
best_len = pos + 1;
|
||||
}
|
||||
}
|
||||
|
||||
@ -195,7 +127,10 @@ void WChess::run() {
|
||||
set_status(txt);
|
||||
}
|
||||
if (!command.empty()) {
|
||||
set_moves(m_board->process(command));
|
||||
auto move = m_board->process(command);
|
||||
if (!move.empty()) {
|
||||
set_moves(std::move(move));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
Reference in New Issue
Block a user