mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2025-06-15 21:38:07 +00:00
ruby : Add low-level methods to transcribe (#2585)
* Add tests for Whisper::Context#full * Add Whisper::Context#full * Add tests for Whisper::Error * Add document of Whisper::Context#full [skip ci] * Add additional signature for Whisper::Context#full * Add description to Whisper::Context#full * Add test for Whisper::Context#full_parallel * Add Whisper::Context#full_parallel * Hide Whisper's instance methods from Ruby code * Add class to test MemoryView * Build test class before running test * Add test for MemoryView * Make Whisper::Context#full and #full_parallel accept MemoryView * Use Ruby 3.1 on CI * Add comment on samples data type * Update README * Update README * Remove unused code
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
#include <ruby.h>
|
||||
#include <ruby/memory_view.h>
|
||||
#include "ruby_whisper.h"
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
@ -35,11 +36,15 @@ extern "C" {
|
||||
VALUE mWhisper;
|
||||
VALUE cContext;
|
||||
VALUE cParams;
|
||||
VALUE eError;
|
||||
|
||||
static ID id_to_s;
|
||||
static ID id_call;
|
||||
static ID id___method__;
|
||||
static ID id_to_enum;
|
||||
static ID id_length;
|
||||
static ID id_next;
|
||||
static ID id_new;
|
||||
|
||||
static bool is_log_callback_finalized = false;
|
||||
|
||||
@ -100,13 +105,13 @@ static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
|
||||
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
|
||||
*/
|
||||
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
|
||||
VALUE old_callback = rb_iv_get(self, "@log_callback");
|
||||
VALUE old_callback = rb_iv_get(self, "log_callback");
|
||||
if (!NIL_P(old_callback)) {
|
||||
rb_undefine_finalizer(old_callback);
|
||||
}
|
||||
|
||||
rb_iv_set(self, "@log_callback", log_callback);
|
||||
rb_iv_set(self, "@user_data", user_data);
|
||||
rb_iv_set(self, "log_callback", log_callback);
|
||||
rb_iv_set(self, "user_data", user_data);
|
||||
|
||||
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
|
||||
rb_define_finalizer(log_callback, finalize_log_callback);
|
||||
@ -115,8 +120,8 @@ static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_d
|
||||
if (is_log_callback_finalized) {
|
||||
return;
|
||||
}
|
||||
VALUE log_callback = rb_iv_get(mWhisper, "@log_callback");
|
||||
VALUE udata = rb_iv_get(mWhisper, "@user_data");
|
||||
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
|
||||
VALUE udata = rb_iv_get(mWhisper, "user_data");
|
||||
rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
|
||||
}, nullptr);
|
||||
|
||||
@ -544,6 +549,168 @@ VALUE ruby_whisper_model_type(VALUE self) {
|
||||
return rb_str_new2(whisper_model_type_readable(rw->context));
|
||||
}
|
||||
|
||||
/*
|
||||
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
* Not thread safe for same context
|
||||
* Uses the specified decoding strategy to obtain the text.
|
||||
*
|
||||
* call-seq:
|
||||
* full(params, samples, n_samples) -> nil
|
||||
* full(params, samples) -> nil
|
||||
*
|
||||
* The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
|
||||
*/
|
||||
VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self) {
|
||||
if (argc < 2 || argc > 3) {
|
||||
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
|
||||
}
|
||||
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
VALUE params = argv[0];
|
||||
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||
VALUE samples = argv[1];
|
||||
int n_samples;
|
||||
rb_memory_view_t view;
|
||||
const bool memory_view_available_p = rb_memory_view_available_p(samples);
|
||||
if (argc == 3) {
|
||||
n_samples = NUM2INT(argv[2]);
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
if (RARRAY_LEN(samples) < n_samples) {
|
||||
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
|
||||
}
|
||||
}
|
||||
// Should check when samples.respond_to?(:length)?
|
||||
} else {
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
n_samples = RARRAY_LEN(samples);
|
||||
} else if (memory_view_available_p) {
|
||||
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
|
||||
view.obj = Qnil;
|
||||
rb_raise(rb_eArgError, "unable to get a memory view");
|
||||
}
|
||||
n_samples = view.byte_size / view.item_size;
|
||||
} else if (rb_respond_to(samples, id_length)) {
|
||||
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
|
||||
} else {
|
||||
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
|
||||
}
|
||||
}
|
||||
float * c_samples = (float *)malloc(n_samples * sizeof(float));
|
||||
if (memory_view_available_p) {
|
||||
c_samples = (float *)view.data;
|
||||
} else {
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
|
||||
}
|
||||
} else {
|
||||
// TODO: use rb_block_call
|
||||
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
// TODO: check if iter is exhausted and raise ArgumentError appropriately
|
||||
VALUE sample = rb_funcall(iter, id_next, 0);
|
||||
c_samples[i] = RFLOAT_VALUE(sample);
|
||||
}
|
||||
}
|
||||
}
|
||||
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
|
||||
if (0 == result) {
|
||||
return Qnil;
|
||||
} else {
|
||||
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
|
||||
* Result is stored in the default state of the context
|
||||
* Not thread safe if executed in parallel on the same context.
|
||||
* It seems this approach can offer some speedup in some cases.
|
||||
* However, the transcription accuracy can be worse at the beginning and end of each chunk.
|
||||
*
|
||||
* call-seq:
|
||||
* full_parallel(params, samples) -> nil
|
||||
* full_parallel(params, samples, n_samples) -> nil
|
||||
* full_parallel(params, samples, n_samples, n_processors) -> nil
|
||||
* full_parallel(params, samples, nil, n_processors) -> nil
|
||||
*/
|
||||
static VALUE ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self) {
|
||||
if (argc < 2 || argc > 4) {
|
||||
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
|
||||
}
|
||||
|
||||
ruby_whisper *rw;
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
VALUE params = argv[0];
|
||||
Data_Get_Struct(params, ruby_whisper_params, rwp);
|
||||
VALUE samples = argv[1];
|
||||
int n_samples;
|
||||
int n_processors;
|
||||
rb_memory_view_t view;
|
||||
const bool memory_view_available_p = rb_memory_view_available_p(samples);
|
||||
switch (argc) {
|
||||
case 2:
|
||||
n_processors = 1;
|
||||
break;
|
||||
case 3:
|
||||
n_processors = 1;
|
||||
break;
|
||||
case 4:
|
||||
n_processors = NUM2INT(argv[3]);
|
||||
break;
|
||||
}
|
||||
if (argc >= 3 && !NIL_P(argv[2])) {
|
||||
n_samples = NUM2INT(argv[2]);
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
if (RARRAY_LEN(samples) < n_samples) {
|
||||
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
|
||||
}
|
||||
}
|
||||
// Should check when samples.respond_to?(:length)?
|
||||
} else if (memory_view_available_p) {
|
||||
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
|
||||
view.obj = Qnil;
|
||||
rb_raise(rb_eArgError, "unable to get a memory view");
|
||||
}
|
||||
n_samples = view.byte_size / view.item_size;
|
||||
} else {
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
n_samples = RARRAY_LEN(samples);
|
||||
} else if (rb_respond_to(samples, id_length)) {
|
||||
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
|
||||
} else {
|
||||
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
|
||||
}
|
||||
}
|
||||
float * c_samples = (float *)malloc(n_samples * sizeof(float));
|
||||
if (memory_view_available_p) {
|
||||
c_samples = (float *)view.data;
|
||||
} else {
|
||||
if (TYPE(samples) == T_ARRAY) {
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
|
||||
}
|
||||
} else {
|
||||
// FIXME: use rb_block_call
|
||||
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
|
||||
for (int i = 0; i < n_samples; i++) {
|
||||
// TODO: check if iter is exhausted and raise ArgumentError
|
||||
VALUE sample = rb_funcall(iter, id_next, 0);
|
||||
c_samples[i] = RFLOAT_VALUE(sample);
|
||||
}
|
||||
}
|
||||
}
|
||||
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
|
||||
if (0 == result) {
|
||||
return Qnil;
|
||||
} else {
|
||||
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
* Number of segments.
|
||||
*
|
||||
@ -1518,15 +1685,59 @@ static VALUE ruby_whisper_c_model_type(VALUE self) {
|
||||
return rb_str_new2(whisper_model_type_readable(rw->context));
|
||||
}
|
||||
|
||||
static VALUE ruby_whisper_error_initialize(VALUE self, VALUE code) {
|
||||
const int c_code = NUM2INT(code);
|
||||
char *raw_message;
|
||||
switch (c_code) {
|
||||
case -2:
|
||||
raw_message = "failed to compute log mel spectrogram";
|
||||
break;
|
||||
case -3:
|
||||
raw_message = "failed to auto-detect language";
|
||||
break;
|
||||
case -4:
|
||||
raw_message = "too many decoders requested";
|
||||
break;
|
||||
case -5:
|
||||
raw_message = "audio_ctx is larger than the maximum allowed";
|
||||
break;
|
||||
case -6:
|
||||
raw_message = "failed to encode";
|
||||
break;
|
||||
case -7:
|
||||
raw_message = "whisper_kv_cache_init() failed for self-attention cache";
|
||||
break;
|
||||
case -8:
|
||||
raw_message = "failed to decode";
|
||||
break;
|
||||
case -9:
|
||||
raw_message = "failed to decode";
|
||||
break;
|
||||
default:
|
||||
raw_message = "unknown error";
|
||||
break;
|
||||
}
|
||||
const VALUE message = rb_str_new2(raw_message);
|
||||
rb_call_super(1, &message);
|
||||
rb_iv_set(self, "@code", code);
|
||||
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
void Init_whisper() {
|
||||
id_to_s = rb_intern("to_s");
|
||||
id_call = rb_intern("call");
|
||||
id___method__ = rb_intern("__method__");
|
||||
id_to_enum = rb_intern("to_enum");
|
||||
id_length = rb_intern("length");
|
||||
id_next = rb_intern("next");
|
||||
id_new = rb_intern("new");
|
||||
|
||||
mWhisper = rb_define_module("Whisper");
|
||||
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
|
||||
cParams = rb_define_class_under(mWhisper, "Params", rb_cObject);
|
||||
eError = rb_define_class_under(mWhisper, "Error", rb_eStandardError);
|
||||
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
|
||||
rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
|
||||
@ -1564,6 +1775,8 @@ void Init_whisper() {
|
||||
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
|
||||
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
|
||||
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
|
||||
rb_define_method(cContext, "full", ruby_whisper_full, -1);
|
||||
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
|
||||
|
||||
rb_define_alloc_func(cParams, ruby_whisper_params_allocate);
|
||||
|
||||
@ -1623,6 +1836,9 @@ void Init_whisper() {
|
||||
rb_define_method(cParams, "abort_callback=", ruby_whisper_params_set_abort_callback, 1);
|
||||
rb_define_method(cParams, "abort_callback_user_data=", ruby_whisper_params_set_abort_callback_user_data, 1);
|
||||
|
||||
rb_define_attr(eError, "code", true, false);
|
||||
rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
|
||||
|
||||
// High leve
|
||||
cSegment = rb_define_class_under(mWhisper, "Segment", rb_cObject);
|
||||
|
||||
|
Reference in New Issue
Block a user