examples : refactor in order to reuse code and reduce duplication (#482)

* examples : refactor common code into a library

* examples : refactor common SDL code into a library

* make : update Makefile to use common libs

* common : fix MSVC M_PI ..

* addon.node : link common lib
This commit is contained in:
Georgi Gerganov
2023-02-15 19:28:10 +02:00
committed by GitHub
parent 0336161b7d
commit 09d7d2b68e
19 changed files with 580 additions and 1254 deletions

View File

@ -7,7 +7,7 @@ if (WHISPER_SUPPORT_SDL2)
# TODO: this is temporary
# need to export ggml symbols for MSVC, but too lazy ..
add_executable(${TARGET} talk.cpp gpt-2.cpp ../../ggml.c ../../whisper.cpp)
add_executable(${TARGET} talk.cpp gpt-2.cpp ../common.cpp ../common-sdl.cpp ../../ggml.c ../../whisper.cpp)
include(DefaultTargetOptions)

View File

@ -1,16 +1,14 @@
// Talk with AI
//
#include "common.h"
#include "common-sdl.h"
#include "whisper.h"
#include "gpt-2.h"
#include <SDL.h>
#include <SDL_audio.h>
#include <cassert>
#include <cstdio>
#include <fstream>
#include <mutex>
#include <regex>
#include <string>
#include <thread>
@ -105,320 +103,6 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, "\n");
}
//
// SDL Audio capture
//
class audio_async {
public:
audio_async(int len_ms);
~audio_async();
bool init(int capture_id, int sample_rate);
// start capturing audio via the provided SDL callback
// keep last len_ms seconds of audio in a circular buffer
bool resume();
bool pause();
bool clear();
// callback to be called by SDL
void callback(uint8_t * stream, int len);
// get audio data from the circular buffer
void get(int ms, std::vector<float> & audio);
private:
SDL_AudioDeviceID m_dev_id_in = 0;
int m_len_ms = 0;
int m_sample_rate = 0;
bool m_running = false;
std::mutex m_mutex;
std::vector<float> m_audio;
std::vector<float> m_audio_new;
size_t m_audio_pos = 0;
size_t m_audio_len = 0;
};
audio_async::audio_async(int len_ms) {
m_len_ms = len_ms;
}
audio_async::~audio_async() {
if (m_dev_id_in) {
SDL_CloseAudioDevice(m_dev_id_in);
}
}
bool audio_async::init(int capture_id, int sample_rate) {
SDL_LogSetPriority(SDL_LOG_CATEGORY_APPLICATION, SDL_LOG_PRIORITY_INFO);
if (SDL_Init(SDL_INIT_AUDIO) < 0) {
SDL_LogError(SDL_LOG_CATEGORY_APPLICATION, "Couldn't initialize SDL: %s\n", SDL_GetError());
return false;
}
SDL_SetHintWithPriority(SDL_HINT_AUDIO_RESAMPLING_MODE, "medium", SDL_HINT_OVERRIDE);
{
int nDevices = SDL_GetNumAudioDevices(SDL_TRUE);
fprintf(stderr, "%s: found %d capture devices:\n", __func__, nDevices);
for (int i = 0; i < nDevices; i++) {
fprintf(stderr, "%s: - Capture device #%d: '%s'\n", __func__, i, SDL_GetAudioDeviceName(i, SDL_TRUE));
}
}
SDL_AudioSpec capture_spec_requested;
SDL_AudioSpec capture_spec_obtained;
SDL_zero(capture_spec_requested);
SDL_zero(capture_spec_obtained);
capture_spec_requested.freq = sample_rate;
capture_spec_requested.format = AUDIO_F32;
capture_spec_requested.channels = 1;
capture_spec_requested.samples = 1024;
capture_spec_requested.callback = [](void * userdata, uint8_t * stream, int len) {
audio_async * audio = (audio_async *) userdata;
audio->callback(stream, len);
};
capture_spec_requested.userdata = this;
if (capture_id >= 0) {
fprintf(stderr, "%s: attempt to open capture device %d : '%s' ...\n", __func__, capture_id, SDL_GetAudioDeviceName(capture_id, SDL_TRUE));
m_dev_id_in = SDL_OpenAudioDevice(SDL_GetAudioDeviceName(capture_id, SDL_TRUE), SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
} else {
fprintf(stderr, "%s: attempt to open default capture device ...\n", __func__);
m_dev_id_in = SDL_OpenAudioDevice(nullptr, SDL_TRUE, &capture_spec_requested, &capture_spec_obtained, 0);
}
if (!m_dev_id_in) {
fprintf(stderr, "%s: couldn't open an audio device for capture: %s!\n", __func__, SDL_GetError());
m_dev_id_in = 0;
return false;
} else {
fprintf(stderr, "%s: obtained spec for input device (SDL Id = %d):\n", __func__, m_dev_id_in);
fprintf(stderr, "%s: - sample rate: %d\n", __func__, capture_spec_obtained.freq);
fprintf(stderr, "%s: - format: %d (required: %d)\n", __func__, capture_spec_obtained.format,
capture_spec_requested.format);
fprintf(stderr, "%s: - channels: %d (required: %d)\n", __func__, capture_spec_obtained.channels,
capture_spec_requested.channels);
fprintf(stderr, "%s: - samples per frame: %d\n", __func__, capture_spec_obtained.samples);
fprintf(stderr, "\n");
}
m_sample_rate = capture_spec_obtained.freq;
m_audio.resize((m_sample_rate*m_len_ms)/1000);
return true;
}
bool audio_async::resume() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to resume!\n", __func__);
return false;
}
if (m_running) {
fprintf(stderr, "%s: already running!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 0);
m_running = true;
return true;
}
bool audio_async::pause() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to pause!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: already paused!\n", __func__);
return false;
}
SDL_PauseAudioDevice(m_dev_id_in, 1);
m_running = false;
return true;
}
bool audio_async::clear() {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to clear!\n", __func__);
return false;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return false;
}
{
std::lock_guard<std::mutex> lock(m_mutex);
m_audio_pos = 0;
m_audio_len = 0;
}
return true;
}
// callback to be called by SDL
void audio_async::callback(uint8_t * stream, int len) {
if (!m_running) {
return;
}
const size_t n_samples = len / sizeof(float);
m_audio_new.resize(n_samples);
memcpy(m_audio_new.data(), stream, n_samples * sizeof(float));
//fprintf(stderr, "%s: %zu samples, pos %zu, len %zu\n", __func__, n_samples, m_audio_pos, m_audio_len);
{
std::lock_guard<std::mutex> lock(m_mutex);
if (m_audio_pos + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - m_audio_pos;
memcpy(&m_audio[m_audio_pos], stream, n0 * sizeof(float));
memcpy(&m_audio[0], &stream[n0], (n_samples - n0) * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = m_audio.size();
} else {
memcpy(&m_audio[m_audio_pos], stream, n_samples * sizeof(float));
m_audio_pos = (m_audio_pos + n_samples) % m_audio.size();
m_audio_len = std::min(m_audio_len + n_samples, m_audio.size());
}
}
}
void audio_async::get(int ms, std::vector<float> & result) {
if (!m_dev_id_in) {
fprintf(stderr, "%s: no audio device to get audio from!\n", __func__);
return;
}
if (!m_running) {
fprintf(stderr, "%s: not running!\n", __func__);
return;
}
result.clear();
{
std::lock_guard<std::mutex> lock(m_mutex);
if (ms <= 0) {
ms = m_len_ms;
}
size_t n_samples = (m_sample_rate * ms) / 1000;
if (n_samples > m_audio_len) {
n_samples = m_audio_len;
}
result.resize(n_samples);
int s0 = m_audio_pos - n_samples;
if (s0 < 0) {
s0 += m_audio.size();
}
if (s0 + n_samples > m_audio.size()) {
const size_t n0 = m_audio.size() - s0;
memcpy(result.data(), &m_audio[s0], n0 * sizeof(float));
memcpy(&result[n0], &m_audio[0], (n_samples - n0) * sizeof(float));
} else {
memcpy(result.data(), &m_audio[s0], n_samples * sizeof(float));
}
}
}
///////////////////////////
std::string trim(const std::string & s) {
std::regex e("^\\s+|\\s+$");
return std::regex_replace(s, e, "");
}
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
std::string result = s;
size_t pos = 0;
while ((pos = result.find(from, pos)) != std::string::npos) {
result.replace(pos, from.length(), to);
pos += to.length();
}
return result;
}
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
const float rc = 1.0f / (2.0f * M_PI * cutoff);
const float dt = 1.0f / sample_rate;
const float alpha = dt / (rc + dt);
float y = data[0];
for (size_t i = 1; i < data.size(); i++) {
y = alpha * (y + data[i] - data[i - 1]);
data[i] = y;
}
}
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
const int n_samples = pcmf32.size();
const int n_samples_last = (sample_rate * last_ms) / 1000;
if (n_samples_last >= n_samples) {
// not enough samples - assume no speech
return false;
}
if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, freq_thold, sample_rate);
}
float energy_all = 0.0f;
float energy_last = 0.0f;
for (int i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
if (i >= n_samples - n_samples_last) {
energy_last += fabsf(pcmf32[i]);
}
}
energy_all /= n_samples;
energy_last /= n_samples_last;
if (verbose) {
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
}
if (energy_last > vad_thold*energy_all) {
return false;
}
return true;
}
std::string transcribe(whisper_context * ctx, const whisper_params & params, const std::vector<float> & pcmf32, float & prob, int64_t & t_ms) {
const auto t_start = std::chrono::high_resolution_clock::now();
@ -557,22 +241,10 @@ int main(int argc, char ** argv) {
// main loop
while (is_running) {
// handle Ctrl + C
{
SDL_Event event;
while (SDL_PollEvent(&event)) {
switch (event.type) {
case SDL_QUIT:
{
is_running = false;
} break;
default:
break;
}
}
is_running = sdl_poll_events();
if (!is_running) {
break;
}
if (!is_running) {
break;
}
// delay
@ -583,7 +255,7 @@ int main(int argc, char ** argv) {
{
audio.get(2000, pcmf32_cur);
if (vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
if (::vad_simple(pcmf32_cur, WHISPER_SAMPLE_RATE, 1250, params.vad_thold, params.freq_thold, params.print_energy) || force_speak) {
fprintf(stdout, "%s: Speech detected! Processing ...\n", __func__);
audio.get(params.voice_ms, pcmf32_cur);