ruby : add VAD support, migration to Ruby's newer API (#3197)
Some checks failed
Bindings Tests (Ruby) / ubuntu-22 (push) Has been cancelled
CI / determine-tag (push) Has been cancelled
CI / ubuntu-22 (linux/amd64) (push) Has been cancelled
CI / ubuntu-22 (linux/ppc64le) (push) Has been cancelled
CI / ubuntu-22-arm64 (linux/arm64) (push) Has been cancelled
CI / ubuntu-22-arm-v7 (linux/arm/v7) (push) Has been cancelled
CI / macOS-latest (generic/platform=iOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=macOS) (push) Has been cancelled
CI / macOS-latest (generic/platform=tvOS) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm64 (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Debug) (push) Has been cancelled
CI / ubuntu-22-gcc-arm-v7 (linux/arm/v7, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/amd64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/arm64, Release) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Debug) (push) Has been cancelled
CI / ubuntu-22-clang (linux/ppc64le, Release) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, ADDRESS) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, THREAD) (push) Has been cancelled
CI / ubuntu-22-gcc-sanitized (linux/amd64, UNDEFINED) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/amd64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm/v7, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/arm64, icx, icpx, ON) (push) Has been cancelled
CI / ubuntu-22-cmake-sycl-fp16 (linux/ppc64le, icx, icpx, ON) (push) Has been cancelled
CI / windows-msys2 (Release, clang-x86_64, CLANG64) (push) Has been cancelled
CI / windows-msys2 (Release, ucrt-x86_64, UCRT64) (push) Has been cancelled
CI / windows (Win32, Release, win32-x86, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows (x64, Release, win32-x86-64, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (Win32, ON, x86, 0.3.29, Release, x86, 2.28.5, ON) (push) Has been cancelled
CI / windows-blas (x64, ON, x64_64, 0.3.29, Release, x64, 2.28.5, ON) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 11.8.0, ON, 2.28.5) (push) Has been cancelled
CI / windows-cublas (x64, Release, ON, 12.2.0, ON, 2.28.5) (push) Has been cancelled
CI / emscripten (Release) (push) Has been cancelled
CI / android (push) Has been cancelled
CI / android_java (push) Has been cancelled
CI / quantize (push) Has been cancelled
CI / vad (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main-musa.Dockerfile platform:linux/amd64 tag:main-musa]) (push) Has been cancelled
Publish Docker image / Push Docker image to Docker Hub (map[dockerfile:.devops/main.Dockerfile platform:linux/amd64 tag:main]) (push) Has been cancelled
Examples WASM / deploy-wasm-github-pages (push) Has been cancelled
CI / ios-xcode-build (Release) (push) Has been cancelled
CI / bindings-java (push) Has been cancelled
CI / release (push) Has been cancelled
CI / coreml-base-en (push) Has been cancelled

* Add VAD models

* Extract function to normalize model path from ruby_whisper_initialize()

* Define ruby_whisper_vad_params struct

* Add VAD-related features to Whisper::Params

* Add tests for VAD-related features

* Define Whisper::VADParams

* Add Whisper::VAD::Params attributes

* Add test suite for VAD::Params

* Make older test to follow namespace change

* Add test for transcription with VAD

* Add assertion for test_vad_params

* Add signatures for VAD-related methods

* Define VAD::Params#==

* Add test for VAD::Params#==

* Fix Params#vad_params

* Add test for Params#vad_params

* Fix signature of Params#vad_params

* Use macro to define VAD::Params params

* Define VAD::Params#initialize

* Add tests for VAD::Params#initialize

* Add signature for VAD::Params.new

* Add documentation on VAD in README

* Wrap register_callbask in prepare_transcription for clear meanings

* Set whisper_params.vad_params just before transcription

* Don't touch NULL

* Define ruby_whisper_params_type

* Use TypedData_XXX for ruby_whisper_params instead of Data_XXX

* Remove unused functions

