diff --git a/.github/workflows/bindings.yml b/.github/workflows/bindings-go.yml similarity index 93% rename from .github/workflows/bindings.yml rename to .github/workflows/bindings-go.yml index 02667edb..13f1950a 100644 --- a/.github/workflows/bindings.yml +++ b/.github/workflows/bindings-go.yml @@ -1,4 +1,4 @@ -name: Bindings Tests +name: Bindings Tests (Go) on: push: paths: diff --git a/.github/workflows/bindings-ruby.yml b/.github/workflows/bindings-ruby.yml new file mode 100644 index 00000000..902dfe6a --- /dev/null +++ b/.github/workflows/bindings-ruby.yml @@ -0,0 +1,22 @@ +name: Bindings Tests (Ruby) +on: + push: + paths: + - bindings/ruby/** + - whisper.h + pull_request: + paths: + - bindings/ruby/** + - whisper.h + +jobs: + ubuntu-latest: + runs-on: ubuntu-latest + steps: + - uses: ruby/setup-ruby@v1 + with: + ruby-version: '3.0' + - uses: actions/checkout@v1 + - run: | + cd bindings/ruby/ext + ruby extconf.rb && make diff --git a/bindings/ruby/ext/.gitignore b/bindings/ruby/ext/.gitignore new file mode 100644 index 00000000..7c9cb037 --- /dev/null +++ b/bindings/ruby/ext/.gitignore @@ -0,0 +1,7 @@ +Makefile +ggml.c +ggml.h +whisper.bundle +whisper.cpp +whisper.h +dr_wav.h diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb new file mode 100644 index 00000000..851c52db --- /dev/null +++ b/bindings/ruby/ext/extconf.rb @@ -0,0 +1,21 @@ +require 'mkmf' +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.cpp')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','whisper.h')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.h')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','ggml.c')} .") +system("cp #{File.join(File.dirname(__FILE__),'..','..','..','examples','dr_wav.h')} .") + + +# need to use c++ compiler flags +$CXXFLAGS << ' -std=c++11' +# Set to true when building binary gems +if enable_config('static-stdlib', false) + $LDFLAGS << ' -static-libgcc -static-libstdc++' +end + +if enable_config('march-tune-native', false) + $CFLAGS << ' -march=native -mtune=native' + $CXXFLAGS << ' -march=native -mtune=native' +end + +create_makefile('whisper') diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp new file mode 100644 index 00000000..e7416ba2 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -0,0 +1,426 @@ +#include +#include "ruby_whisper.h" +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" +#include +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#define BOOL_PARAMS_SETTER(self, prop, value) \ + ruby_whisper_params *rwp; \ + Data_Get_Struct(self, ruby_whisper_params, rwp); \ + if (value == Qfalse || value == Qnil) { \ + rwp->params.prop = false; \ + } else { \ + rwp->params.prop = true; \ + } \ + return value; \ + +#define BOOL_PARAMS_GETTER(self, prop) \ + ruby_whisper_params *rwp; \ + Data_Get_Struct(self, ruby_whisper_params, rwp); \ + if (rwp->params.prop) { \ + return Qtrue; \ + } else { \ + return Qfalse; \ + } + +VALUE mWhisper; +VALUE cContext; +VALUE cParams; + +static void ruby_whisper_free(ruby_whisper *rw) { + if (rw->context) { + whisper_free(rw->context); + rw->context = NULL; + } +} +static void ruby_whisper_params_free(ruby_whisper_params *rwp) { +} + +void rb_whisper_mark(ruby_whisper *rw) { + // call rb_gc_mark on any ruby references in rw +} + +void rb_whisper_free(ruby_whisper *rw) { + ruby_whisper_free(rw); + free(rw); +} + +void rb_whisper_params_mark(ruby_whisper_params *rwp) { +} + +void rb_whisper_params_free(ruby_whisper_params *rwp) { + ruby_whisper_params_free(rwp); + free(rwp); +} + +static VALUE ruby_whisper_allocate(VALUE klass) { + ruby_whisper *rw; + rw = ALLOC(ruby_whisper); + rw->context = NULL; + return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw); +} + +static VALUE ruby_whisper_params_allocate(VALUE klass) { + ruby_whisper_params *rwp; + rwp = ALLOC(ruby_whisper_params); + rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); +} + +static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { + ruby_whisper *rw; + VALUE whisper_model_file_path; + + // TODO: we can support init from buffer here too maybe another ruby object to expose + rb_scan_args(argc, argv, "01", &whisper_model_file_path); + Data_Get_Struct(self, ruby_whisper, rw); + + if (!rb_respond_to(whisper_model_file_path, rb_intern("to_s"))) { + rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); + } + rw->context = whisper_init_from_file(StringValueCStr(whisper_model_file_path)); + if (rw->context == nullptr) { + rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); + } + return self; +} + +/* + * transcribe a single file + * can emit to a block results + * + **/ +static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { + ruby_whisper *rw; + ruby_whisper_params *rwp; + VALUE wave_file_path, blk, params; + + rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk); + Data_Get_Struct(self, ruby_whisper, rw); + Data_Get_Struct(params, ruby_whisper_params, rwp); + + if (!rb_respond_to(wave_file_path, rb_intern("to_s"))) { + rb_raise(rb_eRuntimeError, "Expected file path to wave file"); + } + + std::string fname_inp = StringValueCStr(wave_file_path); + + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + // WAV input - this is directly from main.cpp example + { + drwav wav; + std::vector wav_data; // used for pipe input from stdin + + if (fname_inp == "-") { + { + uint8_t buf[1024]; + while (true) { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + wav_data.insert(wav_data.end(), buf, buf + n); + } + } + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return self; + } + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); + } else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) { + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); + return self; + } + + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str()); + return self; + } + + if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) { + fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str()); + return self; + } + + if (wav.sampleRate != WHISPER_SAMPLE_RATE) { + fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000); + return self; + } + + if (wav.bitsPerSample != 16) { + fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str()); + return self; + } + + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); + + std::vector pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } + } + + if (rwp->diarize) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (uint64_t i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; + pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; + } + } + } + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, void * user_data) { + bool is_aborted = *(bool*)user_data; + return !is_aborted; + }; + rwp->params.encoder_begin_callback_user_data = &is_aborted; + } + + if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { + fprintf(stderr, "failed to process audio\n"); + return self; + } + const int n_segments = whisper_full_n_segments(rw->context); + VALUE output = rb_str_new2(""); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(rw->context, i); + output = rb_str_concat(output, rb_str_new2(text)); + } + VALUE idCall = rb_intern("call"); + if (blk != Qnil) { + rb_funcall(blk, idCall, 1, output); + } + return self; +} + +/* + * params.language = "auto" | "en", etc... + */ +static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (value == Qfalse || value == Qnil) { + rwp->params.language = "auto"; + } else { + rwp->params.language = StringValueCStr(value); + } + return value; +} +static VALUE ruby_whisper_params_get_language(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (rwp->params.language) { + return rb_str_new2(rwp->params.language); + } else { + return rb_str_new2("auto"); + } +} +static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, translate, value) +} +static VALUE ruby_whisper_params_get_translate(VALUE self) { + BOOL_PARAMS_GETTER(self, translate) +} +static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, no_context, value) +} +static VALUE ruby_whisper_params_get_no_context(VALUE self) { + BOOL_PARAMS_GETTER(self, no_context) +} +static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, single_segment, value) +} +static VALUE ruby_whisper_params_get_single_segment(VALUE self) { + BOOL_PARAMS_GETTER(self, single_segment) +} +static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_special, value) +} +static VALUE ruby_whisper_params_get_print_special(VALUE self) { + BOOL_PARAMS_GETTER(self, print_special) +} +static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_progress, value) +} +static VALUE ruby_whisper_params_get_print_progress(VALUE self) { + BOOL_PARAMS_GETTER(self, print_progress) +} +static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_realtime, value) +} +static VALUE ruby_whisper_params_get_print_realtime(VALUE self) { + BOOL_PARAMS_GETTER(self, print_realtime) +} +static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, print_timestamps, value) +} +static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) { + BOOL_PARAMS_GETTER(self, print_timestamps) +} +static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, suppress_blank, value) +} +static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) { + BOOL_PARAMS_GETTER(self, suppress_blank) +} +static VALUE ruby_whisper_params_set_suppress_non_speech_tokens(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, suppress_non_speech_tokens, value) +} +static VALUE ruby_whisper_params_get_suppress_non_speech_tokens(VALUE self) { + BOOL_PARAMS_GETTER(self, suppress_non_speech_tokens) +} +static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) { + BOOL_PARAMS_GETTER(self, token_timestamps) +} +static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, token_timestamps, value) +} +static VALUE ruby_whisper_params_get_split_on_word(VALUE self) { + BOOL_PARAMS_GETTER(self, split_on_word) +} +static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, split_on_word, value) +} +static VALUE ruby_whisper_params_get_speed_up(VALUE self) { + BOOL_PARAMS_GETTER(self, speed_up) +} +static VALUE ruby_whisper_params_set_speed_up(VALUE self, VALUE value) { + BOOL_PARAMS_SETTER(self, speed_up, value) +} +static VALUE ruby_whisper_params_get_diarize(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (rwp->diarize) { + return Qtrue; + } else { + return Qfalse; + } +} +static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (value == Qfalse || value == Qnil) { + rwp->diarize = false; + } else { + rwp->diarize = true; + } \ + return value; +} + +static VALUE ruby_whisper_params_get_offset(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params.offset_ms); +} +static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.offset_ms = NUM2INT(value); + return value; +} +static VALUE ruby_whisper_params_get_duration(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params.duration_ms); +} +static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.duration_ms = NUM2INT(value); + return value; +} + +static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params.n_max_text_ctx); +} +static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.n_max_text_ctx = NUM2INT(value); + return value; +} + +void Init_whisper() { + mWhisper = rb_define_module("Whisper"); + cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); + cParams = rb_define_class_under(mWhisper, "Params", rb_cObject); + + rb_define_alloc_func(cContext, ruby_whisper_allocate); + rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); + + rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); + + rb_define_alloc_func(cParams, ruby_whisper_params_allocate); + + rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1); + rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0); + rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1); + rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0); + rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1); + rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0); + rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1); + rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0); + rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0); + rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1); + rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0); + rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1); + rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0); + rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1); + rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0); + rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1); + rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0); + rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1); + rb_define_method(cParams, "suppress_non_speech_tokens", ruby_whisper_params_get_suppress_non_speech_tokens, 0); + rb_define_method(cParams, "suppress_non_speech_tokens=", ruby_whisper_params_set_suppress_non_speech_tokens, 1); + rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0); + rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); + rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); + rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1); + rb_define_method(cParams, "speed_up", ruby_whisper_params_get_speed_up, 0); + rb_define_method(cParams, "speed_up=", ruby_whisper_params_set_speed_up, 1); + rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0); + rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1); + + rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0); + rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1); + rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0); + rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1); + + rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); + rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); +} +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h new file mode 100644 index 00000000..8c35b7cb --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper.h @@ -0,0 +1,15 @@ +#ifndef __RUBY_WHISPER_H +#define __RUBY_WHISPER_H + +#include "whisper.h" + +typedef struct { + struct whisper_context *context; +} ruby_whisper; + +typedef struct { + struct whisper_full_params params; + bool diarize; +} ruby_whisper_params; + +#endif diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb new file mode 100644 index 00000000..fa6a3e2d --- /dev/null +++ b/bindings/ruby/tests/test_whisper.rb @@ -0,0 +1,138 @@ +TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) +EXTDIR = File.join(TOPDIR, 'ext') +#$LIBDIR = File.join(TOPDIR, 'lib') +#$:.unshift(LIBDIR) +$:.unshift(EXTDIR) + +require 'whisper' +require 'test/unit' + +class TestWhisper < Test::Unit::TestCase + def setup + @params = Whisper::Params.new + end + + def test_language + @params.language = "en" + assert_equal @params.language, "en" + @params.language = "auto" + assert_equal @params.language, "auto" + end + + def test_offset + @params.offset = 10_000 + assert_equal @params.offset, 10_000 + @params.offset = 0 + assert_equal @params.offset, 0 + end + + def test_duration + @params.duration = 60_000 + assert_equal @params.duration, 60_000 + @params.duration = 0 + assert_equal @params.duration, 0 + end + + def test_max_text_tokens + @params.max_text_tokens = 300 + assert_equal @params.max_text_tokens, 300 + @params.max_text_tokens = 0 + assert_equal @params.max_text_tokens, 0 + end + + def test_translate + @params.translate = true + assert @params.translate + @params.translate = false + assert !@params.translate + end + + def test_no_context + @params.no_context = true + assert @params.no_context + @params.no_context = false + assert !@params.no_context + end + + def test_single_segment + @params.single_segment = true + assert @params.single_segment + @params.single_segment = false + assert !@params.single_segment + end + + def test_print_special + @params.print_special = true + assert @params.print_special + @params.print_special = false + assert !@params.print_special + end + + def test_print_progress + @params.print_progress = true + assert @params.print_progress + @params.print_progress = false + assert !@params.print_progress + end + + def test_print_realtime + @params.print_realtime = true + assert @params.print_realtime + @params.print_realtime = false + assert !@params.print_realtime + end + + def test_print_timestamps + @params.print_timestamps = true + assert @params.print_timestamps + @params.print_timestamps = false + assert !@params.print_timestamps + end + + def test_suppress_blank + @params.suppress_blank = true + assert @params.suppress_blank + @params.suppress_blank = false + assert !@params.suppress_blank + end + + def test_suppress_non_speech_tokens + @params.suppress_non_speech_tokens = true + assert @params.suppress_non_speech_tokens + @params.suppress_non_speech_tokens = false + assert !@params.suppress_non_speech_tokens + end + + def test_token_timestamps + @params.token_timestamps = true + assert @params.token_timestamps + @params.token_timestamps = false + assert !@params.token_timestamps + end + + def test_split_on_word + @params.split_on_word = true + assert @params.split_on_word + @params.split_on_word = false + assert !@params.split_on_word + end + + def test_speed_up + @params.speed_up = true + assert @params.speed_up + @params.speed_up = false + assert !@params.speed_up + end + + def test_whisper + @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) + params = Whisper::Params.new + params.print_timestamps = false + + jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper.transcribe(jfk, params) {|text| + assert_match /ask not what your country can do for you, ask what you can do for your country/, text + } + end + +end