mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-30 01:08:52 +00:00
163 lines
4.6 KiB
C++
163 lines
4.6 KiB
C++
|
#include "common.h"
|
||
|
|
||
|
// third-party utilities
|
||
|
// use your favorite implementations
|
||
|
#define DR_WAV_IMPLEMENTATION
|
||
|
#include "dr_wav.h"
|
||
|
|
||
|
#include <cmath>
|
||
|
#include <regex>
|
||
|
|
||
|
#ifndef M_PI
|
||
|
#define M_PI 3.14159265358979323846
|
||
|
#endif
|
||
|
|
||
|
std::string trim(const std::string & s) {
|
||
|
std::regex e("^\\s+|\\s+$");
|
||
|
return std::regex_replace(s, e, "");
|
||
|
}
|
||
|
|
||
|
std::string replace(const std::string & s, const std::string & from, const std::string & to) {
|
||
|
std::string result = s;
|
||
|
size_t pos = 0;
|
||
|
while ((pos = result.find(from, pos)) != std::string::npos) {
|
||
|
result.replace(pos, from.length(), to);
|
||
|
pos += to.length();
|
||
|
}
|
||
|
return result;
|
||
|
}
|
||
|
|
||
|
bool read_wav(const std::string & fname, std::vector<float>& pcmf32, std::vector<std::vector<float>>& pcmf32s, bool stereo) {
|
||
|
drwav wav;
|
||
|
std::vector<uint8_t> wav_data; // used for pipe input from stdin
|
||
|
|
||
|
if (fname == "-") {
|
||
|
{
|
||
|
uint8_t buf[1024];
|
||
|
while (true)
|
||
|
{
|
||
|
const size_t n = fread(buf, 1, sizeof(buf), stdin);
|
||
|
if (n == 0) {
|
||
|
break;
|
||
|
}
|
||
|
wav_data.insert(wav_data.end(), buf, buf + n);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
|
||
|
fprintf(stderr, "error: failed to open WAV file from stdin\n");
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
|
||
|
}
|
||
|
else if (drwav_init_file(&wav, fname.c_str(), nullptr) == false) {
|
||
|
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname.c_str());
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
if (wav.channels != 1 && wav.channels != 2) {
|
||
|
fprintf(stderr, "%s: WAV file '%s' must be mono or stereo\n", __func__, fname.c_str());
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
if (stereo && wav.channels != 2) {
|
||
|
fprintf(stderr, "%s: WAV file '%s' must be stereo for diarization\n", __func__, fname.c_str());
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
if (wav.sampleRate != COMMON_SAMPLE_RATE) {
|
||
|
fprintf(stderr, "%s: WAV file '%s' must be %i kHz\n", __func__, fname.c_str(), COMMON_SAMPLE_RATE/1000);
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
if (wav.bitsPerSample != 16) {
|
||
|
fprintf(stderr, "%s: WAV file '%s' must be 16-bit\n", __func__, fname.c_str());
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
|
||
|
|
||
|
std::vector<int16_t> pcm16;
|
||
|
pcm16.resize(n*wav.channels);
|
||
|
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
|
||
|
drwav_uninit(&wav);
|
||
|
|
||
|
// convert to mono, float
|
||
|
pcmf32.resize(n);
|
||
|
if (wav.channels == 1) {
|
||
|
for (uint64_t i = 0; i < n; i++) {
|
||
|
pcmf32[i] = float(pcm16[i])/32768.0f;
|
||
|
}
|
||
|
} else {
|
||
|
for (uint64_t i = 0; i < n; i++) {
|
||
|
pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if (stereo) {
|
||
|
// convert to stereo, float
|
||
|
pcmf32s.resize(2);
|
||
|
|
||
|
pcmf32s[0].resize(n);
|
||
|
pcmf32s[1].resize(n);
|
||
|
for (uint64_t i = 0; i < n; i++) {
|
||
|
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
|
||
|
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
void high_pass_filter(std::vector<float> & data, float cutoff, float sample_rate) {
|
||
|
const float rc = 1.0f / (2.0f * M_PI * cutoff);
|
||
|
const float dt = 1.0f / sample_rate;
|
||
|
const float alpha = dt / (rc + dt);
|
||
|
|
||
|
float y = data[0];
|
||
|
|
||
|
for (size_t i = 1; i < data.size(); i++) {
|
||
|
y = alpha * (y + data[i] - data[i - 1]);
|
||
|
data[i] = y;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
bool vad_simple(std::vector<float> & pcmf32, int sample_rate, int last_ms, float vad_thold, float freq_thold, bool verbose) {
|
||
|
const int n_samples = pcmf32.size();
|
||
|
const int n_samples_last = (sample_rate * last_ms) / 1000;
|
||
|
|
||
|
if (n_samples_last >= n_samples) {
|
||
|
// not enough samples - assume no speech
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
if (freq_thold > 0.0f) {
|
||
|
high_pass_filter(pcmf32, freq_thold, sample_rate);
|
||
|
}
|
||
|
|
||
|
float energy_all = 0.0f;
|
||
|
float energy_last = 0.0f;
|
||
|
|
||
|
for (int i = 0; i < n_samples; i++) {
|
||
|
energy_all += fabsf(pcmf32[i]);
|
||
|
|
||
|
if (i >= n_samples - n_samples_last) {
|
||
|
energy_last += fabsf(pcmf32[i]);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
energy_all /= n_samples;
|
||
|
energy_last /= n_samples_last;
|
||
|
|
||
|
if (verbose) {
|
||
|
fprintf(stderr, "%s: energy_all: %f, energy_last: %f, vad_thold: %f, freq_thold: %f\n", __func__, energy_all, energy_last, vad_thold, freq_thold);
|
||
|
}
|
||
|
|
||
|
if (energy_last > vad_thold*energy_all) {
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
return true;
|
||
|
}
|