diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 9d933453..96c43520 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -73,6 +73,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; rwp = ALLOC(ruby_whisper_params); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + rwp->new_segment_callback = Qnil; return Data_Wrap_Struct(klass, rb_whisper_params_mark, rb_whisper_params_free, rwp); } @@ -205,6 +206,28 @@ static VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) { }; rwp->params.encoder_begin_callback_user_data = &is_aborted; } + { + // This cannot be used later because it is not incremented when new_segment_callback is not given. + static int n_segments = 0; + + rwp->params.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) { + VALUE callback = *(VALUE *)user_data; + if (NIL_P(callback)){ + return; + } + + for (int i = 0; i < n_new; i++) { + const int i_segment = n_segments + i; + const char * text = whisper_full_get_segment_text_from_state(state, i_segment); + // Multiplying 10 shouldn't cause overflow because to_timestamp() in whisper.cpp does it + const int64_t t0 = whisper_full_get_segment_t0_from_state(state, i_segment) * 10; + const int64_t t1 = whisper_full_get_segment_t1_from_state(state, i_segment) * 10; + rb_funcall(callback, rb_intern("call"), 4, rb_str_new2(text), INT2NUM(t0), INT2NUM(t1), INT2FIX(i_segment)); + } + n_segments += n_new; + }; + rwp->params.new_segment_callback_user_data = &rwp->new_segment_callback; + } if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) { fprintf(stderr, "failed to process audio\n"); @@ -365,6 +388,12 @@ static VALUE ruby_whisper_params_set_max_text_tokens(VALUE self, VALUE value) { rwp->params.n_max_text_ctx = NUM2INT(value); return value; } +static VALUE ruby_whisper_params_set_new_segment_callback(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->new_segment_callback = value; + return value; +} void Init_whisper() { mWhisper = rb_define_module("Whisper"); @@ -412,6 +441,8 @@ void Init_whisper() { rb_define_method(cParams, "max_text_tokens", ruby_whisper_params_get_max_text_tokens, 0); rb_define_method(cParams, "max_text_tokens=", ruby_whisper_params_set_max_text_tokens, 1); + + rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1); } #ifdef __cplusplus } diff --git a/bindings/ruby/ext/ruby_whisper.h b/bindings/ruby/ext/ruby_whisper.h index 8c35b7cb..988750a8 100644 --- a/bindings/ruby/ext/ruby_whisper.h +++ b/bindings/ruby/ext/ruby_whisper.h @@ -10,6 +10,7 @@ typedef struct { typedef struct { struct whisper_full_params params; bool diarize; + VALUE new_segment_callback; } ruby_whisper_params; #endif