* Define rb_whisper_model_data_type

* Use TypedData_XXX for ruby_whisper_model instead of Data_XXX

* Define ruby_whisper_segment_type

* Use TypedData_XXX for ruby_whisper_segment instead of Data_XXX

* Define ruby_whisper_type

* Use TypedData_XXX for ruby_whisper instead of Data_XXX

* Qualify with const
This commit is contained in:
KITAITI Makoto
2025-05-28 20:05:12 +09:00
committed by GitHub
parent 5720426d97
commit 1f5fdbecb4
14 changed files with 925 additions and 173 deletions

View File

@ -16,10 +16,11 @@ extern VALUE cContext;
extern VALUE eError;
extern VALUE cModel;
extern const rb_data_type_t ruby_whisper_params_type;
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
extern VALUE rb_whisper_model_initialize(VALUE context);
extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
extern void prepare_transcription(ruby_whisper_params *rwp, VALUE *context);
static void
ruby_whisper_free(ruby_whisper *rw)
@ -37,19 +38,64 @@ rb_whisper_mark(ruby_whisper *rw)
}
void
rb_whisper_free(ruby_whisper *rw)
rb_whisper_free(void *p)
{
ruby_whisper *rw = (ruby_whisper *)p;
ruby_whisper_free(rw);
free(rw);
}
static size_t
ruby_whisper_memsize(const void *p)
{
const ruby_whisper *rw = (const ruby_whisper *)p;
size_t size = sizeof(rw);
if (!rw) {
return 0;
}
return size;
}
const rb_data_type_t ruby_whisper_type = {
"ruby_whisper",
{0, rb_whisper_free, ruby_whisper_memsize,},
0, 0,
0
};
static VALUE
ruby_whisper_allocate(VALUE klass)
{
ruby_whisper *rw;
rw = ALLOC(ruby_whisper);
VALUE obj = TypedData_Make_Struct(klass, ruby_whisper, &ruby_whisper_type, rw);
rw->context = NULL;
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
return obj;
}
VALUE
ruby_whisper_normalize_model_path(VALUE model_path)
{
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, model_path);
if (!NIL_P(pre_converted_model)) {
model_path = pre_converted_model;
}
else if (TYPE(model_path) == T_STRING) {
const char * model_path_str = StringValueCStr(model_path);
if (strncmp("http://", model_path_str, 7) == 0 || strncmp("https://", model_path_str, 8) == 0) {
VALUE uri_class = rb_const_get(cModel, id_URI);
model_path = rb_class_new_instance(1, &model_path, uri_class);
}
}
else if (rb_obj_is_kind_of(model_path, rb_path2class("URI::HTTP"))) {
VALUE uri_class = rb_const_get(cModel, id_URI);
model_path = rb_class_new_instance(1, &model_path, uri_class);
}
if (rb_respond_to(model_path, id_to_path)) {
model_path = rb_funcall(model_path, id_to_path, 0);
}
return model_path;
}
/*
@ -66,27 +112,9 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
// TODO: we can support init from buffer here too maybe another ruby object to expose
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
if (!NIL_P(pre_converted_model)) {
whisper_model_file_path = pre_converted_model;
}
if (TYPE(whisper_model_file_path) == T_STRING) {
const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
VALUE uri_class = rb_const_get(cModel, id_URI);
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
}
}
if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
VALUE uri_class = rb_const_get(cModel, id_URI);
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
}
if (rb_respond_to(whisper_model_file_path, id_to_path)) {
whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
}
whisper_model_file_path = ruby_whisper_normalize_model_path(whisper_model_file_path);
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
@ -104,7 +132,7 @@ ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
VALUE ruby_whisper_model_n_vocab(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_n_vocab(rw->context));
}
@ -115,7 +143,7 @@ VALUE ruby_whisper_model_n_vocab(VALUE self)
VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
}
@ -126,7 +154,7 @@ VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
VALUE ruby_whisper_model_n_audio_state(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_n_audio_state(rw->context));
}
@ -137,7 +165,7 @@ VALUE ruby_whisper_model_n_audio_state(VALUE self)
VALUE ruby_whisper_model_n_audio_head(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_n_audio_head(rw->context));
}
@ -148,7 +176,7 @@ VALUE ruby_whisper_model_n_audio_head(VALUE self)
VALUE ruby_whisper_model_n_audio_layer(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_n_audio_layer(rw->context));
}
@ -159,7 +187,7 @@ VALUE ruby_whisper_model_n_audio_layer(VALUE self)
VALUE ruby_whisper_model_n_text_ctx(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_n_text_ctx(rw->context));
}
@ -170,7 +198,7 @@ VALUE ruby_whisper_model_n_text_ctx(VALUE self)
VALUE ruby_whisper_model_n_text_state(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_n_text_state(rw->context));
}
@ -181,7 +209,7 @@ VALUE ruby_whisper_model_n_text_state(VALUE self)
VALUE ruby_whisper_model_n_text_head(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_n_text_head(rw->context));
}
@ -192,7 +220,7 @@ VALUE ruby_whisper_model_n_text_head(VALUE self)
VALUE ruby_whisper_model_n_text_layer(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_n_text_layer(rw->context));
}
@ -203,7 +231,7 @@ VALUE ruby_whisper_model_n_text_layer(VALUE self)
VALUE ruby_whisper_model_n_mels(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_n_mels(rw->context));
}
@ -214,7 +242,7 @@ VALUE ruby_whisper_model_n_mels(VALUE self)
VALUE ruby_whisper_model_ftype(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_model_ftype(rw->context));
}
@ -225,7 +253,7 @@ VALUE ruby_whisper_model_ftype(VALUE self)
VALUE ruby_whisper_model_type(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return rb_str_new2(whisper_model_type_readable(rw->context));
}
@ -248,9 +276,9 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
ruby_whisper *rw;
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
VALUE params = argv[0];
Data_Get_Struct(params, ruby_whisper_params, rwp);
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
VALUE samples = argv[1];
int n_samples;
rb_memory_view_t view;
@ -296,7 +324,7 @@ VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
}
}
}
register_callbacks(rwp, &self);
prepare_transcription(rwp, &self);
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
if (0 == result) {
return self;
@ -327,9 +355,9 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
ruby_whisper *rw;
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
VALUE params = argv[0];
Data_Get_Struct(params, ruby_whisper_params, rwp);
TypedData_Get_Struct(params, ruby_whisper_params, &ruby_whisper_params_type, rwp);
VALUE samples = argv[1];
int n_samples;
int n_processors;
@ -387,7 +415,7 @@ ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
}
}
}
register_callbacks(rwp, &self);
prepare_transcription(rwp, &self);
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
if (0 == result) {
return self;
@ -406,7 +434,7 @@ static VALUE
ruby_whisper_full_n_segments(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_full_n_segments(rw->context));
}
@ -420,7 +448,7 @@ static VALUE
ruby_whisper_full_lang_id(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
return INT2NUM(whisper_full_lang_id(rw->context));
}
@ -445,7 +473,7 @@ static VALUE
ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
return INT2NUM(t0);
@ -463,7 +491,7 @@ static VALUE
ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
return INT2NUM(t1);
@ -481,7 +509,7 @@ static VALUE
ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
return speaker_turn_next ? Qtrue : Qfalse;
@ -499,7 +527,7 @@ static VALUE
ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
return rb_str_new2(text);
@ -513,7 +541,7 @@ static VALUE
ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
return DBL2NUM(no_speech_prob);
@ -554,7 +582,7 @@ ruby_whisper_each_segment(VALUE self)
}
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
TypedData_Get_Struct(self, ruby_whisper, &ruby_whisper_type, rw);
const int n_segments = whisper_full_n_segments(rw->context);
for (int i = 0; i < n_segments; ++i) {