Merge pull request #1 from ggerganov/gg/wchess

wchess : add clear_audio callback
This commit is contained in:
fraxy-v 2023-11-28 15:45:17 +02:00 committed by GitHub
commit 8dba8204eb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 26 additions and 4 deletions

View File

@ -1,3 +1,4 @@
set(CMAKE_CXX_STANDARD 11)
add_subdirectory(libwchess) add_subdirectory(libwchess)

View File

@ -48,6 +48,11 @@ bool WChess::check_running() const {
return false; return false;
} }
bool WChess::clear_audio() const {
if (m_cb.clear_audio) return (*m_cb.clear_audio)();
return false;
}
void WChess::get_audio(int ms, std::vector<float>& pcmf32) const { void WChess::get_audio(int ms, std::vector<float>& pcmf32) const {
if (m_cb.get_audio) (*m_cb.get_audio)(ms, pcmf32); if (m_cb.get_audio) (*m_cb.get_audio)(ms, pcmf32);
} }
@ -195,6 +200,8 @@ void WChess::run() {
set_moves(m_board->process(command)); set_moves(m_board->process(command));
} }
} }
clear_audio();
} }
} }
} }

View File

@ -12,20 +12,22 @@ public:
using CheckRunningCb = bool (*)(); using CheckRunningCb = bool (*)();
using GetAudioCb = void (*)(int, std::vector<float> &); using GetAudioCb = void (*)(int, std::vector<float> &);
using SetMovesCb = void (*)(const std::string &); using SetMovesCb = void (*)(const std::string &);
using CleartAudioCb = bool (*)();
struct callbacks { struct callbacks {
SetStatusCb set_status = nullptr; SetStatusCb set_status = nullptr;
CheckRunningCb check_running = nullptr; CheckRunningCb check_running = nullptr;
GetAudioCb get_audio = nullptr; GetAudioCb get_audio = nullptr;
SetMovesCb set_moves = nullptr; SetMovesCb set_moves = nullptr;
CleartAudioCb clear_audio = nullptr;
}; };
struct settings { struct settings {
int32_t vad_ms = 2000; int32_t vad_ms = 2000;
int32_t prompt_ms = 5000; int32_t prompt_ms = 5000;
int32_t command_ms = 4000; int32_t command_ms = 4000;
float vad_thold = 0.1f; float vad_thold = 0.2f;
float freq_thold = -1.0f; float freq_thold = 100.0f;
bool print_energy = false; bool print_energy = false;
}; };
@ -44,6 +46,7 @@ private:
void set_status(const std::string& msg) const; void set_status(const std::string& msg) const;
void set_moves(const std::string& moves) const; void set_moves(const std::string& moves) const;
bool check_running() const; bool check_running() const;
bool clear_audio() const;
std::string transcribe( std::string transcribe(
const std::vector<float> & pcmf32, const std::vector<float> & pcmf32,
float & logprob_min, float & logprob_min,

View File

@ -118,6 +118,10 @@ void get_audio(int ms, std::vector<float> & pcmf32_cur) {
g_audio.get(ms, pcmf32_cur); g_audio.get(ms, pcmf32_cur);
} }
bool clear_audio() {
g_audio.clear();
}
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
whisper_params params; whisper_params params;
@ -180,6 +184,7 @@ int main(int argc, char ** argv) {
cb.check_running = sdl_poll_events; cb.check_running = sdl_poll_events;
cb.get_audio = get_audio; cb.get_audio = get_audio;
cb.set_moves = set_moves; cb.set_moves = set_moves;
cb.clear_audio = clear_audio;
WChess::settings s; WChess::settings s;
s.vad_ms = 2000; s.vad_ms = 2000;

View File

@ -44,10 +44,15 @@ void get_audio(int ms, std::vector<float> & audio) {
} }
bool check_running() { bool check_running() {
g_pcmf32.clear(); //g_pcmf32.clear();
return g_running; return g_running;
} }
bool clear_audio() {
g_pcmf32.clear();
return true;
}
void wchess_main(size_t i) { void wchess_main(size_t i) {
struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY); struct whisper_full_params wparams = whisper_full_default_params(whisper_sampling_strategy::WHISPER_SAMPLING_GREEDY);
@ -81,6 +86,7 @@ void wchess_main(size_t i) {
cb.check_running = check_running; cb.check_running = check_running;
cb.get_audio = get_audio; cb.get_audio = get_audio;
cb.set_moves = set_moves; cb.set_moves = set_moves;
cb.clear_audio = clear_audio;
WChess(g_contexts[i], wparams, cb, {}).run(); WChess(g_contexts[i], wparams, cb, {}).run();
if (i < g_contexts.size()) { if (i < g_contexts.size()) {