diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index 13ff1f00..f66d8d65 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -24,14 +24,15 @@ require "whisper" whisper = Whisper::Context.new("base") -params = Whisper::Params.new -params.language = "en" -params.offset = 10_000 -params.duration = 60_000 -params.max_text_tokens = 300 -params.translate = true -params.print_timestamps = false -params.initial_prompt = "Initial prompt here." +params = Whisper::Params.new( + language: "en", + offset: 10_000, + duration: 60_000, + max_text_tokens: 300, + translate: true, + print_timestamps: false, + initial_prompt: "Initial prompt here." +) whisper.transcribe("path/to/audio.wav", params) do |whole_text| puts whole_text @@ -113,18 +114,18 @@ def format_time(time_ms) "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part] end -whisper.transcribe("path/to/audio.wav", params) - -whisper.each_segment.with_index do |segment, index| - line = "[%{nth}: %{st} --> %{ed}] %{text}" % { - nth: index + 1, - st: format_time(segment.start_time), - ed: format_time(segment.end_time), - text: segment.text - } - line << " (speaker turned)" if segment.speaker_next_turn? - puts line -end +whisper + .transcribe("path/to/audio.wav", params) + .each_segment.with_index do |segment, index| + line = "[%{nth}: %{st} --> %{ed}] %{text}" % { + nth: index + 1, + st: format_time(segment.start_time), + ed: format_time(segment.end_time), + text: segment.text + } + line << " (speaker turned)" if segment.speaker_next_turn? + puts line + end ``` @@ -215,10 +216,11 @@ reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, : samples = reader.enum_for(:each_buffer).map(&:samples).flatten whisper = Whisper::Context.new("base") -whisper.full(Whisper::Params.new, samples) -whisper.each_segment do |segment| - puts segment.text -end +whisper + .full(Whisper::Params.new, samples) + .each_segment do |segment| + puts segment.text + end ``` The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy. diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index 3a7809b7..0d52e88a 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -18,9 +18,11 @@ EXTSOURCES.each do |src| end CLEAN.include SOURCES -CLEAN.include FileList["ext/*.o", "ext/*.metal", "ext/whisper.{so,bundle,dll}"] +CLEAN.include FileList["ext/**/*.o", "ext/**/*.metal", "ext/**/*.tmp", "ext/whisper.{so,bundle,dll}"] -task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whispercpp.gemspec"] +SRC = FileList["ext/*.{c,cpp,h}"] + +task build: SOURCES directory "pkg" CLOBBER.include "pkg" @@ -29,14 +31,14 @@ LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"]) SO_FILE = File.join("ext", LIB_NAME) LIB_FILE = File.join("lib", LIB_NAME) -file "ext/Makefile" => ["ext/extconf.rb", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp"] + SOURCES do |t| - Dir.chdir "ext" do +file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t| + chdir "ext" do ruby "extconf.rb" end end file SO_FILE => "ext/Makefile" do |t| - Dir.chdir "ext" do + chdir "ext" do sh "make" end end @@ -54,7 +56,7 @@ end TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}" file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t| - Dir.chdir "tests/jfk_reader" do + chdir "tests/jfk_reader" do ruby "extconf.rb" sh "make" end diff --git a/bindings/ruby/ext/.gitignore b/bindings/ruby/ext/.gitignore index e96a8584..7703146f 100644 --- a/bindings/ruby/ext/.gitignore +++ b/bindings/ruby/ext/.gitignore @@ -4,10 +4,8 @@ whisper.bundle whisper.dll scripts/get-flags.mk *.o -*.c -*.cpp -*.h -*.m -*.metal -!ruby_whisper.cpp -!ruby_whisper.h +/*/**/*.c +/*/**/*.cpp +/*/**/*.h +/*/**/*.m +/*/**/*.metal diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 6ffac109..af50904d 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -174,7 +174,14 @@ $OBJ_WHISPER << 'src/whisper.o' $objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL -$objs << "ruby_whisper.o" +$objs << + "ruby_whisper.o" << + "ruby_whisper_context.o" << + "ruby_whisper_transcribe.o" << + "ruby_whisper_params.o" << + "ruby_whisper_error.o" << + "ruby_whisper_segment.o" << + "ruby_whisper_model.o" $CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}" $CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}" diff --git a/bindings/ruby/ext/ruby_whisper.c b/bindings/ruby/ext/ruby_whisper.c new file mode 100644 index 00000000..43227786 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper.c @@ -0,0 +1,164 @@ +#include +#include +#include "ruby_whisper.h" + +VALUE mWhisper; +VALUE cContext; +VALUE cParams; +VALUE eError; + +VALUE cSegment; +VALUE cModel; + +ID id_to_s; +ID id_call; +ID id___method__; +ID id_to_enum; +ID id_length; +ID id_next; +ID id_new; +ID id_to_path; +ID id_URI; +ID id_pre_converted_models; + +static bool is_log_callback_finalized = false; + +// High level API +extern VALUE ruby_whisper_segment_allocate(VALUE klass); + +extern void init_ruby_whisper_context(VALUE *mWhisper); +extern void init_ruby_whisper_params(VALUE *mWhisper); +extern void init_ruby_whisper_error(VALUE *mWhisper); +extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment); +extern void init_ruby_whisper_model(VALUE *mWhisper); +extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context); + +/* + * call-seq: + * lang_max_id -> Integer + */ +static VALUE ruby_whisper_s_lang_max_id(VALUE self) { + return INT2NUM(whisper_lang_max_id()); +} + +/* + * call-seq: + * lang_id(lang_name) -> Integer + */ +static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) { + const char * lang_str = StringValueCStr(lang); + const int id = whisper_lang_id(lang_str); + if (-1 == id) { + rb_raise(rb_eArgError, "language not found: %s", lang_str); + } + return INT2NUM(id); +} + +/* + * call-seq: + * lang_str(lang_id) -> String + */ +static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) { + const int lang_id = NUM2INT(id); + const char * str = whisper_lang_str(lang_id); + if (NULL == str) { + rb_raise(rb_eIndexError, "id %d outside of language id", lang_id); + } + return rb_str_new2(str); +} + +/* + * call-seq: + * lang_str(lang_id) -> String + */ +static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) { + const int lang_id = NUM2INT(id); + const char * str_full = whisper_lang_str_full(lang_id); + if (NULL == str_full) { + rb_raise(rb_eIndexError, "id %d outside of language id", lang_id); + } + return rb_str_new2(str_full); +} + +static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { + is_log_callback_finalized = true; + return Qnil; +} + +static void +ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) { + if (is_log_callback_finalized) { + return; + } + VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); + VALUE udata = rb_iv_get(mWhisper, "user_data"); + rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata); +} + +/* + * call-seq: + * log_set ->(level, buffer, user_data) { ... }, user_data -> nil + */ +static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) { + VALUE old_callback = rb_iv_get(self, "log_callback"); + if (!NIL_P(old_callback)) { + rb_undefine_finalizer(old_callback); + } + + rb_iv_set(self, "log_callback", log_callback); + rb_iv_set(self, "user_data", user_data); + + VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); + rb_define_finalizer(log_callback, finalize_log_callback); + + whisper_log_set(ruby_whisper_log_callback, NULL); + + return Qnil; +} + +static void rb_whisper_model_mark(ruby_whisper_model *rwm) { + rb_gc_mark(rwm->context); +} + +static VALUE ruby_whisper_model_allocate(VALUE klass) { + ruby_whisper_model *rwm; + rwm = ALLOC(ruby_whisper_model); + return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm); +} + +void Init_whisper() { + id_to_s = rb_intern("to_s"); + id_call = rb_intern("call"); + id___method__ = rb_intern("__method__"); + id_to_enum = rb_intern("to_enum"); + id_length = rb_intern("length"); + id_next = rb_intern("next"); + id_new = rb_intern("new"); + id_to_path = rb_intern("to_path"); + id_URI = rb_intern("URI"); + id_pre_converted_models = rb_intern("pre_converted_models"); + + mWhisper = rb_define_module("Whisper"); + + rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE)); + rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO)); + rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN)); + rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR)); + rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG)); + rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT)); + + rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0); + rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1); + rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); + rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); + rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); + rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); + + init_ruby_whisper_context(&mWhisper); + init_ruby_whisper_params(&mWhisper); + init_ruby_whisper_error(&mWhisper); + init_ruby_whisper_segment(&mWhisper, &cContext); + init_ruby_whisper_model(&mWhisper); + + rb_require("whisper/model/uri"); +} diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp deleted file mode 100644 index 5979f208..00000000 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ /dev/null @@ -1,1962 +0,0 @@ -#include -#include -#include "ruby_whisper.h" -#define DR_WAV_IMPLEMENTATION -#include "dr_wav.h" -#include -#include -#include -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -#define BOOL_PARAMS_SETTER(self, prop, value) \ - ruby_whisper_params *rwp; \ - Data_Get_Struct(self, ruby_whisper_params, rwp); \ - if (value == Qfalse || value == Qnil) { \ - rwp->params.prop = false; \ - } else { \ - rwp->params.prop = true; \ - } \ - return value; \ - -#define BOOL_PARAMS_GETTER(self, prop) \ - ruby_whisper_params *rwp; \ - Data_Get_Struct(self, ruby_whisper_params, rwp); \ - if (rwp->params.prop) { \ - return Qtrue; \ - } else { \ - return Qfalse; \ - } - -VALUE mWhisper; -VALUE cContext; -VALUE cParams; -VALUE eError; - -VALUE cSegment; -VALUE cModel; - -static ID id_to_s; -static ID id_call; -static ID id___method__; -static ID id_to_enum; -static ID id_length; -static ID id_next; -static ID id_new; -static ID id_to_path; -static ID id_URI; -static ID id_pre_converted_models; - -static bool is_log_callback_finalized = false; - -// High level API -static VALUE rb_whisper_segment_initialize(VALUE context, int index); - -/* - * call-seq: - * lang_max_id -> Integer - */ -static VALUE ruby_whisper_s_lang_max_id(VALUE self) { - return INT2NUM(whisper_lang_max_id()); -} - -/* - * call-seq: - * lang_id(lang_name) -> Integer - */ -static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) { - const char * lang_str = StringValueCStr(lang); - const int id = whisper_lang_id(lang_str); - if (-1 == id) { - rb_raise(rb_eArgError, "language not found: %s", lang_str); - } - return INT2NUM(id); -} - -/* - * call-seq: - * lang_str(lang_id) -> String - */ -static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) { - const int lang_id = NUM2INT(id); - const char * str = whisper_lang_str(lang_id); - if (nullptr == str) { - rb_raise(rb_eIndexError, "id %d outside of language id", lang_id); - } - return rb_str_new2(str); -} - -/* - * call-seq: - * lang_str(lang_id) -> String - */ -static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) { - const int lang_id = NUM2INT(id); - const char * str_full = whisper_lang_str_full(lang_id); - if (nullptr == str_full) { - rb_raise(rb_eIndexError, "id %d outside of language id", lang_id); - } - return rb_str_new2(str_full); -} - -static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) { - is_log_callback_finalized = true; - return Qnil; -} - -/* - * call-seq: - * log_set ->(level, buffer, user_data) { ... }, user_data -> nil - */ -static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) { - VALUE old_callback = rb_iv_get(self, "log_callback"); - if (!NIL_P(old_callback)) { - rb_undefine_finalizer(old_callback); - } - - rb_iv_set(self, "log_callback", log_callback); - rb_iv_set(self, "user_data", user_data); - - VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback")); - rb_define_finalizer(log_callback, finalize_log_callback); - - whisper_log_set([](ggml_log_level level, const char * buffer, void * user_data) { - if (is_log_callback_finalized) { - return; - } - VALUE log_callback = rb_iv_get(mWhisper, "log_callback"); - VALUE udata = rb_iv_get(mWhisper, "user_data"); - rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata); - }, nullptr); - - return Qnil; -} - -static void ruby_whisper_free(ruby_whisper *rw) { - if (rw->context) { - whisper_free(rw->context); - rw->context = NULL; - } -} - -static void ruby_whisper_params_free(ruby_whisper_params *rwp) { -} - -void rb_whisper_mark(ruby_whisper *rw) { - // call rb_gc_mark on any ruby references in rw -} - -void rb_whisper_free(ruby_whisper *rw) { - ruby_whisper_free(rw); - free(rw); -} - -void rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) { - rb_gc_mark(rwc->user_data); - rb_gc_mark(rwc->callback); - rb_gc_mark(rwc->callbacks); -} - -void rb_whisper_params_mark(ruby_whisper_params *rwp) { - rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); - rb_whisper_callbcack_container_mark(rwp->progress_callback_container); - rb_whisper_callbcack_container_mark(rwp->abort_callback_container); -} - -void rb_whisper_params_free(ruby_whisper_params *rwp) { - // How to free user_data and callback only when not referred to by others? - ruby_whisper_params_free(rwp); - free(rwp); -} - -static VALUE ruby_whisper_allocate(VALUE klass) { - ruby_whisper *rw; - rw = ALLOC(ruby_whisper); - rw->context = NULL; - return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw); -} - -static ruby_whisper_callback_container * rb_whisper_callback_container_allocate() { - ruby_whisper_callback_container *container; - container = ALLOC(ruby_whisper_callback_container); - container->context = nullptr; - container->user_data = Qnil; - container->callback = Qnil; - container->callbacks = rb_ary_new(); - return container; -} - -static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - - // Currently, doesn't support state because - // those require to resolve GC-related problems. - if (!NIL_P(container->callback)) { - rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); - } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return; - } - const int n_segments = whisper_full_n_segments_from_state(state); - for (int i = n_new; i > 0; i--) { - int i_segment = n_segments - i; - VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment); - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - rb_funcall(cb, id_call, 1, segment); - } - } -} - -static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - const VALUE progress = INT2NUM(progress_cur); - // Currently, doesn't support state because - // those require to resolve GC-related problems. - if (!NIL_P(container->callback)) { - rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data); - } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return; - } - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - rb_funcall(cb, id_call, 1, progress); - } -} - -static bool abort_callback(void * user_data) { - const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; - if (!NIL_P(container->callback)) { - VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { - return true; - } - } - const long callbacks_len = RARRAY_LEN(container->callbacks); - if (0 == callbacks_len) { - return false; - } - for (int j = 0; j < callbacks_len; j++) { - VALUE cb = rb_ary_entry(container->callbacks, j); - VALUE result = rb_funcall(cb, id_call, 1, container->user_data); - if (!NIL_P(result) && Qfalse != result) { - return true; - } - } - return false; -} - -static VALUE ruby_whisper_params_allocate(VALUE klass) { - ruby_whisper_params *rwp; - rwp = ALLOC(ruby_whisper_params); - rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); - rwp->diarize = false; - rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); - rwp->progress_callback_container = rb_whisper_callback_container_allocate(); - rwp->abort_callback_container = rb_whisper_callback_container_allocate(); - return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); -} - -/* - * call-seq: - * new("base.en") -> Whisper::Context - * new("path/to/model.bin") -> Whisper::Context - * new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context - */ -static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { - ruby_whisper *rw; - VALUE whisper_model_file_path; - - // TODO: we can support init from buffer here too maybe another ruby object to expose - rb_scan_args(argc, argv, "01", &whisper_model_file_path); - Data_Get_Struct(self, ruby_whisper, rw); - - VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0); - VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path); - if (!NIL_P(pre_converted_model)) { - whisper_model_file_path = pre_converted_model; - } - if (TYPE(whisper_model_file_path) == T_STRING) { - const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path); - if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) { - VALUE uri_class = rb_const_get(cModel, id_URI); - whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class); - } - } - if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) { - VALUE uri_class = rb_const_get(cModel, id_URI); - whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class); - } - if (rb_respond_to(whisper_model_file_path, id_to_path)) { - whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0); - } - if (!rb_respond_to(whisper_model_file_path, id_to_s)) { - rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); - } - rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); - if (rw->context == nullptr) { - rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); - } - return self; -} - -static void register_callbacks(ruby_whisper_params * rwp, VALUE * self) { - if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { - rwp->new_segment_callback_container->context = self; - rwp->params.new_segment_callback = new_segment_callback; - rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; - } - - if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { - rwp->progress_callback_container->context = self; - rwp->params.progress_callback = progress_callback; - rwp->params.progress_callback_user_data = rwp->progress_callback_container; - } - - if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { - rwp->abort_callback_container->context = self; - rwp->params.abort_callback = abort_callback; - rwp->params.abort_callback_user_data = rwp->abort_callback_container; - } -} - -/* - * transcribe a single file - * can emit to a block results - * - * params = Whisper::Params.new - * params.duration = 60_000 - * whisper.transcribe "path/to/audio.wav", params do |text| - * puts text - * end - * - * call-seq: - * transcribe(path_to_audio, params) {|text| ...} - **/ -static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { - ruby_whisper *rw; - ruby_whisper_params *rwp; - VALUE wave_file_path, blk, params; - - rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk); - Data_Get_Struct(self, ruby_whisper, rw); - Data_Get_Struct(params, ruby_whisper_params, rwp); - - if (!rb_respond_to(wave_file_path, id_to_s)) { - rb_raise(rb_eRuntimeError, "Expected file path to wave file"); - } - - std::string fname_inp = StringValueCStr(wave_file_path); - - std::vector pcmf32; // mono-channel F32 PCM - std::vector> pcmf32s; // stereo-channel F32 PCM - - // WAV input - this is directly from main.cpp example - { - drwav wav; - std::vector wav_data; // used for pipe input from stdin - - if (fname_inp == "-") { - { - uint8_t buf[1024]; - while (true) { - const size_t n = fread(buf, 1, sizeof(buf), stdin); - if (n == 0) { - break; - } - wav_data.insert(wav_data.end(), buf, buf + n); - } - } - - if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { - fprintf(stderr, "error: failed to open WAV file from stdin\n"); - return self; - } - - fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); - } else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) { - fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); - return self; - } - - if (wav.channels != 1 && wav.channels != 2) { - fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str()); - return self; - } - - if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) { - fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str()); - return self; - } - - if (wav.sampleRate != WHISPER_SAMPLE_RATE) { - fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000); - return self; - } - - if (wav.bitsPerSample != 16) { - fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str()); - return self; - } - - const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); - - std::vector pcm16; - pcm16.resize(n*wav.channels); - drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); - drwav_uninit(&wav); - - // convert to mono, float - pcmf32.resize(n); - if (wav.channels == 1) { - for (uint64_t i = 0; i < n; i++) { - pcmf32[i] = float(pcm16[i])/32768.0f; - } - } else { - for (uint64_t i = 0; i < n; i++) { - pcmf32[i] = float(pcm16[2*i] + pcm16[2*i + 1])/65536.0f; - } - } - - if (rwp->diarize) { - // convert to stereo, float - pcmf32s.resize(2); - - pcmf32s[0].resize(n); - pcmf32s[1].resize(n); - for (uint64_t i = 0; i < n; i++) { - pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; - pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; - } - } - } - { - static bool is_aborted = false; // NOTE: this should be atomic to avoid data race - - rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { - bool is_aborted = *(bool*)user_data; - return !is_aborted; - }; - rwp->params.encoder_begin_callback_user_data = &is_aborted; - } - - register_callbacks(rwp, &self); - - if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { - fprintf(stderr, "failed to process audio\n"); - return self; - } - const int n_segments = whisper_full_n_segments(rw->context); - VALUE output = rb_str_new2(""); - for (int i = 0; i < n_segments; ++i) { - const char * text = whisper_full_get_segment_text(rw->context, i); - output = rb_str_concat(output, rb_str_new2(text)); - } - VALUE idCall = id_call; - if (blk != Qnil) { - rb_funcall(blk, idCall, 1, output); - } - return self; -} - -/* - * call-seq: - * model_n_vocab -> Integer - */ -VALUE ruby_whisper_model_n_vocab(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_n_vocab(rw->context)); -} - -/* - * call-seq: - * model_n_audio_ctx -> Integer - */ -VALUE ruby_whisper_model_n_audio_ctx(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_n_audio_ctx(rw->context)); -} - -/* - * call-seq: - * model_n_audio_state -> Integer - */ -VALUE ruby_whisper_model_n_audio_state(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_n_audio_state(rw->context)); -} - -/* - * call-seq: - * model_n_audio_head -> Integer - */ -VALUE ruby_whisper_model_n_audio_head(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_n_audio_head(rw->context)); -} - -/* - * call-seq: - * model_n_audio_layer -> Integer - */ -VALUE ruby_whisper_model_n_audio_layer(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_n_audio_layer(rw->context)); -} - -/* - * call-seq: - * model_n_text_ctx -> Integer - */ -VALUE ruby_whisper_model_n_text_ctx(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_n_text_ctx(rw->context)); -} - -/* - * call-seq: - * model_n_text_state -> Integer - */ -VALUE ruby_whisper_model_n_text_state(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_n_text_state(rw->context)); -} - -/* - * call-seq: - * model_n_text_head -> Integer - */ -VALUE ruby_whisper_model_n_text_head(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_n_text_head(rw->context)); -} - -/* - * call-seq: - * model_n_text_layer -> Integer - */ -VALUE ruby_whisper_model_n_text_layer(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_n_text_layer(rw->context)); -} - -/* - * call-seq: - * model_n_mels -> Integer - */ -VALUE ruby_whisper_model_n_mels(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_n_mels(rw->context)); -} - -/* - * call-seq: - * model_ftype -> Integer - */ -VALUE ruby_whisper_model_ftype(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_model_ftype(rw->context)); -} - -/* - * call-seq: - * model_type -> String - */ -VALUE ruby_whisper_model_type(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return rb_str_new2(whisper_model_type_readable(rw->context)); -} - -/* - * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text - * Not thread safe for same context - * Uses the specified decoding strategy to obtain the text. - * - * call-seq: - * full(params, samples, n_samples) -> nil - * full(params, samples) -> nil - * - * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. - */ -VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) { - if (argc < 2 || argc > 3) { - rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); - } - - ruby_whisper *rw; - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper, rw); - VALUE params = argv[0]; - Data_Get_Struct(params, ruby_whisper_params, rwp); - VALUE samples = argv[1]; - int n_samples; - rb_memory_view_t view; - const bool memory_view_available_p = rb_memory_view_available_p(samples); - if (argc == 3) { - n_samples = NUM2INT(argv[2]); - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) < n_samples) { - rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); - } - } - // Should check when samples.respond_to?(:length)? - } else { - if (TYPE(samples) == T_ARRAY) { - n_samples = RARRAY_LEN(samples); - } else if (memory_view_available_p) { - if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { - view.obj = Qnil; - rb_raise(rb_eArgError, "unable to get a memory view"); - } - n_samples = view.byte_size / view.item_size; - } else if (rb_respond_to(samples, id_length)) { - n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); - } else { - rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); - } - } - float * c_samples = (float *)malloc(n_samples * sizeof(float)); - if (memory_view_available_p) { - c_samples = (float *)view.data; - } else { - if (TYPE(samples) == T_ARRAY) { - for (int i = 0; i < n_samples; i++) { - c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); - } - } else { - // TODO: use rb_block_call - VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); - for (int i = 0; i < n_samples; i++) { - // TODO: check if iter is exhausted and raise ArgumentError appropriately - VALUE sample = rb_funcall(iter, id_next, 0); - c_samples[i] = RFLOAT_VALUE(sample); - } - } - } - register_callbacks(rwp, &self); - const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples); - if (0 == result) { - return Qnil; - } else { - rb_exc_raise(rb_funcall(eError, id_new, 1, result)); - } -} - -/* - * Split the input audio in chunks and process each chunk separately using whisper_full_with_state() - * Result is stored in the default state of the context - * Not thread safe if executed in parallel on the same context. - * It seems this approach can offer some speedup in some cases. - * However, the transcription accuracy can be worse at the beginning and end of each chunk. - * - * call-seq: - * full_parallel(params, samples) -> nil - * full_parallel(params, samples, n_samples) -> nil - * full_parallel(params, samples, n_samples, n_processors) -> nil - * full_parallel(params, samples, nil, n_processors) -> nil - */ -static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) { - if (argc < 2 || argc > 4) { - rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); - } - - ruby_whisper *rw; - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper, rw); - VALUE params = argv[0]; - Data_Get_Struct(params, ruby_whisper_params, rwp); - VALUE samples = argv[1]; - int n_samples; - int n_processors; - rb_memory_view_t view; - const bool memory_view_available_p = rb_memory_view_available_p(samples); - switch (argc) { - case 2: - n_processors = 1; - break; - case 3: - n_processors = 1; - break; - case 4: - n_processors = NUM2INT(argv[3]); - break; - } - if (argc >= 3 && !NIL_P(argv[2])) { - n_samples = NUM2INT(argv[2]); - if (TYPE(samples) == T_ARRAY) { - if (RARRAY_LEN(samples) < n_samples) { - rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); - } - } - // Should check when samples.respond_to?(:length)? - } else if (memory_view_available_p) { - if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { - view.obj = Qnil; - rb_raise(rb_eArgError, "unable to get a memory view"); - } - n_samples = view.byte_size / view.item_size; - } else { - if (TYPE(samples) == T_ARRAY) { - n_samples = RARRAY_LEN(samples); - } else if (rb_respond_to(samples, id_length)) { - n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); - } else { - rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); - } - } - float * c_samples = (float *)malloc(n_samples * sizeof(float)); - if (memory_view_available_p) { - c_samples = (float *)view.data; - } else { - if (TYPE(samples) == T_ARRAY) { - for (int i = 0; i < n_samples; i++) { - c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); - } - } else { - // FIXME: use rb_block_call - VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); - for (int i = 0; i < n_samples; i++) { - // TODO: check if iter is exhausted and raise ArgumentError - VALUE sample = rb_funcall(iter, id_next, 0); - c_samples[i] = RFLOAT_VALUE(sample); - } - } - } - register_callbacks(rwp, &self); - const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors); - if (0 == result) { - return Qnil; - } else { - rb_exc_raise(rb_funcall(eError, id_new, 1, result)); - } -} - -/* - * Number of segments. - * - * call-seq: - * full_n_segments -> Integer - */ -static VALUE ruby_whisper_full_n_segments(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_full_n_segments(rw->context)); -} - -/* - * Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full. - * - * call-seq: - * full_lang_id -> Integer - */ -static VALUE ruby_whisper_full_lang_id(VALUE self) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - return INT2NUM(whisper_full_lang_id(rw->context)); -} - -static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) { - const int c_i_segment = NUM2INT(i_segment); - if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) { - rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment); - } - return c_i_segment; -} - -/* - * Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). - * - * full_get_segment_t0(3) # => 1668 (16680 ms) - * - * call-seq: - * full_get_segment_t0(segment_index) -> Integer - */ -static VALUE ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); - const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment); - return INT2NUM(t0); -} - -/* - * End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). - * - * full_get_segment_t1(3) # => 1668 (16680 ms) - * - * call-seq: - * full_get_segment_t1(segment_index) -> Integer - */ -static VALUE ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); - const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment); - return INT2NUM(t1); -} - -/* - * Whether the next segment indexed by +segment_index+ is predicated as a speaker turn. - * - * full_get_segment_speacker_turn_next(3) # => true - * - * call-seq: - * full_get_segment_speacker_turn_next(segment_index) -> bool - */ -static VALUE ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); - const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment); - return speaker_turn_next ? Qtrue : Qfalse; -} - -/* - * Text of a segment indexed by +segment_index+. - * - * full_get_segment_text(3) # => "ask not what your country can do for you, ..." - * - * call-seq: - * full_get_segment_text(segment_index) -> String - */ -static VALUE ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); - const char * text = whisper_full_get_segment_text(rw->context, c_i_segment); - return rb_str_new2(text); -} - -/* - * call-seq: - * full_get_segment_no_speech_prob(segment_index) -> Float - */ -static VALUE ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) { - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); - const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment); - return DBL2NUM(no_speech_prob); -} - -/* - * params.language = "auto" | "en", etc... - * - * call-seq: - * language = lang_name -> lang_name - */ -static VALUE ruby_whisper_params_set_language(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - if (value == Qfalse || value == Qnil) { - rwp->params.language = "auto"; - } else { - rwp->params.language = StringValueCStr(value); - } - return value; -} -/* - * call-seq: - * language -> String - */ -static VALUE ruby_whisper_params_get_language(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - if (rwp->params.language) { - return rb_str_new2(rwp->params.language); - } else { - return rb_str_new2("auto"); - } -} -/* - * call-seq: - * translate = do_translate -> do_translate - */ -static VALUE ruby_whisper_params_set_translate(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, translate, value) -} -/* - * call-seq: - * translate -> bool - */ -static VALUE ruby_whisper_params_get_translate(VALUE self) { - BOOL_PARAMS_GETTER(self, translate) -} -/* - * call-seq: - * no_context = dont_use_context -> dont_use_context - */ -static VALUE ruby_whisper_params_set_no_context(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, no_context, value) -} -/* - * If true, does not use past transcription (if any) as initial prompt for the decoder. - * - * call-seq: - * no_context -> bool - */ -static VALUE ruby_whisper_params_get_no_context(VALUE self) { - BOOL_PARAMS_GETTER(self, no_context) -} -/* - * call-seq: - * single_segment = force_single -> force_single - */ -static VALUE ruby_whisper_params_set_single_segment(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, single_segment, value) -} -/* - * If true, forces single segment output (useful for streaming). - * - * call-seq: - * single_segment -> bool - */ -static VALUE ruby_whisper_params_get_single_segment(VALUE self) { - BOOL_PARAMS_GETTER(self, single_segment) -} -/* - * call-seq: - * print_special = force_print -> force_print - */ -static VALUE ruby_whisper_params_set_print_special(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, print_special, value) -} -/* - * If true, prints special tokens (e.g. , , , etc.). - * - * call-seq: - * print_special -> bool - */ -static VALUE ruby_whisper_params_get_print_special(VALUE self) { - BOOL_PARAMS_GETTER(self, print_special) -} -/* - * call-seq: - * print_progress = force_print -> force_print - */ -static VALUE ruby_whisper_params_set_print_progress(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, print_progress, value) -} -/* - * If true, prints progress information. - * - * call-seq: - * print_progress -> bool - */ -static VALUE ruby_whisper_params_get_print_progress(VALUE self) { - BOOL_PARAMS_GETTER(self, print_progress) -} -/* - * call-seq: - * print_realtime = force_print -> force_print - */ -static VALUE ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, print_realtime, value) -} -/* - * If true, prints results from within whisper.cpp. (avoid it, use callback instead) - * call-seq: - * print_realtime -> bool - */ -static VALUE ruby_whisper_params_get_print_realtime(VALUE self) { - BOOL_PARAMS_GETTER(self, print_realtime) -} -/* - * call-seq: - * print_timestamps = force_print -> force_print - */ -static VALUE ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, print_timestamps, value) -} -/* - * If true, prints timestamps for each text segment when printing realtime. - * - * call-seq: - * print_timestamps -> bool - */ -static VALUE ruby_whisper_params_get_print_timestamps(VALUE self) { - BOOL_PARAMS_GETTER(self, print_timestamps) -} -/* - * call-seq: - * suppress_blank = force_suppress -> force_suppress - */ -static VALUE ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, suppress_blank, value) -} -/* - * If true, suppresses blank outputs. - * - * call-seq: - * suppress_blank -> bool - */ -static VALUE ruby_whisper_params_get_suppress_blank(VALUE self) { - BOOL_PARAMS_GETTER(self, suppress_blank) -} -/* - * call-seq: - * suppress_nst = force_suppress -> force_suppress - */ -static VALUE ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, suppress_nst, value) -} -/* - * If true, suppresses non-speech-tokens. - * - * call-seq: - * suppress_nst -> bool - */ -static VALUE ruby_whisper_params_get_suppress_nst(VALUE self) { - BOOL_PARAMS_GETTER(self, suppress_nst) -} -/* - * If true, enables token-level timestamps. - * - * call-seq: - * token_timestamps -> bool - */ -static VALUE ruby_whisper_params_get_token_timestamps(VALUE self) { - BOOL_PARAMS_GETTER(self, token_timestamps) -} -/* - * call-seq: - * token_timestamps = force_timestamps -> force_timestamps - */ -static VALUE ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, token_timestamps, value) -} -/* - * If true, split on word rather than on token (when used with max_len). - * - * call-seq: - * translate -> bool - */ -static VALUE ruby_whisper_params_get_split_on_word(VALUE self) { - BOOL_PARAMS_GETTER(self, split_on_word) -} -/* - * call-seq: - * split_on_word = force_split -> force_split - */ -static VALUE ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) { - BOOL_PARAMS_SETTER(self, split_on_word, value) -} -/* - * Tokens to provide to the whisper decoder as initial prompt - * these are prepended to any existing text context from a previous call - * use whisper_tokenize() to convert text to tokens. - * Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224). - * - * call-seq: - * initial_prompt -> String - */ -static VALUE ruby_whisper_params_get_initial_prompt(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return rwp->params.initial_prompt == nullptr ? Qnil : rb_str_new2(rwp->params.initial_prompt); -} -/* - * call-seq: - * initial_prompt = prompt -> prompt - */ -static VALUE ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.initial_prompt = StringValueCStr(value); - return value; -} -/* - * If true, enables diarization. - * - * call-seq: - * diarize -> bool - */ -static VALUE ruby_whisper_params_get_diarize(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - if (rwp->diarize) { - return Qtrue; - } else { - return Qfalse; - } -} -/* - * call-seq: - * diarize = force_diarize -> force_diarize - */ -static VALUE ruby_whisper_params_set_diarize(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - if (value == Qfalse || value == Qnil) { - rwp->diarize = false; - } else { - rwp->diarize = true; - } \ - return value; -} - -/* - * Start offset in ms. - * - * call-seq: - * offset -> Integer - */ -static VALUE ruby_whisper_params_get_offset(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return INT2NUM(rwp->params.offset_ms); -} -/* - * call-seq: - * offset = offset_ms -> offset_ms - */ -static VALUE ruby_whisper_params_set_offset(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.offset_ms = NUM2INT(value); - return value; -} -/* - * Audio duration to process in ms. - * - * call-seq: - * duration -> Integer - */ -static VALUE ruby_whisper_params_get_duration(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return INT2NUM(rwp->params.duration_ms); -} -/* - * call-seq: - * duration = duration_ms -> duration_ms - */ -static VALUE ruby_whisper_params_set_duration(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.duration_ms = NUM2INT(value); - return value; -} - -/* - * Max tokens to use from past text as prompt for the decoder. - * - * call-seq: - * max_text_tokens -> Integer - */ -static VALUE ruby_whisper_params_get_max_text_tokens(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return INT2NUM(rwp->params.n_max_text_ctx); -} -/* - * call-seq: - * max_text_tokens = n_tokens -> n_tokens - */ -static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.n_max_text_ctx = NUM2INT(value); - return value; -} -/* - * call-seq: - * temperature -> Float - */ -static VALUE ruby_whisper_params_get_temperature(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return DBL2NUM(rwp->params.temperature); -} -/* - * call-seq: - * temperature = temp -> temp - */ -static VALUE ruby_whisper_params_set_temperature(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.temperature = RFLOAT_VALUE(value); - return value; -} -/* - * See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 - * - * call-seq: - * max_initial_ts -> Flaot - */ -static VALUE ruby_whisper_params_get_max_initial_ts(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return DBL2NUM(rwp->params.max_initial_ts); -} -/* - * call-seq: - * max_initial_ts = timestamp -> timestamp - */ -static VALUE ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.max_initial_ts = RFLOAT_VALUE(value); - return value; -} -/* - * call-seq: - * length_penalty -> Float - */ -static VALUE ruby_whisper_params_get_length_penalty(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return DBL2NUM(rwp->params.length_penalty); -} -/* - * call-seq: - * length_penalty = penalty -> penalty - */ -static VALUE ruby_whisper_params_set_length_penalty(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.length_penalty = RFLOAT_VALUE(value); - return value; -} -/* - * call-seq: - * temperature_inc -> Float - */ -static VALUE ruby_whisper_params_get_temperature_inc(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return DBL2NUM(rwp->params.temperature_inc); -} -/* - * call-seq: - * temperature_inc = inc -> inc - */ -static VALUE ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.temperature_inc = RFLOAT_VALUE(value); - return value; -} -/* - * Similar to OpenAI's "compression_ratio_threshold" - * - * call-seq: - * entropy_thold -> Float - */ -static VALUE ruby_whisper_params_get_entropy_thold(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return DBL2NUM(rwp->params.entropy_thold); -} -/* - * call-seq: - * entropy_thold = threshold -> threshold - */ -static VALUE ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.entropy_thold = RFLOAT_VALUE(value); - return value; -} -/* - * call-seq: - * logprob_thold -> Float - */ -static VALUE ruby_whisper_params_get_logprob_thold(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return DBL2NUM(rwp->params.logprob_thold); -} -/* - * call-seq: - * logprob_thold = threshold -> threshold - */ -static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.logprob_thold = RFLOAT_VALUE(value); - return value; -} -/* - * call-seq: - * no_speech_thold -> Float - */ -static VALUE ruby_whisper_params_get_no_speech_thold(VALUE self) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - return DBL2NUM(rwp->params.no_speech_thold); -} -/* - * call-seq: - * no_speech_thold = threshold -> threshold - */ -static VALUE ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->params.no_speech_thold = RFLOAT_VALUE(value); - return value; -} -/* - * Sets new segment callback, called for every newly generated text segment. - * - * params.new_segment_callback = ->(context, _, n_new, user_data) { - * # ... - * } - * - * call-seq: - * new_segment_callback = callback -> callback - */ -static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->new_segment_callback_container->callback = value; - return value; -} -/* - * Sets user data passed to the last argument of new segment callback. - * - * call-seq: - * new_segment_callback_user_data = user_data -> use_data - */ -static VALUE ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->new_segment_callback_container->user_data = value; - return value; -} -/* - * Sets progress callback, called on each progress update. - * - * params.new_segment_callback = ->(context, _, n_new, user_data) { - * # ... - * } - * - * call-seq: - * progress_callback = callback -> callback - */ -static VALUE ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->progress_callback_container->callback = value; - return value; -} -/* - * Sets user data passed to the last argument of progress callback. - * - * call-seq: - * progress_callback_user_data = user_data -> use_data - */ -static VALUE ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->progress_callback_container->user_data = value; - return value; -} -/* - * Sets abort callback, called to check if the process should be aborted. - * - * params.abort_callback = ->(user_data) { - * # ... - * } - * - * call-seq: - * abort_callback = callback -> callback - */ -static VALUE ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->abort_callback_container->callback = value; - return value; -} -/* - * Sets user data passed to the last argument of abort callback. - * - * call-seq: - * abort_callback_user_data = user_data -> use_data - */ -static VALUE ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) { - ruby_whisper_params *rwp; - Data_Get_Struct(self, ruby_whisper_params, rwp); - rwp->abort_callback_container->user_data = value; - return value; -} - -// High level API - -typedef struct { - VALUE context; - int index; -} ruby_whisper_segment; - -typedef struct { - VALUE context; -} ruby_whisper_model; - -static void rb_whisper_segment_mark(ruby_whisper_segment *rws) { - rb_gc_mark(rws->context); -} - -static VALUE ruby_whisper_segment_allocate(VALUE klass) { - ruby_whisper_segment *rws; - rws = ALLOC(ruby_whisper_segment); - return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws); -} - -static VALUE rb_whisper_segment_initialize(VALUE context, int index) { - ruby_whisper_segment *rws; - const VALUE segment = ruby_whisper_segment_allocate(cSegment); - Data_Get_Struct(segment, ruby_whisper_segment, rws); - rws->context = context; - rws->index = index; - return segment; -}; - -/* - * Yields each Whisper::Segment: - * - * whisper.transcribe("path/to/audio.wav", params) - * whisper.each_segment do |segment| - * puts segment.text - * end - * - * Returns an Enumerator if no block given: - * - * whisper.transcribe("path/to/audio.wav", params) - * enum = whisper.each_segment - * enum.to_a # => [#, ...] - * - * call-seq: - * each_segment {|segment| ... } - * each_segment -> Enumerator - */ -static VALUE ruby_whisper_each_segment(VALUE self) { - if (!rb_block_given_p()) { - const VALUE method_name = rb_funcall(self, id___method__, 0); - return rb_funcall(self, id_to_enum, 1, method_name); - } - - ruby_whisper *rw; - Data_Get_Struct(self, ruby_whisper, rw); - - const int n_segments = whisper_full_n_segments(rw->context); - for (int i = 0; i < n_segments; ++i) { - rb_yield(rb_whisper_segment_initialize(self, i)); - } - - return self; -} - -/* - * Hook called on new segment. Yields each Whisper::Segment. - * - * whisper.on_new_segment do |segment| - * # ... - * end - * - * call-seq: - * on_new_segment {|segment| ... } - */ -static VALUE ruby_whisper_params_on_new_segment(VALUE self) { - ruby_whisper_params *rws; - Data_Get_Struct(self, ruby_whisper_params, rws); - const VALUE blk = rb_block_proc(); - rb_ary_push(rws->new_segment_callback_container->callbacks, blk); - return Qnil; -} - -/* - * Hook called on progress update. Yields each progress Integer between 0 and 100. - * - * whisper.on_progress do |progress| - * # ... - * end - * - * call-seq: - * on_progress {|progress| ... } - */ -static VALUE ruby_whisper_params_on_progress(VALUE self) { - ruby_whisper_params *rws; - Data_Get_Struct(self, ruby_whisper_params, rws); - const VALUE blk = rb_block_proc(); - rb_ary_push(rws->progress_callback_container->callbacks, blk); - return Qnil; -} - -/* - * Call block to determine whether abort or not. Return +true+ when you want to abort. - * - * params.abort_on do - * if some_condition - * true # abort - * else - * false # continue - * end - * end - * - * call-seq: - * abort_on { ... } - */ -static VALUE ruby_whisper_params_abort_on(VALUE self) { - ruby_whisper_params *rws; - Data_Get_Struct(self, ruby_whisper_params, rws); - const VALUE blk = rb_block_proc(); - rb_ary_push(rws->abort_callback_container->callbacks, blk); - return Qnil; -} - -/* - * Start time in milliseconds. - * - * call-seq: - * start_time -> Integer - */ -static VALUE ruby_whisper_segment_get_start_time(VALUE self) { - ruby_whisper_segment *rws; - Data_Get_Struct(self, ruby_whisper_segment, rws); - ruby_whisper *rw; - Data_Get_Struct(rws->context, ruby_whisper, rw); - const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index); - // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it - return INT2NUM(t0 * 10); -} - -/* - * End time in milliseconds. - * - * call-seq: - * end_time -> Integer - */ -static VALUE ruby_whisper_segment_get_end_time(VALUE self) { - ruby_whisper_segment *rws; - Data_Get_Struct(self, ruby_whisper_segment, rws); - ruby_whisper *rw; - Data_Get_Struct(rws->context, ruby_whisper, rw); - const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index); - // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it - return INT2NUM(t1 * 10); -} - -/* - * Whether the next segment is predicted as a speaker turn. - * - * call-seq: - * speaker_turn_next? -> bool - */ -static VALUE ruby_whisper_segment_get_speaker_turn_next(VALUE self) { - ruby_whisper_segment *rws; - Data_Get_Struct(self, ruby_whisper_segment, rws); - ruby_whisper *rw; - Data_Get_Struct(rws->context, ruby_whisper, rw); - return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse; -} - -/* - * call-seq: - * text -> String - */ -static VALUE ruby_whisper_segment_get_text(VALUE self) { - ruby_whisper_segment *rws; - Data_Get_Struct(self, ruby_whisper_segment, rws); - ruby_whisper *rw; - Data_Get_Struct(rws->context, ruby_whisper, rw); - const char * text = whisper_full_get_segment_text(rw->context, rws->index); - return rb_str_new2(text); -} - -/* - * call-seq: - * no_speech_prob -> Float - */ -static VALUE ruby_whisper_segment_get_no_speech_prob(VALUE self) { - ruby_whisper_segment *rws; - Data_Get_Struct(self, ruby_whisper_segment, rws); - ruby_whisper *rw; - Data_Get_Struct(rws->context, ruby_whisper, rw); - return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index)); -} - -static void rb_whisper_model_mark(ruby_whisper_model *rwm) { - rb_gc_mark(rwm->context); -} - -static VALUE ruby_whisper_model_allocate(VALUE klass) { - ruby_whisper_model *rwm; - rwm = ALLOC(ruby_whisper_model); - return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm); -} - -static VALUE rb_whisper_model_initialize(VALUE context) { - ruby_whisper_model *rwm; - const VALUE model = ruby_whisper_model_allocate(cModel); - Data_Get_Struct(model, ruby_whisper_model, rwm); - rwm->context = context; - return model; -}; - -/* - * call-seq: - * model -> Whisper::Model - */ -static VALUE ruby_whisper_get_model(VALUE self) { - return rb_whisper_model_initialize(self); -} - -/* - * call-seq: - * n_vocab -> Integer - */ -static VALUE ruby_whisper_c_model_n_vocab(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_n_vocab(rw->context)); -} - -/* - * call-seq: - * n_audio_ctx -> Integer - */ -static VALUE ruby_whisper_c_model_n_audio_ctx(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_n_audio_ctx(rw->context)); -} - -/* - * call-seq: - * n_audio_state -> Integer - */ -static VALUE ruby_whisper_c_model_n_audio_state(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_n_audio_state(rw->context)); -} - -/* - * call-seq: - * n_audio_head -> Integer - */ -static VALUE ruby_whisper_c_model_n_audio_head(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_n_audio_head(rw->context)); -} - -/* - * call-seq: - * n_audio_layer -> Integer - */ -static VALUE ruby_whisper_c_model_n_audio_layer(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_n_audio_layer(rw->context)); -} - -/* - * call-seq: - * n_text_ctx -> Integer - */ -static VALUE ruby_whisper_c_model_n_text_ctx(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_n_text_ctx(rw->context)); -} - -/* - * call-seq: - * n_text_state -> Integer - */ -static VALUE ruby_whisper_c_model_n_text_state(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_n_text_state(rw->context)); -} - -/* - * call-seq: - * n_text_head -> Integer - */ -static VALUE ruby_whisper_c_model_n_text_head(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_n_text_head(rw->context)); -} - -/* - * call-seq: - * n_text_layer -> Integer - */ -static VALUE ruby_whisper_c_model_n_text_layer(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_n_text_layer(rw->context)); -} - -/* - * call-seq: - * n_mels -> Integer - */ -static VALUE ruby_whisper_c_model_n_mels(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_n_mels(rw->context)); -} - -/* - * call-seq: - * ftype -> Integer - */ -static VALUE ruby_whisper_c_model_ftype(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return INT2NUM(whisper_model_ftype(rw->context)); -} - -/* - * call-seq: - * type -> String - */ -static VALUE ruby_whisper_c_model_type(VALUE self) { - ruby_whisper_model *rwm; - Data_Get_Struct(self, ruby_whisper_model, rwm); - ruby_whisper *rw; - Data_Get_Struct(rwm->context, ruby_whisper, rw); - return rb_str_new2(whisper_model_type_readable(rw->context)); -} - -static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) { - const int c_code = NUM2INT(code); - const char *raw_message; - switch (c_code) { - case -2: - raw_message = "failed to compute log mel spectrogram"; - break; - case -3: - raw_message = "failed to auto-detect language"; - break; - case -4: - raw_message = "too many decoders requested"; - break; - case -5: - raw_message = "audio_ctx is larger than the maximum allowed"; - break; - case -6: - raw_message = "failed to encode"; - break; - case -7: - raw_message = "whisper_kv_cache_init() failed for self-attention cache"; - break; - case -8: - raw_message = "failed to decode"; - break; - case -9: - raw_message = "failed to decode"; - break; - default: - raw_message = "unknown error"; - break; - } - const VALUE message = rb_str_new2(raw_message); - rb_call_super(1, &message); - rb_iv_set(self, "@code", code); - - return self; -} - - -void Init_whisper() { - id_to_s = rb_intern("to_s"); - id_call = rb_intern("call"); - id___method__ = rb_intern("__method__"); - id_to_enum = rb_intern("to_enum"); - id_length = rb_intern("length"); - id_next = rb_intern("next"); - id_new = rb_intern("new"); - id_to_path = rb_intern("to_path"); - id_URI = rb_intern("URI"); - id_pre_converted_models = rb_intern("pre_converted_models"); - - mWhisper = rb_define_module("Whisper"); - cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); - cParams = rb_define_class_under(mWhisper, "Params", rb_cObject); - eError = rb_define_class_under(mWhisper, "Error", rb_eStandardError); - - rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE)); - rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO)); - rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN)); - rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR)); - rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG)); - rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT)); - - rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0); - rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1); - rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1); - rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1); - rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2); - rb_define_singleton_method(mWhisper, "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1); - - rb_define_alloc_func(cContext, ruby_whisper_allocate); - rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); - - rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); - rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0); - rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0); - rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0); - rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0); - rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0); - rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0); - rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0); - rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0); - rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0); - rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0); - rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0); - rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0); - rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0); - rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0); - rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1); - rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1); - rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1); - rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1); - rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1); - rb_define_method(cContext, "full", ruby_whisper_full, -1); - rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1); - - rb_define_alloc_func(cParams, ruby_whisper_params_allocate); - - rb_define_method(cParams, "language=", ruby_whisper_params_set_language, 1); - rb_define_method(cParams, "language", ruby_whisper_params_get_language, 0); - rb_define_method(cParams, "translate=", ruby_whisper_params_set_translate, 1); - rb_define_method(cParams, "translate", ruby_whisper_params_get_translate, 0); - rb_define_method(cParams, "no_context=", ruby_whisper_params_set_no_context, 1); - rb_define_method(cParams, "no_context", ruby_whisper_params_get_no_context, 0); - rb_define_method(cParams, "single_segment=", ruby_whisper_params_set_single_segment, 1); - rb_define_method(cParams, "single_segment", ruby_whisper_params_get_single_segment, 0); - rb_define_method(cParams, "print_special", ruby_whisper_params_get_print_special, 0); - rb_define_method(cParams, "print_special=", ruby_whisper_params_set_print_special, 1); - rb_define_method(cParams, "print_progress", ruby_whisper_params_get_print_progress, 0); - rb_define_method(cParams, "print_progress=", ruby_whisper_params_set_print_progress, 1); - rb_define_method(cParams, "print_realtime", ruby_whisper_params_get_print_realtime, 0); - rb_define_method(cParams, "print_realtime=", ruby_whisper_params_set_print_realtime, 1); - rb_define_method(cParams, "print_timestamps", ruby_whisper_params_get_print_timestamps, 0); - rb_define_method(cParams, "print_timestamps=", ruby_whisper_params_set_print_timestamps, 1); - rb_define_method(cParams, "suppress_blank", ruby_whisper_params_get_suppress_blank, 0); - rb_define_method(cParams, "suppress_blank=", ruby_whisper_params_set_suppress_blank, 1); - rb_define_method(cParams, "suppress_nst", ruby_whisper_params_get_suppress_nst, 0); - rb_define_method(cParams, "suppress_nst=", ruby_whisper_params_set_suppress_nst, 1); - rb_define_method(cParams, "token_timestamps", ruby_whisper_params_get_token_timestamps, 0); - rb_define_method(cParams, "token_timestamps=", ruby_whisper_params_set_token_timestamps, 1); - rb_define_method(cParams, "split_on_word", ruby_whisper_params_get_split_on_word, 0); - rb_define_method(cParams, "split_on_word=", ruby_whisper_params_set_split_on_word, 1); - rb_define_method(cParams, "initial_prompt", ruby_whisper_params_get_initial_prompt, 0); - rb_define_method(cParams, "initial_prompt=", ruby_whisper_params_set_initial_prompt, 1); - rb_define_method(cParams, "diarize", ruby_whisper_params_get_diarize, 0); - rb_define_method(cParams, "diarize=", ruby_whisper_params_set_diarize, 1); - - rb_define_method(cParams, "offset", ruby_whisper_params_get_offset, 0); - rb_define_method(cParams, "offset=", ruby_whisper_params_set_offset, 1); - rb_define_method(cParams, "duration", ruby_whisper_params_get_duration, 0); - rb_define_method(cParams, "duration=", ruby_whisper_params_set_duration, 1); - - rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); - rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); - rb_define_method(cParams, "temperature", ruby_whisper_params_get_temperature, 0); - rb_define_method(cParams, "temperature=", ruby_whisper_params_set_temperature, 1); - rb_define_method(cParams, "max_initial_ts", ruby_whisper_params_get_max_initial_ts, 0); - rb_define_method(cParams, "max_initial_ts=", ruby_whisper_params_set_max_initial_ts, 1); - rb_define_method(cParams, "length_penalty", ruby_whisper_params_get_length_penalty, 0); - rb_define_method(cParams, "length_penalty=", ruby_whisper_params_set_length_penalty, 1); - rb_define_method(cParams, "temperature_inc", ruby_whisper_params_get_temperature_inc, 0); - rb_define_method(cParams, "temperature_inc=", ruby_whisper_params_set_temperature_inc, 1); - rb_define_method(cParams, "entropy_thold", ruby_whisper_params_get_entropy_thold, 0); - rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1); - rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0); - rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1); - rb_define_method(cParams, "no_speech_thold", ruby_whisper_params_get_no_speech_thold, 0); - rb_define_method(cParams, "no_speech_thold=", ruby_whisper_params_set_no_speech_thold, 1); - - rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1); - rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1); - rb_define_method(cParams, "progress_callback=", ruby_whisper_params_set_progress_callback, 1); - rb_define_method(cParams, "progress_callback_user_data=", ruby_whisper_params_set_progress_callback_user_data, 1); - rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1); - rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1); - - rb_define_attr(eError, "code", true, false); - rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1); - - // High leve - cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject); - - rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate); - rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0); - rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); - rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); - rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0); - rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0); - rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0); - rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0); - rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0); - rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0); - - cModel = rb_define_class_under(mWhisper, "Model", rb_cObject); - rb_define_alloc_func(cModel, ruby_whisper_model_allocate); - rb_define_method(cContext, "model", ruby_whisper_get_model, 0); - rb_define_method(cModel, "n_vocab", ruby_whisper_c_model_n_vocab, 0); - rb_define_method(cModel, "n_audio_ctx", ruby_whisper_c_model_n_audio_ctx, 0); - rb_define_method(cModel, "n_audio_state", ruby_whisper_c_model_n_audio_state, 0); - rb_define_method(cModel, "n_audio_head", ruby_whisper_c_model_n_audio_head, 0); - rb_define_method(cModel, "n_audio_layer", ruby_whisper_c_model_n_audio_layer, 0); - rb_define_method(cModel, "n_text_ctx", ruby_whisper_c_model_n_text_ctx, 0); - rb_define_method(cModel, "n_text_state", ruby_whisper_c_model_n_text_state, 0); - rb_define_method(cModel, "n_text_head", ruby_whisper_c_model_n_text_head, 0); - rb_define_method(cModel, "n_text_layer", ruby_whisper_c_model_n_text_layer, 0); - rb_define_method(cModel, "n_mels", ruby_whisper_c_model_n_mels, 0); - rb_define_method(cModel, "ftype", ruby_whisper_c_model_ftype, 0); - rb_define_method(cModel, "type", ruby_whisper_c_model_type, 0); - - rb_require("whisper/model/uri"); -} -#ifdef __cplusplus -} -#endif diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 21e36c49..bbf3435e 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -22,4 +22,13 @@ typedef struct { ruby_whisper_callback_container *abort_callback_container; } ruby_whisper_params; +typedef struct { + VALUE context; + int index; +} ruby_whisper_segment; + +typedef struct { + VALUE context; +} ruby_whisper_model; + #endif diff --git a/bindings/ruby/ext/ruby_whisper_context.c b/bindings/ruby/ext/ruby_whisper_context.c new file mode 100644 index 00000000..df375218 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_context.c @@ -0,0 +1,613 @@ +#include +#include +#include "ruby_whisper.h" + +extern ID id_to_s; +extern ID id___method__; +extern ID id_to_enum; +extern ID id_length; +extern ID id_next; +extern ID id_new; +extern ID id_to_path; +extern ID id_URI; +extern ID id_pre_converted_models; + +extern VALUE cContext; +extern VALUE eError; +extern VALUE cModel; + +extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self); +extern VALUE rb_whisper_model_initialize(VALUE context); +extern VALUE rb_whisper_segment_initialize(VALUE context, int index); +extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context); + +static void +ruby_whisper_free(ruby_whisper *rw) +{ + if (rw->context) { + whisper_free(rw->context); + rw->context = NULL; + } +} + +void +rb_whisper_mark(ruby_whisper *rw) +{ + // call rb_gc_mark on any ruby references in rw +} + +void +rb_whisper_free(ruby_whisper *rw) +{ + ruby_whisper_free(rw); + free(rw); +} + +static VALUE +ruby_whisper_allocate(VALUE klass) +{ + ruby_whisper *rw; + rw = ALLOC(ruby_whisper); + rw->context = NULL; + return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw); +} + +/* + * call-seq: + * new("base.en") -> Whisper::Context + * new("path/to/model.bin") -> Whisper::Context + * new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context + */ +static VALUE +ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) +{ + ruby_whisper *rw; + VALUE whisper_model_file_path; + + // TODO: we can support init from buffer here too maybe another ruby object to expose + rb_scan_args(argc, argv, "01", &whisper_model_file_path); + Data_Get_Struct(self, ruby_whisper, rw); + + VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0); + VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path); + if (!NIL_P(pre_converted_model)) { + whisper_model_file_path = pre_converted_model; + } + if (TYPE(whisper_model_file_path) == T_STRING) { + const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path); + if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) { + VALUE uri_class = rb_const_get(cModel, id_URI); + whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class); + } + } + if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) { + VALUE uri_class = rb_const_get(cModel, id_URI); + whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class); + } + if (rb_respond_to(whisper_model_file_path, id_to_path)) { + whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0); + } + if (!rb_respond_to(whisper_model_file_path, id_to_s)) { + rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context"); + } + rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params()); + if (rw->context == NULL) { + rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context"); + } + return self; +} + +/* + * call-seq: + * model_n_vocab -> Integer + */ +VALUE ruby_whisper_model_n_vocab(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_vocab(rw->context)); +} + +/* + * call-seq: + * model_n_audio_ctx -> Integer + */ +VALUE ruby_whisper_model_n_audio_ctx(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_ctx(rw->context)); +} + +/* + * call-seq: + * model_n_audio_state -> Integer + */ +VALUE ruby_whisper_model_n_audio_state(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_state(rw->context)); +} + +/* + * call-seq: + * model_n_audio_head -> Integer + */ +VALUE ruby_whisper_model_n_audio_head(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_head(rw->context)); +} + +/* + * call-seq: + * model_n_audio_layer -> Integer + */ +VALUE ruby_whisper_model_n_audio_layer(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_layer(rw->context)); +} + +/* + * call-seq: + * model_n_text_ctx -> Integer + */ +VALUE ruby_whisper_model_n_text_ctx(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_ctx(rw->context)); +} + +/* + * call-seq: + * model_n_text_state -> Integer + */ +VALUE ruby_whisper_model_n_text_state(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_state(rw->context)); +} + +/* + * call-seq: + * model_n_text_head -> Integer + */ +VALUE ruby_whisper_model_n_text_head(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_head(rw->context)); +} + +/* + * call-seq: + * model_n_text_layer -> Integer + */ +VALUE ruby_whisper_model_n_text_layer(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_layer(rw->context)); +} + +/* + * call-seq: + * model_n_mels -> Integer + */ +VALUE ruby_whisper_model_n_mels(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_n_mels(rw->context)); +} + +/* + * call-seq: + * model_ftype -> Integer + */ +VALUE ruby_whisper_model_ftype(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_model_ftype(rw->context)); +} + +/* + * call-seq: + * model_type -> String + */ +VALUE ruby_whisper_model_type(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return rb_str_new2(whisper_model_type_readable(rw->context)); +} + +/* + * Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text + * Not thread safe for same context + * Uses the specified decoding strategy to obtain the text. + * + * call-seq: + * full(params, samples, n_samples) -> nil + * full(params, samples) -> nil + * + * The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data. + */ +VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) +{ + if (argc < 2 || argc > 3) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + ruby_whisper *rw; + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper, rw); + VALUE params = argv[0]; + Data_Get_Struct(params, ruby_whisper_params, rwp); + VALUE samples = argv[1]; + int n_samples; + rb_memory_view_t view; + const bool memory_view_available_p = rb_memory_view_available_p(samples); + if (argc == 3) { + n_samples = NUM2INT(argv[2]); + if (TYPE(samples) == T_ARRAY) { + if (RARRAY_LEN(samples) < n_samples) { + rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); + } + } + // Should check when samples.respond_to?(:length)? + } else { + if (TYPE(samples) == T_ARRAY) { + n_samples = RARRAY_LEN(samples); + } else if (memory_view_available_p) { + if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { + view.obj = Qnil; + rb_raise(rb_eArgError, "unable to get a memory view"); + } + n_samples = view.byte_size / view.item_size; + } else if (rb_respond_to(samples, id_length)) { + n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); + } else { + rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); + } + } + float * c_samples = (float *)malloc(n_samples * sizeof(float)); + if (memory_view_available_p) { + c_samples = (float *)view.data; + } else { + if (TYPE(samples) == T_ARRAY) { + for (int i = 0; i < n_samples; i++) { + c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); + } + } else { + // TODO: use rb_block_call + VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); + for (int i = 0; i < n_samples; i++) { + // TODO: check if iter is exhausted and raise ArgumentError appropriately + VALUE sample = rb_funcall(iter, id_next, 0); + c_samples[i] = RFLOAT_VALUE(sample); + } + } + } + register_callbacks(rwp, &self); + const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples); + if (0 == result) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, result)); + } +} + +/* + * Split the input audio in chunks and process each chunk separately using whisper_full_with_state() + * Result is stored in the default state of the context + * Not thread safe if executed in parallel on the same context. + * It seems this approach can offer some speedup in some cases. + * However, the transcription accuracy can be worse at the beginning and end of each chunk. + * + * call-seq: + * full_parallel(params, samples) -> nil + * full_parallel(params, samples, n_samples) -> nil + * full_parallel(params, samples, n_samples, n_processors) -> nil + * full_parallel(params, samples, nil, n_processors) -> nil + */ +static VALUE +ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) +{ + if (argc < 2 || argc > 4) { + rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc); + } + + ruby_whisper *rw; + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper, rw); + VALUE params = argv[0]; + Data_Get_Struct(params, ruby_whisper_params, rwp); + VALUE samples = argv[1]; + int n_samples; + int n_processors; + rb_memory_view_t view; + const bool memory_view_available_p = rb_memory_view_available_p(samples); + switch (argc) { + case 2: + n_processors = 1; + break; + case 3: + n_processors = 1; + break; + case 4: + n_processors = NUM2INT(argv[3]); + break; + } + if (argc >= 3 && !NIL_P(argv[2])) { + n_samples = NUM2INT(argv[2]); + if (TYPE(samples) == T_ARRAY) { + if (RARRAY_LEN(samples) < n_samples) { + rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples); + } + } + // Should check when samples.respond_to?(:length)? + } else if (memory_view_available_p) { + if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) { + view.obj = Qnil; + rb_raise(rb_eArgError, "unable to get a memory view"); + } + n_samples = view.byte_size / view.item_size; + } else { + if (TYPE(samples) == T_ARRAY) { + n_samples = RARRAY_LEN(samples); + } else if (rb_respond_to(samples, id_length)) { + n_samples = NUM2INT(rb_funcall(samples, id_length, 0)); + } else { + rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given"); + } + } + float * c_samples = (float *)malloc(n_samples * sizeof(float)); + if (memory_view_available_p) { + c_samples = (float *)view.data; + } else { + if (TYPE(samples) == T_ARRAY) { + for (int i = 0; i < n_samples; i++) { + c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i)); + } + } else { + // FIXME: use rb_block_call + VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each")); + for (int i = 0; i < n_samples; i++) { + // TODO: check if iter is exhausted and raise ArgumentError + VALUE sample = rb_funcall(iter, id_next, 0); + c_samples[i] = RFLOAT_VALUE(sample); + } + } + } + register_callbacks(rwp, &self); + const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors); + if (0 == result) { + return self; + } else { + rb_exc_raise(rb_funcall(eError, id_new, 1, result)); + } +} + +/* + * Number of segments. + * + * call-seq: + * full_n_segments -> Integer + */ +static VALUE +ruby_whisper_full_n_segments(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_full_n_segments(rw->context)); +} + +/* + * Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full. + * + * call-seq: + * full_lang_id -> Integer + */ +static VALUE +ruby_whisper_full_lang_id(VALUE self) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + return INT2NUM(whisper_full_lang_id(rw->context)); +} + +static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment) +{ + const int c_i_segment = NUM2INT(i_segment); + if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) { + rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment); + } + return c_i_segment; +} + +/* + * Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + * + * full_get_segment_t0(3) # => 1668 (16680 ms) + * + * call-seq: + * full_get_segment_t0(segment_index) -> Integer + */ +static VALUE +ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment); + return INT2NUM(t0); +} + +/* + * End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds). + * + * full_get_segment_t1(3) # => 1668 (16680 ms) + * + * call-seq: + * full_get_segment_t1(segment_index) -> Integer + */ +static VALUE +ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment); + return INT2NUM(t1); +} + +/* + * Whether the next segment indexed by +segment_index+ is predicated as a speaker turn. + * + * full_get_segment_speacker_turn_next(3) # => true + * + * call-seq: + * full_get_segment_speacker_turn_next(segment_index) -> bool + */ +static VALUE +ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment); + return speaker_turn_next ? Qtrue : Qfalse; +} + +/* + * Text of a segment indexed by +segment_index+. + * + * full_get_segment_text(3) # => "ask not what your country can do for you, ..." + * + * call-seq: + * full_get_segment_text(segment_index) -> String + */ +static VALUE +ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const char * text = whisper_full_get_segment_text(rw->context, c_i_segment); + return rb_str_new2(text); +} + +/* + * call-seq: + * full_get_segment_no_speech_prob(segment_index) -> Float + */ +static VALUE +ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment) +{ + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment); + const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment); + return DBL2NUM(no_speech_prob); +} + +// High level API + +static VALUE +ruby_whisper_full_get_segment(VALUE self, VALUE i_segment) +{ + return rb_whisper_segment_initialize(self, NUM2INT(i_segment)); +} + +/* + * Yields each Whisper::Segment: + * + * whisper.transcribe("path/to/audio.wav", params) + * whisper.each_segment do |segment| + * puts segment.text + * end + * + * Returns an Enumerator if no block given: + * + * whisper.transcribe("path/to/audio.wav", params) + * enum = whisper.each_segment + * enum.to_a # => [#, ...] + * + * call-seq: + * each_segment {|segment| ... } + * each_segment -> Enumerator + */ +static VALUE +ruby_whisper_each_segment(VALUE self) +{ + if (!rb_block_given_p()) { + const VALUE method_name = rb_funcall(self, id___method__, 0); + return rb_funcall(self, id_to_enum, 1, method_name); + } + + ruby_whisper *rw; + Data_Get_Struct(self, ruby_whisper, rw); + + const int n_segments = whisper_full_n_segments(rw->context); + for (int i = 0; i < n_segments; ++i) { + rb_yield(rb_whisper_segment_initialize(self, i)); + } + + return self; +} + +/* + * call-seq: + * model -> Whisper::Model + */ +static VALUE +ruby_whisper_get_model(VALUE self) +{ + return rb_whisper_model_initialize(self); +} + +void +init_ruby_whisper_context(VALUE *mWhisper) +{ + cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject); + + rb_define_alloc_func(cContext, ruby_whisper_allocate); + rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1); + + rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1); + rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0); + rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0); + rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0); + rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0); + rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0); + rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0); + rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0); + rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0); + rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0); + rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0); + rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0); + rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0); + rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0); + rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0); + rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1); + rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1); + rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1); + rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1); + rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1); + rb_define_method(cContext, "full", ruby_whisper_full, -1); + rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1); + + // High leve + rb_define_method(cContext, "full_get_segment", ruby_whisper_full_get_segment, 1); + rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0); + + rb_define_method(cContext, "model", ruby_whisper_get_model, 0); +} diff --git a/bindings/ruby/ext/ruby_whisper_error.c b/bindings/ruby/ext/ruby_whisper_error.c new file mode 100644 index 00000000..b4dbec0c --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_error.c @@ -0,0 +1,52 @@ +#include + +extern VALUE eError; + +VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) +{ + const int c_code = NUM2INT(code); + const char *raw_message; + switch (c_code) { + case -2: + raw_message = "failed to compute log mel spectrogram"; + break; + case -3: + raw_message = "failed to auto-detect language"; + break; + case -4: + raw_message = "too many decoders requested"; + break; + case -5: + raw_message = "audio_ctx is larger than the maximum allowed"; + break; + case -6: + raw_message = "failed to encode"; + break; + case -7: + raw_message = "whisper_kv_cache_init() failed for self-attention cache"; + break; + case -8: + raw_message = "failed to decode"; + break; + case -9: + raw_message = "failed to decode"; + break; + default: + raw_message = "unknown error"; + break; + } + const VALUE message = rb_str_new2(raw_message); + rb_call_super(1, &message); + rb_iv_set(self, "@code", code); + + return self; +} + +void +init_ruby_whisper_error(VALUE *mWhisper) +{ + eError = rb_define_class_under(*mWhisper, "Error", rb_eStandardError); + + rb_define_attr(eError, "code", true, false); + rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1); +} diff --git a/bindings/ruby/ext/ruby_whisper_model.c b/bindings/ruby/ext/ruby_whisper_model.c new file mode 100644 index 00000000..1e0648fd --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_model.c @@ -0,0 +1,210 @@ +#include +#include "ruby_whisper.h" + +extern VALUE cModel; + +static void rb_whisper_model_mark(ruby_whisper_model *rwm) { + rb_gc_mark(rwm->context); +} + +static VALUE ruby_whisper_model_allocate(VALUE klass) { + ruby_whisper_model *rwm; + rwm = ALLOC(ruby_whisper_model); + return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm); +} + +VALUE rb_whisper_model_initialize(VALUE context) { + ruby_whisper_model *rwm; + const VALUE model = ruby_whisper_model_allocate(cModel); + Data_Get_Struct(model, ruby_whisper_model, rwm); + rwm->context = context; + return model; +}; + +/* + * call-seq: + * n_vocab -> Integer + */ +static VALUE +ruby_whisper_model_n_vocab(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_vocab(rw->context)); +} + +/* + * call-seq: + * n_audio_ctx -> Integer + */ +static VALUE +ruby_whisper_model_n_audio_ctx(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_ctx(rw->context)); +} + +/* + * call-seq: + * n_audio_state -> Integer + */ +static VALUE +ruby_whisper_model_n_audio_state(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_state(rw->context)); +} + +/* + * call-seq: + * n_audio_head -> Integer + */ +static VALUE +ruby_whisper_model_n_audio_head(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_head(rw->context)); +} + +/* + * call-seq: + * n_audio_layer -> Integer + */ +static VALUE +ruby_whisper_model_n_audio_layer(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_audio_layer(rw->context)); +} + +/* + * call-seq: + * n_text_ctx -> Integer + */ +static VALUE +ruby_whisper_model_n_text_ctx(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_ctx(rw->context)); +} + +/* + * call-seq: + * n_text_state -> Integer + */ +static VALUE +ruby_whisper_model_n_text_state(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_state(rw->context)); +} + +/* + * call-seq: + * n_text_head -> Integer + */ +static VALUE +ruby_whisper_model_n_text_head(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_head(rw->context)); +} + +/* + * call-seq: + * n_text_layer -> Integer + */ +static VALUE +ruby_whisper_model_n_text_layer(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_text_layer(rw->context)); +} + +/* + * call-seq: + * n_mels -> Integer + */ +static VALUE +ruby_whisper_model_n_mels(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_n_mels(rw->context)); +} + +/* + * call-seq: + * ftype -> Integer + */ +static VALUE +ruby_whisper_model_ftype(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return INT2NUM(whisper_model_ftype(rw->context)); +} + +/* + * call-seq: + * type -> String + */ +static VALUE +ruby_whisper_model_type(VALUE self) +{ + ruby_whisper_model *rwm; + Data_Get_Struct(self, ruby_whisper_model, rwm); + ruby_whisper *rw; + Data_Get_Struct(rwm->context, ruby_whisper, rw); + return rb_str_new2(whisper_model_type_readable(rw->context)); +} + +void +init_ruby_whisper_model(VALUE *mWhisper) +{ + cModel = rb_define_class_under(*mWhisper, "Model", rb_cObject); + + rb_define_alloc_func(cModel, ruby_whisper_model_allocate); + rb_define_method(cModel, "n_vocab", ruby_whisper_model_n_vocab, 0); + rb_define_method(cModel, "n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0); + rb_define_method(cModel, "n_audio_state", ruby_whisper_model_n_audio_state, 0); + rb_define_method(cModel, "n_audio_head", ruby_whisper_model_n_audio_head, 0); + rb_define_method(cModel, "n_audio_layer", ruby_whisper_model_n_audio_layer, 0); + rb_define_method(cModel, "n_text_ctx", ruby_whisper_model_n_text_ctx, 0); + rb_define_method(cModel, "n_text_state", ruby_whisper_model_n_text_state, 0); + rb_define_method(cModel, "n_text_head", ruby_whisper_model_n_text_head, 0); + rb_define_method(cModel, "n_text_layer", ruby_whisper_model_n_text_layer, 0); + rb_define_method(cModel, "n_mels", ruby_whisper_model_n_mels, 0); + rb_define_method(cModel, "ftype", ruby_whisper_model_ftype, 0); + rb_define_method(cModel, "type", ruby_whisper_model_type, 0); +} diff --git a/bindings/ruby/ext/ruby_whisper_params.c b/bindings/ruby/ext/ruby_whisper_params.c new file mode 100644 index 00000000..0446db32 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_params.c @@ -0,0 +1,1077 @@ +#include +#include "ruby_whisper.h" + +#define BOOL_PARAMS_SETTER(self, prop, value) \ + ruby_whisper_params *rwp; \ + Data_Get_Struct(self, ruby_whisper_params, rwp); \ + if (value == Qfalse || value == Qnil) { \ + rwp->params.prop = false; \ + } else { \ + rwp->params.prop = true; \ + } \ + return value; \ + +#define BOOL_PARAMS_GETTER(self, prop) \ + ruby_whisper_params *rwp; \ + Data_Get_Struct(self, ruby_whisper_params, rwp); \ + if (rwp->params.prop) { \ + return Qtrue; \ + } else { \ + return Qfalse; \ + } + +#define DEFINE_PARAM(param_name, nth) \ + id_ ## param_name = rb_intern(#param_name); \ + param_names[nth] = id_ ## param_name; \ + rb_define_method(cParams, #param_name, ruby_whisper_params_get_ ## param_name, 0); \ + rb_define_method(cParams, #param_name "=", ruby_whisper_params_set_ ## param_name, 1); + +#define RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT 30 + +extern VALUE cParams; + +extern ID id_call; + +extern VALUE rb_whisper_segment_initialize(VALUE context, int index); + +static ID param_names[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT]; +static ID id_language; +static ID id_translate; +static ID id_no_context; +static ID id_single_segment; +static ID id_print_special; +static ID id_print_progress; +static ID id_print_realtime; +static ID id_print_timestamps; +static ID id_suppress_blank; +static ID id_suppress_nst; +static ID id_token_timestamps; +static ID id_split_on_word; +static ID id_initial_prompt; +static ID id_diarize; +static ID id_offset; +static ID id_duration; +static ID id_max_text_tokens; +static ID id_temperature; +static ID id_max_initial_ts; +static ID id_length_penalty; +static ID id_temperature_inc; +static ID id_entropy_thold; +static ID id_logprob_thold; +static ID id_no_speech_thold; +static ID id_new_segment_callback; +static ID id_new_segment_callback_user_data; +static ID id_progress_callback; +static ID id_progress_callback_user_data; +static ID id_abort_callback; +static ID id_abort_callback_user_data; + +static void +rb_whisper_callbcack_container_mark(ruby_whisper_callback_container *rwc) +{ + rb_gc_mark(rwc->user_data); + rb_gc_mark(rwc->callback); + rb_gc_mark(rwc->callbacks); +} + +static ruby_whisper_callback_container* +rb_whisper_callback_container_allocate() { + ruby_whisper_callback_container *container; + container = ALLOC(ruby_whisper_callback_container); + container->context = NULL; + container->user_data = Qnil; + container->callback = Qnil; + container->callbacks = rb_ary_new(); + return container; +} + +static void new_segment_callback(struct whisper_context *ctx, struct whisper_state *state, int n_new, void *user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + + // Currently, doesn't support state because + // those require to resolve GC-related problems. + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, INT2NUM(n_new), container->user_data); + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return; + } + const int n_segments = whisper_full_n_segments_from_state(state); + for (int i = n_new; i > 0; i--) { + int i_segment = n_segments - i; + VALUE segment = rb_whisper_segment_initialize(*container->context, i_segment); + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, segment); + } + } +} + +static void progress_callback(struct whisper_context *ctx, struct whisper_state *state, int progress_cur, void *user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + const VALUE progress = INT2NUM(progress_cur); + // Currently, doesn't support state because + // those require to resolve GC-related problems. + if (!NIL_P(container->callback)) { + rb_funcall(container->callback, id_call, 4, *container->context, Qnil, progress, container->user_data); + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + rb_funcall(cb, id_call, 1, progress); + } +} + +static bool abort_callback(void * user_data) { + const ruby_whisper_callback_container *container = (ruby_whisper_callback_container *)user_data; + if (!NIL_P(container->callback)) { + VALUE result = rb_funcall(container->callback, id_call, 1, container->user_data); + if (!NIL_P(result) && Qfalse != result) { + return true; + } + } + const long callbacks_len = RARRAY_LEN(container->callbacks); + if (0 == callbacks_len) { + return false; + } + for (int j = 0; j < callbacks_len; j++) { + VALUE cb = rb_ary_entry(container->callbacks, j); + VALUE result = rb_funcall(cb, id_call, 1, container->user_data); + if (!NIL_P(result) && Qfalse != result) { + return true; + } + } + return false; +} + +void register_callbacks(ruby_whisper_params * rwp, VALUE * context) { + if (!NIL_P(rwp->new_segment_callback_container->callback) || 0 != RARRAY_LEN(rwp->new_segment_callback_container->callbacks)) { + rwp->new_segment_callback_container->context = context; + rwp->params.new_segment_callback = new_segment_callback; + rwp->params.new_segment_callback_user_data = rwp->new_segment_callback_container; + } + + if (!NIL_P(rwp->progress_callback_container->callback) || 0 != RARRAY_LEN(rwp->progress_callback_container->callbacks)) { + rwp->progress_callback_container->context = context; + rwp->params.progress_callback = progress_callback; + rwp->params.progress_callback_user_data = rwp->progress_callback_container; + } + + if (!NIL_P(rwp->abort_callback_container->callback) || 0 != RARRAY_LEN(rwp->abort_callback_container->callbacks)) { + rwp->abort_callback_container->context = context; + rwp->params.abort_callback = abort_callback; + rwp->params.abort_callback_user_data = rwp->abort_callback_container; + } +} + +void +rb_whisper_params_mark(ruby_whisper_params *rwp) +{ + rb_whisper_callbcack_container_mark(rwp->new_segment_callback_container); + rb_whisper_callbcack_container_mark(rwp->progress_callback_container); + rb_whisper_callbcack_container_mark(rwp->abort_callback_container); +} + +void +ruby_whisper_params_free(ruby_whisper_params *rwp) +{ +} + +void +rb_whisper_params_free(ruby_whisper_params *rwp) +{ + // How to free user_data and callback only when not referred to by others? + ruby_whisper_params_free(rwp); + free(rwp); +} + +static VALUE +ruby_whisper_params_allocate(VALUE klass) +{ + ruby_whisper_params *rwp; + rwp = ALLOC(ruby_whisper_params); + rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + rwp->diarize = false; + rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); + rwp->progress_callback_container = rb_whisper_callback_container_allocate(); + rwp->abort_callback_container = rb_whisper_callback_container_allocate(); + return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); +} + +/* + * params.language = "auto" | "en", etc... + * + * call-seq: + * language = lang_name -> lang_name + */ +static VALUE +ruby_whisper_params_set_language(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (value == Qfalse || value == Qnil) { + rwp->params.language = "auto"; + } else { + rwp->params.language = StringValueCStr(value); + } + return value; +} +/* + * call-seq: + * language -> String + */ +static VALUE +ruby_whisper_params_get_language(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (rwp->params.language) { + return rb_str_new2(rwp->params.language); + } else { + return rb_str_new2("auto"); + } +} +/* + * call-seq: + * translate = do_translate -> do_translate + */ +static VALUE +ruby_whisper_params_set_translate(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, translate, value) +} +/* + * call-seq: + * translate -> bool + */ +static VALUE +ruby_whisper_params_get_translate(VALUE self) +{ + BOOL_PARAMS_GETTER(self, translate) +} +/* + * call-seq: + * no_context = dont_use_context -> dont_use_context + */ +static VALUE +ruby_whisper_params_set_no_context(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, no_context, value) +} +/* + * If true, does not use past transcription (if any) as initial prompt for the decoder. + * + * call-seq: + * no_context -> bool + */ +static VALUE +ruby_whisper_params_get_no_context(VALUE self) +{ + BOOL_PARAMS_GETTER(self, no_context) +} +/* + * call-seq: + * single_segment = force_single -> force_single + */ +static VALUE +ruby_whisper_params_set_single_segment(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, single_segment, value) +} +/* + * If true, forces single segment output (useful for streaming). + * + * call-seq: + * single_segment -> bool + */ +static VALUE +ruby_whisper_params_get_single_segment(VALUE self) +{ + BOOL_PARAMS_GETTER(self, single_segment) +} +/* + * call-seq: + * print_special = force_print -> force_print + */ +static VALUE +ruby_whisper_params_set_print_special(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, print_special, value) +} +/* + * If true, prints special tokens (e.g. , , , etc.). + * + * call-seq: + * print_special -> bool + */ +static VALUE +ruby_whisper_params_get_print_special(VALUE self) +{ + BOOL_PARAMS_GETTER(self, print_special) +} +/* + * call-seq: + * print_progress = force_print -> force_print + */ +static VALUE +ruby_whisper_params_set_print_progress(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, print_progress, value) +} +/* + * If true, prints progress information. + * + * call-seq: + * print_progress -> bool + */ +static VALUE +ruby_whisper_params_get_print_progress(VALUE self) +{ + BOOL_PARAMS_GETTER(self, print_progress) +} +/* + * call-seq: + * print_realtime = force_print -> force_print + */ +static VALUE +ruby_whisper_params_set_print_realtime(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, print_realtime, value) +} +/* + * If true, prints results from within whisper.cpp. (avoid it, use callback instead) + * call-seq: + * print_realtime -> bool + */ +static VALUE +ruby_whisper_params_get_print_realtime(VALUE self) +{ + BOOL_PARAMS_GETTER(self, print_realtime) +} +/* + * call-seq: + * print_timestamps = force_print -> force_print + */ +static VALUE +ruby_whisper_params_set_print_timestamps(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, print_timestamps, value) +} +/* + * If true, prints timestamps for each text segment when printing realtime. + * + * call-seq: + * print_timestamps -> bool + */ +static VALUE +ruby_whisper_params_get_print_timestamps(VALUE self) +{ + BOOL_PARAMS_GETTER(self, print_timestamps) +} +/* + * call-seq: + * suppress_blank = force_suppress -> force_suppress + */ +static VALUE +ruby_whisper_params_set_suppress_blank(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, suppress_blank, value) +} +/* + * If true, suppresses blank outputs. + * + * call-seq: + * suppress_blank -> bool + */ +static VALUE +ruby_whisper_params_get_suppress_blank(VALUE self) +{ + BOOL_PARAMS_GETTER(self, suppress_blank) +} +/* + * call-seq: + * suppress_nst = force_suppress -> force_suppress + */ +static VALUE +ruby_whisper_params_set_suppress_nst(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, suppress_nst, value) +} +/* + * If true, suppresses non-speech-tokens. + * + * call-seq: + * suppress_nst -> bool + */ +static VALUE +ruby_whisper_params_get_suppress_nst(VALUE self) +{ + BOOL_PARAMS_GETTER(self, suppress_nst) +} +/* + * If true, enables token-level timestamps. + * + * call-seq: + * token_timestamps -> bool + */ +static VALUE +ruby_whisper_params_get_token_timestamps(VALUE self) +{ + BOOL_PARAMS_GETTER(self, token_timestamps) +} +/* + * call-seq: + * token_timestamps = force_timestamps -> force_timestamps + */ +static VALUE +ruby_whisper_params_set_token_timestamps(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, token_timestamps, value) +} +/* + * If true, split on word rather than on token (when used with max_len). + * + * call-seq: + * translate -> bool + */ +static VALUE +ruby_whisper_params_get_split_on_word(VALUE self) +{ + BOOL_PARAMS_GETTER(self, split_on_word) +} +/* + * call-seq: + * split_on_word = force_split -> force_split + */ +static VALUE +ruby_whisper_params_set_split_on_word(VALUE self, VALUE value) +{ + BOOL_PARAMS_SETTER(self, split_on_word, value) +} +/* + * Tokens to provide to the whisper decoder as initial prompt + * these are prepended to any existing text context from a previous call + * use whisper_tokenize() to convert text to tokens. + * Maximum of whisper_n_text_ctx()/2 tokens are used (typically 224). + * + * call-seq: + * initial_prompt -> String + */ +static VALUE +ruby_whisper_params_get_initial_prompt(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return rwp->params.initial_prompt == NULL ? Qnil : rb_str_new2(rwp->params.initial_prompt); +} +/* + * call-seq: + * initial_prompt = prompt -> prompt + */ +static VALUE +ruby_whisper_params_set_initial_prompt(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.initial_prompt = StringValueCStr(value); + return value; +} +/* + * If true, enables diarization. + * + * call-seq: + * diarize -> bool + */ +static VALUE +ruby_whisper_params_get_diarize(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (rwp->diarize) { + return Qtrue; + } else { + return Qfalse; + } +} +/* + * call-seq: + * diarize = force_diarize -> force_diarize + */ +static VALUE +ruby_whisper_params_set_diarize(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + if (value == Qfalse || value == Qnil) { + rwp->diarize = false; + } else { + rwp->diarize = true; + } \ + return value; +} + +/* + * Start offset in ms. + * + * call-seq: + * offset -> Integer + */ +static VALUE +ruby_whisper_params_get_offset(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params.offset_ms); +} +/* + * call-seq: + * offset = offset_ms -> offset_ms + */ +static VALUE +ruby_whisper_params_set_offset(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.offset_ms = NUM2INT(value); + return value; +} +/* + * Audio duration to process in ms. + * + * call-seq: + * duration -> Integer + */ +static VALUE +ruby_whisper_params_get_duration(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params.duration_ms); +} +/* + * call-seq: + * duration = duration_ms -> duration_ms + */ +static VALUE +ruby_whisper_params_set_duration(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.duration_ms = NUM2INT(value); + return value; +} + +/* + * Max tokens to use from past text as prompt for the decoder. + * + * call-seq: + * max_text_tokens -> Integer + */ +static VALUE +ruby_whisper_params_get_max_text_tokens(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return INT2NUM(rwp->params.n_max_text_ctx); +} +/* + * call-seq: + * max_text_tokens = n_tokens -> n_tokens + */ +static VALUE +ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.n_max_text_ctx = NUM2INT(value); + return value; +} +/* + * call-seq: + * temperature -> Float + */ +static VALUE +ruby_whisper_params_get_temperature(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.temperature); +} +/* + * call-seq: + * temperature = temp -> temp + */ +static VALUE +ruby_whisper_params_set_temperature(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.temperature = RFLOAT_VALUE(value); + return value; +} +/* + * See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + * + * call-seq: + * max_initial_ts -> Flaot + */ +static VALUE +ruby_whisper_params_get_max_initial_ts(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.max_initial_ts); +} +/* + * call-seq: + * max_initial_ts = timestamp -> timestamp + */ +static VALUE +ruby_whisper_params_set_max_initial_ts(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.max_initial_ts = RFLOAT_VALUE(value); + return value; +} +/* + * call-seq: + * length_penalty -> Float + */ +static VALUE +ruby_whisper_params_get_length_penalty(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.length_penalty); +} +/* + * call-seq: + * length_penalty = penalty -> penalty + */ +static VALUE +ruby_whisper_params_set_length_penalty(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.length_penalty = RFLOAT_VALUE(value); + return value; +} +/* + * call-seq: + * temperature_inc -> Float + */ +static VALUE +ruby_whisper_params_get_temperature_inc(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.temperature_inc); +} +/* + * call-seq: + * temperature_inc = inc -> inc + */ +static VALUE +ruby_whisper_params_set_temperature_inc(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.temperature_inc = RFLOAT_VALUE(value); + return value; +} +/* + * Similar to OpenAI's "compression_ratio_threshold" + * + * call-seq: + * entropy_thold -> Float + */ +static VALUE +ruby_whisper_params_get_entropy_thold(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.entropy_thold); +} +/* + * call-seq: + * entropy_thold = threshold -> threshold + */ +static VALUE +ruby_whisper_params_set_entropy_thold(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.entropy_thold = RFLOAT_VALUE(value); + return value; +} +/* + * call-seq: + * logprob_thold -> Float + */ +static VALUE +ruby_whisper_params_get_logprob_thold(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.logprob_thold); +} +/* + * call-seq: + * logprob_thold = threshold -> threshold + */ +static VALUE +ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.logprob_thold = RFLOAT_VALUE(value); + return value; +} +/* + * call-seq: + * no_speech_thold -> Float + */ +static VALUE +ruby_whisper_params_get_no_speech_thold(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.no_speech_thold); +} +/* + * call-seq: + * no_speech_thold = threshold -> threshold + */ +static VALUE +ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.no_speech_thold = RFLOAT_VALUE(value); + return value; +} +static VALUE +ruby_whisper_params_get_new_segment_callback(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return rwp->new_segment_callback_container->callback; +} +/* + * Sets new segment callback, called for every newly generated text segment. + * + * params.new_segment_callback = ->(context, _, n_new, user_data) { + * # ... + * } + * + * call-seq: + * new_segment_callback = callback -> callback + */ +static VALUE +ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->new_segment_callback_container->callback = value; + return value; +} +static VALUE +ruby_whisper_params_get_new_segment_callback_user_data(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return rwp->new_segment_callback_container->user_data; +} +/* + * Sets user data passed to the last argument of new segment callback. + * + * call-seq: + * new_segment_callback_user_data = user_data -> use_data + */ +static VALUE +ruby_whisper_params_set_new_segment_callback_user_data(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->new_segment_callback_container->user_data = value; + return value; +} +static VALUE +ruby_whisper_params_get_progress_callback(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return rwp->progress_callback_container->callback; +} +/* + * Sets progress callback, called on each progress update. + * + * params.new_segment_callback = ->(context, _, progress, user_data) { + * # ... + * } + * + * +progress+ is an Integer between 0 and 100. + * + * call-seq: + * progress_callback = callback -> callback + */ +static VALUE +ruby_whisper_params_set_progress_callback(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->progress_callback_container->callback = value; + return value; +} +static VALUE +ruby_whisper_params_get_progress_callback_user_data(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return rwp->progress_callback_container->user_data; +} +/* + * Sets user data passed to the last argument of progress callback. + * + * call-seq: + * progress_callback_user_data = user_data -> use_data + */ +static VALUE +ruby_whisper_params_set_progress_callback_user_data(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->progress_callback_container->user_data = value; + return value; +} +static VALUE +ruby_whisper_params_get_abort_callback(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return rwp->abort_callback_container->callback; +} +/* + * Sets abort callback, called to check if the process should be aborted. + * + * params.abort_callback = ->(user_data) { + * # ... + * } + * + * call-seq: + * abort_callback = callback -> callback + */ +static VALUE +ruby_whisper_params_set_abort_callback(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->abort_callback_container->callback = value; + return value; +} +static VALUE +ruby_whisper_params_get_abort_callback_user_data(VALUE self) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return rwp->abort_callback_container->user_data; +} +/* + * Sets user data passed to the last argument of abort callback. + * + * call-seq: + * abort_callback_user_data = user_data -> use_data + */ +static VALUE +ruby_whisper_params_set_abort_callback_user_data(VALUE self, VALUE value) +{ + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->abort_callback_container->user_data = value; + return value; +} + +#define SET_PARAM_IF_SAME(param_name) \ + if (id == id_ ## param_name) { \ + ruby_whisper_params_set_ ## param_name(self, value); \ + continue; \ + } + +static VALUE +ruby_whisper_params_initialize(int argc, VALUE *argv, VALUE self) +{ + + VALUE kw_hash; + VALUE values[RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT] = {Qundef}; + VALUE value; + ruby_whisper_params *rwp; + ID id; + int i; + + rb_scan_args_kw(RB_SCAN_ARGS_KEYWORDS, argc, argv, ":", &kw_hash); + if (NIL_P(kw_hash)) { + return self; + } + + rb_get_kwargs(kw_hash, ¶m_names, 0, RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT, &values); + Data_Get_Struct(self, ruby_whisper_params, rwp); + + for (i = 0; i < RUBY_WHISPER_PARAMS_PARAM_NAMES_COUNT; i++) { + id = param_names[i]; + value = values[i]; + if (value == Qundef) { + continue; + } + if (id == id_diarize) { + rwp->diarize = value; + continue; + } else { + SET_PARAM_IF_SAME(language) + SET_PARAM_IF_SAME(translate) + SET_PARAM_IF_SAME(no_context) + SET_PARAM_IF_SAME(single_segment) + SET_PARAM_IF_SAME(print_special) + SET_PARAM_IF_SAME(print_progress) + SET_PARAM_IF_SAME(print_realtime) + SET_PARAM_IF_SAME(print_timestamps) + SET_PARAM_IF_SAME(suppress_blank) + SET_PARAM_IF_SAME(suppress_nst) + SET_PARAM_IF_SAME(token_timestamps) + SET_PARAM_IF_SAME(split_on_word) + SET_PARAM_IF_SAME(initial_prompt) + SET_PARAM_IF_SAME(offset) + SET_PARAM_IF_SAME(duration) + SET_PARAM_IF_SAME(max_text_tokens) + SET_PARAM_IF_SAME(temperature) + SET_PARAM_IF_SAME(max_initial_ts) + SET_PARAM_IF_SAME(length_penalty) + SET_PARAM_IF_SAME(temperature_inc) + SET_PARAM_IF_SAME(entropy_thold) + SET_PARAM_IF_SAME(logprob_thold) + SET_PARAM_IF_SAME(no_speech_thold) + SET_PARAM_IF_SAME(new_segment_callback) + SET_PARAM_IF_SAME(new_segment_callback_user_data) + SET_PARAM_IF_SAME(progress_callback) + SET_PARAM_IF_SAME(progress_callback_user_data) + SET_PARAM_IF_SAME(abort_callback) + SET_PARAM_IF_SAME(abort_callback_user_data) + } + } + + return self; +} + +#undef SET_PARAM_IF_SAME + +/* + * Hook called on new segment. Yields each Whisper::Segment. + * + * whisper.on_new_segment do |segment| + * # ... + * end + * + * call-seq: + * on_new_segment {|segment| ... } + */ +static VALUE +ruby_whisper_params_on_new_segment(VALUE self) +{ + ruby_whisper_params *rws; + Data_Get_Struct(self, ruby_whisper_params, rws); + const VALUE blk = rb_block_proc(); + rb_ary_push(rws->new_segment_callback_container->callbacks, blk); + return Qnil; +} + +/* + * Hook called on progress update. Yields each progress Integer between 0 and 100. + * + * whisper.on_progress do |progress| + * # ... + * end + * + * call-seq: + * on_progress {|progress| ... } + */ +static VALUE +ruby_whisper_params_on_progress(VALUE self) +{ + ruby_whisper_params *rws; + Data_Get_Struct(self, ruby_whisper_params, rws); + const VALUE blk = rb_block_proc(); + rb_ary_push(rws->progress_callback_container->callbacks, blk); + return Qnil; +} + +/* + * Call block to determine whether abort or not. Return +true+ when you want to abort. + * + * params.abort_on do + * if some_condition + * true # abort + * else + * false # continue + * end + * end + * + * call-seq: + * abort_on { ... } + */ +static VALUE +ruby_whisper_params_abort_on(VALUE self) +{ + ruby_whisper_params *rws; + Data_Get_Struct(self, ruby_whisper_params, rws); + const VALUE blk = rb_block_proc(); + rb_ary_push(rws->abort_callback_container->callbacks, blk); + return Qnil; +} + +void +init_ruby_whisper_params(VALUE *mWhisper) +{ + cParams = rb_define_class_under(*mWhisper, "Params", rb_cObject); + + rb_define_alloc_func(cParams, ruby_whisper_params_allocate); + rb_define_method(cParams, "initialize", ruby_whisper_params_initialize, -1); + + DEFINE_PARAM(language, 0) + DEFINE_PARAM(translate, 1) + DEFINE_PARAM(no_context, 2) + DEFINE_PARAM(single_segment, 3) + DEFINE_PARAM(print_special, 4) + DEFINE_PARAM(print_progress, 5) + DEFINE_PARAM(print_realtime, 6) + DEFINE_PARAM(print_timestamps, 7) + DEFINE_PARAM(suppress_blank, 8) + DEFINE_PARAM(suppress_nst, 9) + DEFINE_PARAM(token_timestamps, 10) + DEFINE_PARAM(split_on_word, 11) + DEFINE_PARAM(initial_prompt, 12) + DEFINE_PARAM(diarize, 13) + DEFINE_PARAM(offset, 14) + DEFINE_PARAM(duration, 15) + DEFINE_PARAM(max_text_tokens, 16) + DEFINE_PARAM(temperature, 17) + DEFINE_PARAM(max_initial_ts, 18) + DEFINE_PARAM(length_penalty, 19) + DEFINE_PARAM(temperature_inc, 20) + DEFINE_PARAM(entropy_thold, 21) + DEFINE_PARAM(logprob_thold, 22) + DEFINE_PARAM(no_speech_thold, 23) + DEFINE_PARAM(new_segment_callback, 24) + DEFINE_PARAM(new_segment_callback_user_data, 25) + DEFINE_PARAM(progress_callback, 26) + DEFINE_PARAM(progress_callback_user_data, 27) + DEFINE_PARAM(abort_callback, 28) + DEFINE_PARAM(abort_callback_user_data, 29) + + rb_define_method(cParams, "on_new_segment", ruby_whisper_params_on_new_segment, 0); + rb_define_method(cParams, "on_progress", ruby_whisper_params_on_progress, 0); + rb_define_method(cParams, "abort_on", ruby_whisper_params_abort_on, 0); +} diff --git a/bindings/ruby/ext/ruby_whisper_segment.c b/bindings/ruby/ext/ruby_whisper_segment.c new file mode 100644 index 00000000..3440ff95 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_segment.c @@ -0,0 +1,123 @@ +#include +#include "ruby_whisper.h" + +extern VALUE cSegment; + +static void +rb_whisper_segment_mark(ruby_whisper_segment *rws) +{ + rb_gc_mark(rws->context); +} + +VALUE +ruby_whisper_segment_allocate(VALUE klass) +{ + ruby_whisper_segment *rws; + rws = ALLOC(ruby_whisper_segment); + return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws); +} + +VALUE +rb_whisper_segment_initialize(VALUE context, int index) +{ + ruby_whisper_segment *rws; + const VALUE segment = ruby_whisper_segment_allocate(cSegment); + Data_Get_Struct(segment, ruby_whisper_segment, rws); + rws->context = context; + rws->index = index; + return segment; +}; + +/* + * Start time in milliseconds. + * + * call-seq: + * start_time -> Integer + */ +static VALUE +ruby_whisper_segment_get_start_time(VALUE self) +{ + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index); + // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it + return INT2NUM(t0 * 10); +} + +/* + * End time in milliseconds. + * + * call-seq: + * end_time -> Integer + */ +static VALUE +ruby_whisper_segment_get_end_time(VALUE self) +{ + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index); + // able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it + return INT2NUM(t1 * 10); +} + +/* + * Whether the next segment is predicted as a speaker turn. + * + * call-seq: + * speaker_turn_next? -> bool + */ +static VALUE +ruby_whisper_segment_get_speaker_turn_next(VALUE self) +{ + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse; +} + +/* + * call-seq: + * text -> String + */ +static VALUE +ruby_whisper_segment_get_text(VALUE self) +{ + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + const char * text = whisper_full_get_segment_text(rw->context, rws->index); + return rb_str_new2(text); +} + +/* + * call-seq: + * no_speech_prob -> Float + */ +static VALUE +ruby_whisper_segment_get_no_speech_prob(VALUE self) +{ + ruby_whisper_segment *rws; + Data_Get_Struct(self, ruby_whisper_segment, rws); + ruby_whisper *rw; + Data_Get_Struct(rws->context, ruby_whisper, rw); + return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index)); +} + +void +init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cContext) +{ + cSegment = rb_define_class_under(*mWhisper, "Segment", rb_cObject); + + rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate); + rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0); + rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0); + rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0); + rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0); + rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0); +} diff --git a/bindings/ruby/ext/ruby_whisper_transcribe.cpp b/bindings/ruby/ext/ruby_whisper_transcribe.cpp new file mode 100644 index 00000000..d50ed063 --- /dev/null +++ b/bindings/ruby/ext/ruby_whisper_transcribe.cpp @@ -0,0 +1,159 @@ +#include +#include "ruby_whisper.h" +#define DR_WAV_IMPLEMENTATION +#include "dr_wav.h" +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +extern ID id_to_s; +extern ID id_call; + +extern void +register_callbacks(ruby_whisper_params * rwp, VALUE * self); + +/* + * transcribe a single file + * can emit to a block results + * + * params = Whisper::Params.new + * params.duration = 60_000 + * whisper.transcribe "path/to/audio.wav", params do |text| + * puts text + * end + * + * call-seq: + * transcribe(path_to_audio, params) {|text| ...} + **/ +VALUE +ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { + ruby_whisper *rw; + ruby_whisper_params *rwp; + VALUE wave_file_path, blk, params; + + rb_scan_args(argc, argv, "02&", &wave_file_path, ¶ms, &blk); + Data_Get_Struct(self, ruby_whisper, rw); + Data_Get_Struct(params, ruby_whisper_params, rwp); + + if (!rb_respond_to(wave_file_path, id_to_s)) { + rb_raise(rb_eRuntimeError, "Expected file path to wave file"); + } + + std::string fname_inp = StringValueCStr(wave_file_path); + + std::vector pcmf32; // mono-channel F32 PCM + std::vector> pcmf32s; // stereo-channel F32 PCM + + // WAV input - this is directly from main.cpp example + { + drwav wav; + std::vector wav_data; // used for pipe input from stdin + + if (fname_inp == "-") { + { + uint8_t buf[1024]; + while (true) { + const size_t n = fread(buf, 1, sizeof(buf), stdin); + if (n == 0) { + break; + } + wav_data.insert(wav_data.end(), buf, buf + n); + } + } + + if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) { + fprintf(stderr, "error: failed to open WAV file from stdin\n"); + return self; + } + + fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size()); + } else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) { + fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str()); + return self; + } + + if (wav.channels != 1 && wav.channels != 2) { + fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str()); + return self; + } + + if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) { + fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str()); + return self; + } + + if (wav.sampleRate != WHISPER_SAMPLE_RATE) { + fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000); + return self; + } + + if (wav.bitsPerSample != 16) { + fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str()); + return self; + } + + const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8); + + std::vector pcm16; + pcm16.resize(n*wav.channels); + drwav_read_pcm_frames_s16(&wav, n, pcm16.data()); + drwav_uninit(&wav); + + // convert to mono, float + pcmf32.resize(n); + if (wav.channels == 1) { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float(pcm16[i])/32768.0f; + } + } else { + for (uint64_t i = 0; i < n; i++) { + pcmf32[i] = float((int32_t)pcm16[2*i] + pcm16[2*i + 1])/65536.0f; + } + } + + if (rwp->diarize) { + // convert to stereo, float + pcmf32s.resize(2); + + pcmf32s[0].resize(n); + pcmf32s[1].resize(n); + for (uint64_t i = 0; i < n; i++) { + pcmf32s[0][i] = float(pcm16[2*i])/32768.0f; + pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f; + } + } + } + { + static bool is_aborted = false; // NOTE: this should be atomic to avoid data race + + rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) { + bool is_aborted = *(bool*)user_data; + return !is_aborted; + }; + rwp->params.encoder_begin_callback_user_data = &is_aborted; + } + + register_callbacks(rwp, &self); + + if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { + fprintf(stderr, "failed to process audio\n"); + return self; + } + const int n_segments = whisper_full_n_segments(rw->context); + VALUE output = rb_str_new2(""); + for (int i = 0; i < n_segments; ++i) { + const char * text = whisper_full_get_segment_text(rw->context, i); + output = rb_str_concat(output, rb_str_new2(text)); + } + VALUE idCall = id_call; + if (blk != Qnil) { + rb_funcall(blk, idCall, 1, output); + } + return self; +} +#ifdef __cplusplus +} +#endif diff --git a/bindings/ruby/lib/whisper/model/uri.rb b/bindings/ruby/lib/whisper/model/uri.rb index b43d90dd..ce19f715 100644 --- a/bindings/ruby/lib/whisper/model/uri.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -65,6 +65,13 @@ module Whisper end end end + rescue => err + if cache_path.exist? + warn err + # Use cache file + else + raise + end end def download(response) diff --git a/bindings/ruby/sig/whisper.rbs b/bindings/ruby/sig/whisper.rbs index aff2ae73..85d941cb 100644 --- a/bindings/ruby/sig/whisper.rbs +++ b/bindings/ruby/sig/whisper.rbs @@ -20,13 +20,12 @@ module Whisper def self.lang_id: (string name) -> Integer def self.lang_str: (Integer id) -> String def self.lang_str_full: (Integer id) -> String - def self.log_set=: (log_callback) -> log_callback - def self.finalize_log_callback: (void) -> void # Second argument of ObjectSpace.define_finalizer + def self.log_set: (log_callback, Object? user_data) -> log_callback class Context - def initialize: (string | _ToPath | ::URI::HTTP ) -> void - def transcribe: (string, Params) -> void - | (string, Params) { (String) -> void } -> void + def self.new: (string | _ToPath | ::URI::HTTP) -> instance + def transcribe: (string, Params) -> self + | (string, Params) { (String) -> void } -> self def model_n_vocab: () -> Integer def model_n_audio_ctx: () -> Integer def model_n_audio_state: () -> Integer @@ -35,6 +34,10 @@ module Whisper def model_n_mels: () -> Integer def model_ftype: () -> Integer def model_type: () -> String + def each_segment: { (Segment) -> void } -> void + | () -> Enumerator[Segment] + def model: () -> Model + def full_get_segment: (Integer nth) -> Segment def full_n_segments: () -> Integer def full_lang_id: () -> Integer def full_get_segment_t0: (Integer) -> Integer @@ -42,18 +45,46 @@ module Whisper def full_get_segment_speaker_turn_next: (Integer) -> (true | false) def full_get_segment_text: (Integer) -> String def full_get_segment_no_speech_prob: (Integer) -> Float - def full: (Params, Array[Float], ?Integer) -> void - | (Params, _Samples, ?Integer) -> void - def full_parallel: (Params, Array[Float], ?Integer) -> void - | (Params, _Samples, ?Integer) -> void - | (Params, _Samples, ?Integer?, Integer) -> void - def each_segment: { (Segment) -> void } -> void - | () -> Enumerator[Segment] - def model: () -> Model + def full: (Params, Array[Float] samples, ?Integer n_samples) -> self + | (Params, _Samples, ?Integer n_samples) -> self + def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self + | (Params, _Samples, ?Integer n_samples) -> self + | (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self end class Params - def initialize: () -> void + def self.new: ( + ?language: string, + ?translate: boolish, + ?no_context: boolish, + ?single_segment: boolish, + ?print_special: boolish, + ?print_progress: boolish, + ?print_realtime: boolish, + ?print_timestamps: boolish, + ?suppress_blank: boolish, + ?suppress_nst: boolish, + ?token_timestamps: boolish, + ?split_on_word: boolish, + ?initial_prompt: string | nil, + ?diarize: boolish, + ?offset: Integer, + ?duration: Integer, + ?max_text_tokens: Integer, + ?temperature: Float, + ?max_initial_ts: Float, + ?length_penalty: Float, + ?temperature_inc: Float, + ?entropy_thold: Float, + ?logprob_thold: Float, + ?no_speech_thold: Float, + ?new_segment_callback: new_segment_callback, + ?new_segment_callback_user_data: Object, + ?progress_callback: progress_callback, + ?progress_callback_user_data: Object, + ?abort_callback: abort_callback, + ?abort_callback_user_data: Object + ) -> instance def language=: (String) -> String # TODO: Enumerate lang names def language: () -> String def translate=: (boolish) -> boolish @@ -79,7 +110,7 @@ module Whisper def split_on_word=: (boolish) -> boolish def split_on_word: () -> (true | false) def initial_prompt=: (_ToS) -> _ToS - def initial_prompt: () -> String + def initial_prompt: () -> (String | nil) def diarize=: (boolish) -> boolish def diarize: () -> (true | false) def offset=: (Integer) -> Integer @@ -103,19 +134,25 @@ module Whisper def no_speech_thold=: (Float) -> Float def no_speech_thold: () -> Float def new_segment_callback=: (new_segment_callback) -> new_segment_callback + def new_segment_callback: () -> (new_segment_callback | nil) def new_segment_callback_user_data=: (Object) -> Object + def new_segment_callback_user_data: () -> Object def progress_callback=: (progress_callback) -> progress_callback + def progress_callback: () -> (progress_callback | nil) def progress_callback_user_data=: (Object) -> Object + def progress_callback_user_data: () -> Object def abort_callback=: (abort_callback) -> abort_callback + def abort_callback: () -> (abort_callback | nil) def abort_callback_user_data=: (Object) -> Object + def abort_callback_user_data: () -> Object def on_new_segment: { (Segment) -> void } -> void - def on_progress: { (Integer) -> void } -> void - def abort_on: { (Object) -> boolish } -> void + def on_progress: { (Integer progress) -> void } -> void + def abort_on: { (Object user_data) -> boolish } -> void end class Model def self.pre_converted_models: () -> Hash[String, Model::URI] - def initialize: () -> void + def self.new: () -> instance def n_vocab: () -> Integer def n_audio_ctx: () -> Integer def n_audio_state: () -> Integer @@ -130,14 +167,13 @@ module Whisper def type: () -> String class URI - def initialize: (string | ::URI::HTTP) -> void + def self.new: (string | ::URI::HTTP) -> self def to_path: -> String def clear_cache: -> void end end class Segment - def initialize: () -> void def start_time: () -> Integer def end_time: () -> Integer def speaker_next_turn?: () -> (true | false) @@ -148,6 +184,6 @@ module Whisper class Error < StandardError attr_reader code: Integer - def initialize: (Integer) -> void + def self.new: (Integer code) -> instance end end diff --git a/bindings/ruby/tests/test_params.rb b/bindings/ruby/tests/test_params.rb index 7981bfaa..0cc49433 100644 --- a/bindings/ruby/tests/test_params.rb +++ b/bindings/ruby/tests/test_params.rb @@ -1,6 +1,39 @@ require_relative "helper" class TestParams < TestBase + PARAM_NAMES = [ + :language, + :translate, + :no_context, + :single_segment, + :print_special, + :print_progress, + :print_realtime, + :print_timestamps, + :suppress_blank, + :suppress_nst, + :token_timestamps, + :split_on_word, + :initial_prompt, + :diarize, + :offset, + :duration, + :max_text_tokens, + :temperature, + :max_initial_ts, + :length_penalty, + :temperature_inc, + :entropy_thold, + :logprob_thold, + :no_speech_thold, + :new_segment_callback, + :new_segment_callback_user_data, + :progress_callback, + :progress_callback_user_data, + :abort_callback, + :abort_callback_user_data, + ] + def setup @params = Whisper::Params.new end @@ -157,4 +190,57 @@ class TestParams < TestBase @params.no_speech_thold = 0.2 assert_in_delta 0.2, @params.no_speech_thold end + + def test_new_with_kw_args + params = Whisper::Params.new(language: "es") + assert_equal "es", params.language + assert_equal 1.0, params.max_initial_ts + end + + def test_new_with_kw_args_non_existent + assert_raise ArgumentError do + Whisper::Params.new(non_existent: "value") + end + end + + def test_new_with_kw_args_wrong_type + assert_raise TypeError do + Whisper::Params.new(language: 3) + end + end + + data(PARAM_NAMES.collect {|param| [param, param]}.to_h) + def test_new_with_kw_args_default_values(param) + default_value = @params.send(param) + value = case [param, default_value] + in [*, true | false] + !default_value + in [*, Integer | Float] + default_value + 1 + in [:language, *] + "es" + in [:initial_prompt, *] + "Initial prompt" + in [/_callback\Z/, *] + proc {} + in [/_user_data\Z/, *] + Object.new + end + params = Whisper::Params.new(param => value) + if Float === value + assert_in_delta value, params.send(param) + else + assert_equal value, params.send(param) + end + + PARAM_NAMES.reject {|name| name == param}.each do |name| + expected = @params.send(name) + actual = params.send(name) + if Float === expected + assert_in_delta expected, actual + else + assert_equal expected, actual + end + end + end end diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 5b0d189e..76b92c73 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -29,6 +29,12 @@ class TestWhisper < TestBase assert_equal 0, whisper.full_lang_id end + def test_full_get_segment + segment = whisper.full_get_segment(0) + assert_equal 0, segment.start_time + assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text + end + def test_full_get_segment_t0 assert_equal 0, whisper.full_get_segment_t0(0) assert_raise IndexError do