From 08fa34882f4add2f3ed709ad980158c21432003e Mon Sep 17 00:00:00 2001 From: bobqianic <129547291+bobqianic@users.noreply.github.com> Date: Tue, 3 Oct 2023 20:56:11 +0100 Subject: [PATCH] examples : move wav_writer from stream.cpp to common.h (#1317) * Allocate class on the stack instead of on the heap * Add class wav_writer * fix some minor issues * fix some minor issues * remove potential misleading API --- examples/common.h | 100 +++++++++++++++++++++++++++++++++++++ examples/stream/stream.cpp | 62 ++--------------------- 2 files changed, 105 insertions(+), 57 deletions(-) diff --git a/examples/common.h b/examples/common.h index 7c671588..698ee0bd 100644 --- a/examples/common.h +++ b/examples/common.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #define COMMON_SAMPLE_RATE 16000 @@ -139,6 +141,104 @@ bool read_wav( std::vector> & pcmf32s, bool stereo); +// Write PCM data into WAV audio file +class wav_writer { +private: + std::ofstream file; + uint32_t dataSize = 0; + std::string wav_filename; + + bool write_header(const uint32_t sample_rate, + const uint16_t bits_per_sample, + const uint16_t channels) { + + file.write("RIFF", 4); + file.write("\0\0\0\0", 4); // Placeholder for file size + file.write("WAVE", 4); + file.write("fmt ", 4); + + const uint32_t sub_chunk_size = 16; + const uint16_t audio_format = 1; // PCM format + const uint32_t byte_rate = sample_rate * channels * bits_per_sample / 8; + const uint16_t block_align = channels * bits_per_sample / 8; + + file.write(reinterpret_cast(&sub_chunk_size), 4); + file.write(reinterpret_cast(&audio_format), 2); + file.write(reinterpret_cast(&channels), 2); + file.write(reinterpret_cast(&sample_rate), 4); + file.write(reinterpret_cast(&byte_rate), 4); + file.write(reinterpret_cast(&block_align), 2); + file.write(reinterpret_cast(&bits_per_sample), 2); + file.write("data", 4); + file.write("\0\0\0\0", 4); // Placeholder for data size + + return true; + } + + // It is assumed that PCM data is normalized to a range from -1 to 1 + bool write_audio(const float * data, size_t length) { + for (size_t i = 0; i < length; ++i) { + const auto intSample = static_cast(data[i] * 32767); + file.write(reinterpret_cast(&intSample), sizeof(int16_t)); + dataSize += sizeof(int16_t); + } + if (file.is_open()) { + file.seekp(4, std::ios::beg); + uint32_t fileSize = 36 + dataSize; + file.write(reinterpret_cast(&fileSize), 4); + file.seekp(40, std::ios::beg); + file.write(reinterpret_cast(&dataSize), 4); + file.seekp(0, std::ios::end); + } + return true; + } + + bool open_wav(const std::string & filename) { + if (filename != wav_filename) { + if (file.is_open()) { + file.close(); + } + } + if (!file.is_open()) { + file.open(filename, std::ios::binary); + wav_filename = filename; + dataSize = 0; + } + return file.is_open(); + } + +public: + bool open(const std::string & filename, + const uint32_t sample_rate, + const uint16_t bits_per_sample, + const uint16_t channels) { + + if (open_wav(filename)) { + write_header(sample_rate, bits_per_sample, channels); + } else { + return false; + } + + return true; + } + + bool close() { + file.close(); + return true; + } + + bool write(const float * data, size_t length) { + return write_audio(data, length); + } + + ~wav_writer() { + if (file.is_open()) { + file.close(); + } + } +}; + + // Apply a high-pass frequency filter to PCM audio // Suppresses frequencies below cutoff Hz void high_pass_filter( diff --git a/examples/stream/stream.cpp b/examples/stream/stream.cpp index ad0131d5..c8a452d1 100644 --- a/examples/stream/stream.cpp +++ b/examples/stream/stream.cpp @@ -2,7 +2,6 @@ // // A very quick-n-dirty implementation serving mainly as a proof of concept. // -#include #include "common-sdl.h" #include "common.h" #include "whisper.h" @@ -13,60 +12,8 @@ #include #include #include -#include -class SimpleWavWriter { -private: - std::ofstream file; - int32_t dataSize = 0; -public: - SimpleWavWriter(const std::string &filename, int sampleRate, int bitsPerSample, int channels) { - file.open(filename, std::ios::binary); - - file.write("RIFF", 4); - file.write("\0\0\0\0", 4); // Placeholder for file size - file.write("WAVE", 4); - file.write("fmt ", 4); - - int32_t subChunkSize = 16; - int16_t audioFormat = 1; // PCM format - int32_t byteRate = sampleRate * channels * bitsPerSample / 8; - int16_t blockAlign = channels * bitsPerSample / 8; - - file.write(reinterpret_cast(&subChunkSize), 4); - file.write(reinterpret_cast(&audioFormat), 2); - file.write(reinterpret_cast(&channels), 2); - file.write(reinterpret_cast(&sampleRate), 4); - file.write(reinterpret_cast(&byteRate), 4); - file.write(reinterpret_cast(&blockAlign), 2); - file.write(reinterpret_cast(&bitsPerSample), 2); - file.write("data", 4); - file.write("\0\0\0\0", 4); // Placeholder for data size - } - - void writeData(const float *data, size_t length) { - for (size_t i = 0; i < length; ++i) { - int16_t intSample = static_cast(data[i] * 32767); - file.write(reinterpret_cast(&intSample), sizeof(int16_t)); - dataSize += sizeof(int16_t); - } - if (file.is_open()) { - file.seekp(4, std::ios::beg); - int32_t fileSize = 36 + dataSize; - file.write(reinterpret_cast(&fileSize), 4); - file.seekp(40, std::ios::beg); - file.write(reinterpret_cast(&dataSize), 4); - file.seekp(0, std::ios::end); - } - } - - ~SimpleWavWriter() { - if (file.is_open()) { - file.close(); - } - } -}; // 500 -> 00:05.000 // 6000 -> 01:00.000 std::string to_timestamp(int64_t t) { @@ -266,8 +213,9 @@ int main(int argc, char ** argv) { return 1; } } + + wav_writer wavWriter; // save wav file - SimpleWavWriter *wavWriter = nullptr; if (params.save_audio) { // Get current date/time for filename time_t now = time(0); @@ -275,7 +223,7 @@ int main(int argc, char ** argv) { strftime(buffer, sizeof(buffer), "%Y%m%d%H%M%S", localtime(&now)); std::string filename = std::string(buffer) + ".wav"; - wavWriter = new SimpleWavWriter(filename, WHISPER_SAMPLE_RATE, 16, 1); + wavWriter.open(filename, WHISPER_SAMPLE_RATE, 16, 1); } printf("[Start speaking]\n"); fflush(stdout); @@ -285,8 +233,8 @@ int main(int argc, char ** argv) { // main audio loop while (is_running) { - if (params.save_audio && wavWriter) { - wavWriter->writeData(pcmf32_new.data(), pcmf32_new.size()); + if (params.save_audio) { + wavWriter.write(pcmf32_new.data(), pcmf32_new.size()); } // handle Ctrl + C is_running = sdl_poll_events();