ruby : Make context accept initial parameters, API to retrieve a segment and more ()

* Fix type signature for Whisper.log_set

* Use cache file for model when offline

* Extract ruby_whisper_transcribe() into a file

* Extract Whisper::Error

* Use FileList for ext/*.{c,cpp,h}

* Extract Whisper::Segment

* Extract Whisper::Model

* Extract Whisper::Params

* Extract Whisper::Context

* Extract log_callback function

* Write base code in C rather than C++

* Use chdir instead of Dir.chdir in Rakefile

* Define alloc func for Whisper::Model

* Define Whisper::Params' calback and user data reader

* Add test for Whisper::Params.new with keyword arguments

* Make Whisper::Params.new accept keyword arguments

* Update type signatures

* Update README

* Update CLEAN targets

* Fix document comment for Whisper::Params#new_segment_callback=

* Use macro to define params

* Fix dependency of build task

* Set Whisper.finalize_log_callback visibility to private

* Make Whisper::Context#full and full_parallel return self

* Add test for Whisper::Context#full_get_segment

* Add Whisper::Context#full_get_segment

* Update signatures

* Update README

* Fix signature

* Resplace #initialize with .new in signature file [skip ci]

* Fix potential overflow
This commit is contained in:
KITAITI Makoto 2025-01-21 16:39:54 +09:00 committed by GitHub
parent 7a423f1c00
commit 7ffcd05267
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 2610 additions and 2021 deletions

@ -24,14 +24,15 @@ require "whisper"
whisper = Whisper::Context.new("base")
params = Whisper::Params.new
params.language = "en"
params.offset = 10_000
params.duration = 60_000
params.max_text_tokens = 300
params.translate = true
params.print_timestamps = false
params.initial_prompt = "Initial prompt here."
params = Whisper::Params.new(
language: "en",
offset: 10_000,
duration: 60_000,
max_text_tokens: 300,
translate: true,
print_timestamps: false,
initial_prompt: "Initial prompt here."
)
whisper.transcribe("path/to/audio.wav", params) do |whole_text|
puts whole_text
@ -113,18 +114,18 @@ def format_time(time_ms)
"%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
end
whisper.transcribe("path/to/audio.wav", params)
whisper.each_segment.with_index do |segment, index|
line = "[%{nth}: %{st} --> %{ed}] %{text}" % {
nth: index + 1,
st: format_time(segment.start_time),
ed: format_time(segment.end_time),
text: segment.text
}
line << " (speaker turned)" if segment.speaker_next_turn?
puts line
end
whisper
.transcribe("path/to/audio.wav", params)
.each_segment.with_index do |segment, index|
line = "[%{nth}: %{st} --> %{ed}] %{text}" % {
nth: index + 1,
st: format_time(segment.start_time),
ed: format_time(segment.end_time),
text: segment.text
}
line << " (speaker turned)" if segment.speaker_next_turn?
puts line
end
```
@ -215,10 +216,11 @@ reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :
samples = reader.enum_for(:each_buffer).map(&:samples).flatten
whisper = Whisper::Context.new("base")
whisper.full(Whisper::Params.new, samples)
whisper.each_segment do |segment|
puts segment.text
end
whisper
.full(Whisper::Params.new, samples)
.each_segment do |segment|
puts segment.text
end
```
The second argument `samples` may be an array, an object with `length` and `each` method, or a MemoryView. If you can prepare audio data as C array and export it as a MemoryView, whispercpp accepts and works with it with zero copy.

@ -18,9 +18,11 @@ EXTSOURCES.each do |src|
end
CLEAN.include SOURCES
CLEAN.include FileList["ext/*.o", "ext/*.metal", "ext/whisper.{so,bundle,dll}"]
CLEAN.include FileList["ext/**/*.o", "ext/**/*.metal", "ext/**/*.tmp", "ext/whisper.{so,bundle,dll}"]
task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whispercpp.gemspec"]
SRC = FileList["ext/*.{c,cpp,h}"]
task build: SOURCES
directory "pkg"
CLOBBER.include "pkg"
@ -29,14 +31,14 @@ LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"])
SO_FILE = File.join("ext", LIB_NAME)
LIB_FILE = File.join("lib", LIB_NAME)
file "ext/Makefile" => ["ext/extconf.rb", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp"] + SOURCES do |t|
Dir.chdir "ext" do
file "ext/Makefile" => SRC + ["ext/extconf.rb"] + SOURCES do |t|
chdir "ext" do
ruby "extconf.rb"
end
end
file SO_FILE => "ext/Makefile" do |t|
Dir.chdir "ext" do
chdir "ext" do
sh "make"
end
end
@ -54,7 +56,7 @@ end
TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
Dir.chdir "tests/jfk_reader" do
chdir "tests/jfk_reader" do
ruby "extconf.rb"
sh "make"
end

@ -4,10 +4,8 @@ whisper.bundle
whisper.dll
scripts/get-flags.mk
*.o
*.c
*.cpp
*.h
*.m
*.metal
!ruby_whisper.cpp
!ruby_whisper.h
/*/**/*.c
/*/**/*.cpp
/*/**/*.h
/*/**/*.m
/*/**/*.metal

@ -174,7 +174,14 @@ $OBJ_WHISPER <<
'src/whisper.o'
$objs = $OBJ_GGML + $OBJ_WHISPER + $OBJ_COMMON + $OBJ_SDL
$objs << "ruby_whisper.o"
$objs <<
"ruby_whisper.o" <<
"ruby_whisper_context.o" <<
"ruby_whisper_transcribe.o" <<
"ruby_whisper_params.o" <<
"ruby_whisper_error.o" <<
"ruby_whisper_segment.o" <<
"ruby_whisper_model.o"
$CPPFLAGS = "#{$MK_CPPFLAGS} #{$CPPFLAGS}"
$CFLAGS = "#{$CPPFLAGS} #{$MK_CFLAGS} #{$GF_CFLAGS} #{$CFLAGS}"

@ -0,0 +1,164 @@
#include <ruby.h>
#include <ruby/memory_view.h>
#include "ruby_whisper.h"
VALUE mWhisper;
VALUE cContext;
VALUE cParams;
VALUE eError;
VALUE cSegment;
VALUE cModel;
ID id_to_s;
ID id_call;
ID id___method__;
ID id_to_enum;
ID id_length;
ID id_next;
ID id_new;
ID id_to_path;
ID id_URI;
ID id_pre_converted_models;
static bool is_log_callback_finalized = false;
// High level API
extern VALUE ruby_whisper_segment_allocate(VALUE klass);
extern void init_ruby_whisper_context(VALUE *mWhisper);
extern void init_ruby_whisper_params(VALUE *mWhisper);
extern void init_ruby_whisper_error(VALUE *mWhisper);
extern void init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cSegment);
extern void init_ruby_whisper_model(VALUE *mWhisper);
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
/*
* call-seq:
* lang_max_id -> Integer
*/
static VALUE ruby_whisper_s_lang_max_id(VALUE self) {
return INT2NUM(whisper_lang_max_id());
}
/*
* call-seq:
* lang_id(lang_name) -> Integer
*/
static VALUE ruby_whisper_s_lang_id(VALUE self, VALUE lang) {
const char * lang_str = StringValueCStr(lang);
const int id = whisper_lang_id(lang_str);
if (-1 == id) {
rb_raise(rb_eArgError, "language not found: %s", lang_str);
}
return INT2NUM(id);
}
/*
* call-seq:
* lang_str(lang_id) -> String
*/
static VALUE ruby_whisper_s_lang_str(VALUE self, VALUE id) {
const int lang_id = NUM2INT(id);
const char * str = whisper_lang_str(lang_id);
if (NULL == str) {
rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
}
return rb_str_new2(str);
}
/*
* call-seq:
* lang_str(lang_id) -> String
*/
static VALUE ruby_whisper_s_lang_str_full(VALUE self, VALUE id) {
const int lang_id = NUM2INT(id);
const char * str_full = whisper_lang_str_full(lang_id);
if (NULL == str_full) {
rb_raise(rb_eIndexError, "id %d outside of language id", lang_id);
}
return rb_str_new2(str_full);
}
static VALUE ruby_whisper_s_finalize_log_callback(VALUE self, VALUE id) {
is_log_callback_finalized = true;
return Qnil;
}
static void
ruby_whisper_log_callback(enum ggml_log_level level, const char * buffer, void * user_data) {
if (is_log_callback_finalized) {
return;
}
VALUE log_callback = rb_iv_get(mWhisper, "log_callback");
VALUE udata = rb_iv_get(mWhisper, "user_data");
rb_funcall(log_callback, id_call, 3, INT2NUM(level), rb_str_new2(buffer), udata);
}
/*
* call-seq:
* log_set ->(level, buffer, user_data) { ... }, user_data -> nil
*/
static VALUE ruby_whisper_s_log_set(VALUE self, VALUE log_callback, VALUE user_data) {
VALUE old_callback = rb_iv_get(self, "log_callback");
if (!NIL_P(old_callback)) {
rb_undefine_finalizer(old_callback);
}
rb_iv_set(self, "log_callback", log_callback);
rb_iv_set(self, "user_data", user_data);
VALUE finalize_log_callback = rb_funcall(mWhisper, rb_intern("method"), 1, rb_str_new2("finalize_log_callback"));
rb_define_finalizer(log_callback, finalize_log_callback);
whisper_log_set(ruby_whisper_log_callback, NULL);
return Qnil;
}
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
rb_gc_mark(rwm->context);
}
static VALUE ruby_whisper_model_allocate(VALUE klass) {
ruby_whisper_model *rwm;
rwm = ALLOC(ruby_whisper_model);
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
}
void Init_whisper() {
id_to_s = rb_intern("to_s");
id_call = rb_intern("call");
id___method__ = rb_intern("__method__");
id_to_enum = rb_intern("to_enum");
id_length = rb_intern("length");
id_next = rb_intern("next");
id_new = rb_intern("new");
id_to_path = rb_intern("to_path");
id_URI = rb_intern("URI");
id_pre_converted_models = rb_intern("pre_converted_models");
mWhisper = rb_define_module("Whisper");
rb_define_const(mWhisper, "LOG_LEVEL_NONE", INT2NUM(GGML_LOG_LEVEL_NONE));
rb_define_const(mWhisper, "LOG_LEVEL_INFO", INT2NUM(GGML_LOG_LEVEL_INFO));
rb_define_const(mWhisper, "LOG_LEVEL_WARN", INT2NUM(GGML_LOG_LEVEL_WARN));
rb_define_const(mWhisper, "LOG_LEVEL_ERROR", INT2NUM(GGML_LOG_LEVEL_ERROR));
rb_define_const(mWhisper, "LOG_LEVEL_DEBUG", INT2NUM(GGML_LOG_LEVEL_DEBUG));
rb_define_const(mWhisper, "LOG_LEVEL_CONT", INT2NUM(GGML_LOG_LEVEL_CONT));
rb_define_singleton_method(mWhisper, "lang_max_id", ruby_whisper_s_lang_max_id, 0);
rb_define_singleton_method(mWhisper, "lang_id", ruby_whisper_s_lang_id, 1);
rb_define_singleton_method(mWhisper, "lang_str", ruby_whisper_s_lang_str, 1);
rb_define_singleton_method(mWhisper, "lang_str_full", ruby_whisper_s_lang_str_full, 1);
rb_define_singleton_method(mWhisper, "log_set", ruby_whisper_s_log_set, 2);
rb_define_private_method(rb_singleton_class(mWhisper), "finalize_log_callback", ruby_whisper_s_finalize_log_callback, 1);
init_ruby_whisper_context(&mWhisper);
init_ruby_whisper_params(&mWhisper);
init_ruby_whisper_error(&mWhisper);
init_ruby_whisper_segment(&mWhisper, &cContext);
init_ruby_whisper_model(&mWhisper);
rb_require("whisper/model/uri");
}

File diff suppressed because it is too large Load Diff

@ -22,4 +22,13 @@ typedef struct {
ruby_whisper_callback_container *abort_callback_container;
} ruby_whisper_params;
typedef struct {
VALUE context;
int index;
} ruby_whisper_segment;
typedef struct {
VALUE context;
} ruby_whisper_model;
#endif

@ -0,0 +1,613 @@
#include <ruby.h>
#include <ruby/memory_view.h>
#include "ruby_whisper.h"
extern ID id_to_s;
extern ID id___method__;
extern ID id_to_enum;
extern ID id_length;
extern ID id_next;
extern ID id_new;
extern ID id_to_path;
extern ID id_URI;
extern ID id_pre_converted_models;
extern VALUE cContext;
extern VALUE eError;
extern VALUE cModel;
extern VALUE ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self);
extern VALUE rb_whisper_model_initialize(VALUE context);
extern VALUE rb_whisper_segment_initialize(VALUE context, int index);
extern void register_callbacks(ruby_whisper_params *rwp, VALUE *context);
static void
ruby_whisper_free(ruby_whisper *rw)
{
if (rw->context) {
whisper_free(rw->context);
rw->context = NULL;
}
}
void
rb_whisper_mark(ruby_whisper *rw)
{
// call rb_gc_mark on any ruby references in rw
}
void
rb_whisper_free(ruby_whisper *rw)
{
ruby_whisper_free(rw);
free(rw);
}
static VALUE
ruby_whisper_allocate(VALUE klass)
{
ruby_whisper *rw;
rw = ALLOC(ruby_whisper);
rw->context = NULL;
return Data_Wrap_Struct(klass, rb_whisper_mark, rb_whisper_free, rw);
}
/*
* call-seq:
* new("base.en") -> Whisper::Context
* new("path/to/model.bin") -> Whisper::Context
* new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
*/
static VALUE
ruby_whisper_initialize(int argc, VALUE *argv, VALUE self)
{
ruby_whisper *rw;
VALUE whisper_model_file_path;
// TODO: we can support init from buffer here too maybe another ruby object to expose
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
Data_Get_Struct(self, ruby_whisper, rw);
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
if (!NIL_P(pre_converted_model)) {
whisper_model_file_path = pre_converted_model;
}
if (TYPE(whisper_model_file_path) == T_STRING) {
const char * whisper_model_file_path_str = StringValueCStr(whisper_model_file_path);
if (strncmp("http://", whisper_model_file_path_str, 7) == 0 || strncmp("https://", whisper_model_file_path_str, 8) == 0) {
VALUE uri_class = rb_const_get(cModel, id_URI);
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
}
}
if (rb_obj_is_kind_of(whisper_model_file_path, rb_path2class("URI::HTTP"))) {
VALUE uri_class = rb_const_get(cModel, id_URI);
whisper_model_file_path = rb_class_new_instance(1, &whisper_model_file_path, uri_class);
}
if (rb_respond_to(whisper_model_file_path, id_to_path)) {
whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
}
if (!rb_respond_to(whisper_model_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to model to initialize Whisper::Context");
}
rw->context = whisper_init_from_file_with_params(StringValueCStr(whisper_model_file_path), whisper_context_default_params());
if (rw->context == NULL) {
rb_raise(rb_eRuntimeError, "error: failed to initialize whisper context");
}
return self;
}
/*
* call-seq:
* model_n_vocab -> Integer
*/
VALUE ruby_whisper_model_n_vocab(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_vocab(rw->context));
}
/*
* call-seq:
* model_n_audio_ctx -> Integer
*/
VALUE ruby_whisper_model_n_audio_ctx(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
}
/*
* call-seq:
* model_n_audio_state -> Integer
*/
VALUE ruby_whisper_model_n_audio_state(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_state(rw->context));
}
/*
* call-seq:
* model_n_audio_head -> Integer
*/
VALUE ruby_whisper_model_n_audio_head(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_head(rw->context));
}
/*
* call-seq:
* model_n_audio_layer -> Integer
*/
VALUE ruby_whisper_model_n_audio_layer(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_layer(rw->context));
}
/*
* call-seq:
* model_n_text_ctx -> Integer
*/
VALUE ruby_whisper_model_n_text_ctx(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_ctx(rw->context));
}
/*
* call-seq:
* model_n_text_state -> Integer
*/
VALUE ruby_whisper_model_n_text_state(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_state(rw->context));
}
/*
* call-seq:
* model_n_text_head -> Integer
*/
VALUE ruby_whisper_model_n_text_head(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_head(rw->context));
}
/*
* call-seq:
* model_n_text_layer -> Integer
*/
VALUE ruby_whisper_model_n_text_layer(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_layer(rw->context));
}
/*
* call-seq:
* model_n_mels -> Integer
*/
VALUE ruby_whisper_model_n_mels(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_n_mels(rw->context));
}
/*
* call-seq:
* model_ftype -> Integer
*/
VALUE ruby_whisper_model_ftype(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_model_ftype(rw->context));
}
/*
* call-seq:
* model_type -> String
*/
VALUE ruby_whisper_model_type(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return rb_str_new2(whisper_model_type_readable(rw->context));
}
/*
* Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
* Not thread safe for same context
* Uses the specified decoding strategy to obtain the text.
*
* call-seq:
* full(params, samples, n_samples) -> nil
* full(params, samples) -> nil
*
* The second argument +samples+ must be an array of samples, respond to :length, or be a MemoryView of an array of float. It must be 32 bit float PCM audio data.
*/
VALUE ruby_whisper_full(int argc, VALUE *argv, VALUE self)
{
if (argc < 2 || argc > 3) {
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
}
ruby_whisper *rw;
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper, rw);
VALUE params = argv[0];
Data_Get_Struct(params, ruby_whisper_params, rwp);
VALUE samples = argv[1];
int n_samples;
rb_memory_view_t view;
const bool memory_view_available_p = rb_memory_view_available_p(samples);
if (argc == 3) {
n_samples = NUM2INT(argv[2]);
if (TYPE(samples) == T_ARRAY) {
if (RARRAY_LEN(samples) < n_samples) {
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
}
}
// Should check when samples.respond_to?(:length)?
} else {
if (TYPE(samples) == T_ARRAY) {
n_samples = RARRAY_LEN(samples);
} else if (memory_view_available_p) {
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
view.obj = Qnil;
rb_raise(rb_eArgError, "unable to get a memory view");
}
n_samples = view.byte_size / view.item_size;
} else if (rb_respond_to(samples, id_length)) {
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
} else {
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
}
}
float * c_samples = (float *)malloc(n_samples * sizeof(float));
if (memory_view_available_p) {
c_samples = (float *)view.data;
} else {
if (TYPE(samples) == T_ARRAY) {
for (int i = 0; i < n_samples; i++) {
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
}
} else {
// TODO: use rb_block_call
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
for (int i = 0; i < n_samples; i++) {
// TODO: check if iter is exhausted and raise ArgumentError appropriately
VALUE sample = rb_funcall(iter, id_next, 0);
c_samples[i] = RFLOAT_VALUE(sample);
}
}
}
register_callbacks(rwp, &self);
const int result = whisper_full(rw->context, rwp->params, c_samples, n_samples);
if (0 == result) {
return self;
} else {
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
}
}
/*
* Split the input audio in chunks and process each chunk separately using whisper_full_with_state()
* Result is stored in the default state of the context
* Not thread safe if executed in parallel on the same context.
* It seems this approach can offer some speedup in some cases.
* However, the transcription accuracy can be worse at the beginning and end of each chunk.
*
* call-seq:
* full_parallel(params, samples) -> nil
* full_parallel(params, samples, n_samples) -> nil
* full_parallel(params, samples, n_samples, n_processors) -> nil
* full_parallel(params, samples, nil, n_processors) -> nil
*/
static VALUE
ruby_whisper_full_parallel(int argc, VALUE *argv,VALUE self)
{
if (argc < 2 || argc > 4) {
rb_raise(rb_eArgError, "wrong number of arguments (given %d, expected 2..3)", argc);
}
ruby_whisper *rw;
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper, rw);
VALUE params = argv[0];
Data_Get_Struct(params, ruby_whisper_params, rwp);
VALUE samples = argv[1];
int n_samples;
int n_processors;
rb_memory_view_t view;
const bool memory_view_available_p = rb_memory_view_available_p(samples);
switch (argc) {
case 2:
n_processors = 1;
break;
case 3:
n_processors = 1;
break;
case 4:
n_processors = NUM2INT(argv[3]);
break;
}
if (argc >= 3 && !NIL_P(argv[2])) {
n_samples = NUM2INT(argv[2]);
if (TYPE(samples) == T_ARRAY) {
if (RARRAY_LEN(samples) < n_samples) {
rb_raise(rb_eArgError, "samples length %ld is less than n_samples %d", RARRAY_LEN(samples), n_samples);
}
}
// Should check when samples.respond_to?(:length)?
} else if (memory_view_available_p) {
if (!rb_memory_view_get(samples, &view, RUBY_MEMORY_VIEW_SIMPLE)) {
view.obj = Qnil;
rb_raise(rb_eArgError, "unable to get a memory view");
}
n_samples = view.byte_size / view.item_size;
} else {
if (TYPE(samples) == T_ARRAY) {
n_samples = RARRAY_LEN(samples);
} else if (rb_respond_to(samples, id_length)) {
n_samples = NUM2INT(rb_funcall(samples, id_length, 0));
} else {
rb_raise(rb_eArgError, "samples must respond to :length or be a MemoryView of an array of flaot when n_samples is not given");
}
}
float * c_samples = (float *)malloc(n_samples * sizeof(float));
if (memory_view_available_p) {
c_samples = (float *)view.data;
} else {
if (TYPE(samples) == T_ARRAY) {
for (int i = 0; i < n_samples; i++) {
c_samples[i] = RFLOAT_VALUE(rb_ary_entry(samples, i));
}
} else {
// FIXME: use rb_block_call
VALUE iter = rb_funcall(samples, id_to_enum, 1, rb_str_new2("each"));
for (int i = 0; i < n_samples; i++) {
// TODO: check if iter is exhausted and raise ArgumentError
VALUE sample = rb_funcall(iter, id_next, 0);
c_samples[i] = RFLOAT_VALUE(sample);
}
}
}
register_callbacks(rwp, &self);
const int result = whisper_full_parallel(rw->context, rwp->params, c_samples, n_samples, n_processors);
if (0 == result) {
return self;
} else {
rb_exc_raise(rb_funcall(eError, id_new, 1, result));
}
}
/*
* Number of segments.
*
* call-seq:
* full_n_segments -> Integer
*/
static VALUE
ruby_whisper_full_n_segments(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_full_n_segments(rw->context));
}
/*
* Language ID, which can be converted to string by Whisper.lang_str and Whisper.lang_str_full.
*
* call-seq:
* full_lang_id -> Integer
*/
static VALUE
ruby_whisper_full_lang_id(VALUE self)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
return INT2NUM(whisper_full_lang_id(rw->context));
}
static int ruby_whisper_full_check_segment_index(const ruby_whisper * rw, const VALUE i_segment)
{
const int c_i_segment = NUM2INT(i_segment);
if (c_i_segment < 0 || c_i_segment >= whisper_full_n_segments(rw->context)) {
rb_raise(rb_eIndexError, "segment index %d out of range", c_i_segment);
}
return c_i_segment;
}
/*
* Start time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
*
* full_get_segment_t0(3) # => 1668 (16680 ms)
*
* call-seq:
* full_get_segment_t0(segment_index) -> Integer
*/
static VALUE
ruby_whisper_full_get_segment_t0(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t0 = whisper_full_get_segment_t0(rw->context, c_i_segment);
return INT2NUM(t0);
}
/*
* End time of a segment indexed by +segment_index+ in centiseconds (10 times milliseconds).
*
* full_get_segment_t1(3) # => 1668 (16680 ms)
*
* call-seq:
* full_get_segment_t1(segment_index) -> Integer
*/
static VALUE
ruby_whisper_full_get_segment_t1(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const int64_t t1 = whisper_full_get_segment_t1(rw->context, c_i_segment);
return INT2NUM(t1);
}
/*
* Whether the next segment indexed by +segment_index+ is predicated as a speaker turn.
*
* full_get_segment_speacker_turn_next(3) # => true
*
* call-seq:
* full_get_segment_speacker_turn_next(segment_index) -> bool
*/
static VALUE
ruby_whisper_full_get_segment_speaker_turn_next(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const bool speaker_turn_next = whisper_full_get_segment_speaker_turn_next(rw->context, c_i_segment);
return speaker_turn_next ? Qtrue : Qfalse;
}
/*
* Text of a segment indexed by +segment_index+.
*
* full_get_segment_text(3) # => "ask not what your country can do for you, ..."
*
* call-seq:
* full_get_segment_text(segment_index) -> String
*/
static VALUE
ruby_whisper_full_get_segment_text(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const char * text = whisper_full_get_segment_text(rw->context, c_i_segment);
return rb_str_new2(text);
}
/*
* call-seq:
* full_get_segment_no_speech_prob(segment_index) -> Float
*/
static VALUE
ruby_whisper_full_get_segment_no_speech_prob(VALUE self, VALUE i_segment)
{
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int c_i_segment = ruby_whisper_full_check_segment_index(rw, i_segment);
const float no_speech_prob = whisper_full_get_segment_no_speech_prob(rw->context, c_i_segment);
return DBL2NUM(no_speech_prob);
}
// High level API
static VALUE
ruby_whisper_full_get_segment(VALUE self, VALUE i_segment)
{
return rb_whisper_segment_initialize(self, NUM2INT(i_segment));
}
/*
* Yields each Whisper::Segment:
*
* whisper.transcribe("path/to/audio.wav", params)
* whisper.each_segment do |segment|
* puts segment.text
* end
*
* Returns an Enumerator if no block given:
*
* whisper.transcribe("path/to/audio.wav", params)
* enum = whisper.each_segment
* enum.to_a # => [#<Whisper::Segment>, ...]
*
* call-seq:
* each_segment {|segment| ... }
* each_segment -> Enumerator
*/
static VALUE
ruby_whisper_each_segment(VALUE self)
{
if (!rb_block_given_p()) {
const VALUE method_name = rb_funcall(self, id___method__, 0);
return rb_funcall(self, id_to_enum, 1, method_name);
}
ruby_whisper *rw;
Data_Get_Struct(self, ruby_whisper, rw);
const int n_segments = whisper_full_n_segments(rw->context);
for (int i = 0; i < n_segments; ++i) {
rb_yield(rb_whisper_segment_initialize(self, i));
}
return self;
}
/*
* call-seq:
* model -> Whisper::Model
*/
static VALUE
ruby_whisper_get_model(VALUE self)
{
return rb_whisper_model_initialize(self);
}
void
init_ruby_whisper_context(VALUE *mWhisper)
{
cContext = rb_define_class_under(*mWhisper, "Context", rb_cObject);
rb_define_alloc_func(cContext, ruby_whisper_allocate);
rb_define_method(cContext, "initialize", ruby_whisper_initialize, -1);
rb_define_method(cContext, "transcribe", ruby_whisper_transcribe, -1);
rb_define_method(cContext, "model_n_vocab", ruby_whisper_model_n_vocab, 0);
rb_define_method(cContext, "model_n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
rb_define_method(cContext, "model_n_audio_state", ruby_whisper_model_n_audio_state, 0);
rb_define_method(cContext, "model_n_audio_head", ruby_whisper_model_n_audio_head, 0);
rb_define_method(cContext, "model_n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
rb_define_method(cContext, "model_n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
rb_define_method(cContext, "model_n_text_state", ruby_whisper_model_n_text_state, 0);
rb_define_method(cContext, "model_n_text_head", ruby_whisper_model_n_text_head, 0);
rb_define_method(cContext, "model_n_text_layer", ruby_whisper_model_n_text_layer, 0);
rb_define_method(cContext, "model_n_mels", ruby_whisper_model_n_mels, 0);
rb_define_method(cContext, "model_ftype", ruby_whisper_model_ftype, 0);
rb_define_method(cContext, "model_type", ruby_whisper_model_type, 0);
rb_define_method(cContext, "full_n_segments", ruby_whisper_full_n_segments, 0);
rb_define_method(cContext, "full_lang_id", ruby_whisper_full_lang_id, 0);
rb_define_method(cContext, "full_get_segment_t0", ruby_whisper_full_get_segment_t0, 1);
rb_define_method(cContext, "full_get_segment_t1", ruby_whisper_full_get_segment_t1, 1);
rb_define_method(cContext, "full_get_segment_speaker_turn_next", ruby_whisper_full_get_segment_speaker_turn_next, 1);
rb_define_method(cContext, "full_get_segment_text", ruby_whisper_full_get_segment_text, 1);
rb_define_method(cContext, "full_get_segment_no_speech_prob", ruby_whisper_full_get_segment_no_speech_prob, 1);
rb_define_method(cContext, "full", ruby_whisper_full, -1);
rb_define_method(cContext, "full_parallel", ruby_whisper_full_parallel, -1);
// High leve
rb_define_method(cContext, "full_get_segment", ruby_whisper_full_get_segment, 1);
rb_define_method(cContext, "each_segment", ruby_whisper_each_segment, 0);
rb_define_method(cContext, "model", ruby_whisper_get_model, 0);
}

@ -0,0 +1,52 @@
#include <ruby.h>
extern VALUE eError;
VALUE ruby_whisper_error_initialize(VALUE self, VALUE code)
{
const int c_code = NUM2INT(code);
const char *raw_message;
switch (c_code) {
case -2:
raw_message = "failed to compute log mel spectrogram";
break;
case -3:
raw_message = "failed to auto-detect language";
break;
case -4:
raw_message = "too many decoders requested";
break;
case -5:
raw_message = "audio_ctx is larger than the maximum allowed";
break;
case -6:
raw_message = "failed to encode";
break;
case -7:
raw_message = "whisper_kv_cache_init() failed for self-attention cache";
break;
case -8:
raw_message = "failed to decode";
break;
case -9:
raw_message = "failed to decode";
break;
default:
raw_message = "unknown error";
break;
}
const VALUE message = rb_str_new2(raw_message);
rb_call_super(1, &message);
rb_iv_set(self, "@code", code);
return self;
}
void
init_ruby_whisper_error(VALUE *mWhisper)
{
eError = rb_define_class_under(*mWhisper, "Error", rb_eStandardError);
rb_define_attr(eError, "code", true, false);
rb_define_method(eError, "initialize", ruby_whisper_error_initialize, 1);
}

@ -0,0 +1,210 @@
#include <ruby.h>
#include "ruby_whisper.h"
extern VALUE cModel;
static void rb_whisper_model_mark(ruby_whisper_model *rwm) {
rb_gc_mark(rwm->context);
}
static VALUE ruby_whisper_model_allocate(VALUE klass) {
ruby_whisper_model *rwm;
rwm = ALLOC(ruby_whisper_model);
return Data_Wrap_Struct(klass, rb_whisper_model_mark, RUBY_DEFAULT_FREE, rwm);
}
VALUE rb_whisper_model_initialize(VALUE context) {
ruby_whisper_model *rwm;
const VALUE model = ruby_whisper_model_allocate(cModel);
Data_Get_Struct(model, ruby_whisper_model, rwm);
rwm->context = context;
return model;
};
/*
* call-seq:
* n_vocab -> Integer
*/
static VALUE
ruby_whisper_model_n_vocab(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_vocab(rw->context));
}
/*
* call-seq:
* n_audio_ctx -> Integer
*/
static VALUE
ruby_whisper_model_n_audio_ctx(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_ctx(rw->context));
}
/*
* call-seq:
* n_audio_state -> Integer
*/
static VALUE
ruby_whisper_model_n_audio_state(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_state(rw->context));
}
/*
* call-seq:
* n_audio_head -> Integer
*/
static VALUE
ruby_whisper_model_n_audio_head(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_head(rw->context));
}
/*
* call-seq:
* n_audio_layer -> Integer
*/
static VALUE
ruby_whisper_model_n_audio_layer(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_audio_layer(rw->context));
}
/*
* call-seq:
* n_text_ctx -> Integer
*/
static VALUE
ruby_whisper_model_n_text_ctx(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_ctx(rw->context));
}
/*
* call-seq:
* n_text_state -> Integer
*/
static VALUE
ruby_whisper_model_n_text_state(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_state(rw->context));
}
/*
* call-seq:
* n_text_head -> Integer
*/
static VALUE
ruby_whisper_model_n_text_head(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_head(rw->context));
}
/*
* call-seq:
* n_text_layer -> Integer
*/
static VALUE
ruby_whisper_model_n_text_layer(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_text_layer(rw->context));
}
/*
* call-seq:
* n_mels -> Integer
*/
static VALUE
ruby_whisper_model_n_mels(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_n_mels(rw->context));
}
/*
* call-seq:
* ftype -> Integer
*/
static VALUE
ruby_whisper_model_ftype(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return INT2NUM(whisper_model_ftype(rw->context));
}
/*
* call-seq:
* type -> String
*/
static VALUE
ruby_whisper_model_type(VALUE self)
{
ruby_whisper_model *rwm;
Data_Get_Struct(self, ruby_whisper_model, rwm);
ruby_whisper *rw;
Data_Get_Struct(rwm->context, ruby_whisper, rw);
return rb_str_new2(whisper_model_type_readable(rw->context));
}
void
init_ruby_whisper_model(VALUE *mWhisper)
{
cModel = rb_define_class_under(*mWhisper, "Model", rb_cObject);
rb_define_alloc_func(cModel, ruby_whisper_model_allocate);
rb_define_method(cModel, "n_vocab", ruby_whisper_model_n_vocab, 0);
rb_define_method(cModel, "n_audio_ctx", ruby_whisper_model_n_audio_ctx, 0);
rb_define_method(cModel, "n_audio_state", ruby_whisper_model_n_audio_state, 0);
rb_define_method(cModel, "n_audio_head", ruby_whisper_model_n_audio_head, 0);
rb_define_method(cModel, "n_audio_layer", ruby_whisper_model_n_audio_layer, 0);
rb_define_method(cModel, "n_text_ctx", ruby_whisper_model_n_text_ctx, 0);
rb_define_method(cModel, "n_text_state", ruby_whisper_model_n_text_state, 0);
rb_define_method(cModel, "n_text_head", ruby_whisper_model_n_text_head, 0);
rb_define_method(cModel, "n_text_layer", ruby_whisper_model_n_text_layer, 0);
rb_define_method(cModel, "n_mels", ruby_whisper_model_n_mels, 0);
rb_define_method(cModel, "ftype", ruby_whisper_model_ftype, 0);
rb_define_method(cModel, "type", ruby_whisper_model_type, 0);
}

File diff suppressed because it is too large Load Diff

@ -0,0 +1,123 @@
#include <ruby.h>
#include "ruby_whisper.h"
extern VALUE cSegment;
static void
rb_whisper_segment_mark(ruby_whisper_segment *rws)
{
rb_gc_mark(rws->context);
}
VALUE
ruby_whisper_segment_allocate(VALUE klass)
{
ruby_whisper_segment *rws;
rws = ALLOC(ruby_whisper_segment);
return Data_Wrap_Struct(klass, rb_whisper_segment_mark, RUBY_DEFAULT_FREE, rws);
}
VALUE
rb_whisper_segment_initialize(VALUE context, int index)
{
ruby_whisper_segment *rws;
const VALUE segment = ruby_whisper_segment_allocate(cSegment);
Data_Get_Struct(segment, ruby_whisper_segment, rws);
rws->context = context;
rws->index = index;
return segment;
};
/*
* Start time in milliseconds.
*
* call-seq:
* start_time -> Integer
*/
static VALUE
ruby_whisper_segment_get_start_time(VALUE self)
{
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
const int64_t t0 = whisper_full_get_segment_t0(rw->context, rws->index);
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
return INT2NUM(t0 * 10);
}
/*
* End time in milliseconds.
*
* call-seq:
* end_time -> Integer
*/
static VALUE
ruby_whisper_segment_get_end_time(VALUE self)
{
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
const int64_t t1 = whisper_full_get_segment_t1(rw->context, rws->index);
// able to multiply 10 without overflow because to_timestamp() in whisper.cpp does it
return INT2NUM(t1 * 10);
}
/*
* Whether the next segment is predicted as a speaker turn.
*
* call-seq:
* speaker_turn_next? -> bool
*/
static VALUE
ruby_whisper_segment_get_speaker_turn_next(VALUE self)
{
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
return whisper_full_get_segment_speaker_turn_next(rw->context, rws->index) ? Qtrue : Qfalse;
}
/*
* call-seq:
* text -> String
*/
static VALUE
ruby_whisper_segment_get_text(VALUE self)
{
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
const char * text = whisper_full_get_segment_text(rw->context, rws->index);
return rb_str_new2(text);
}
/*
* call-seq:
* no_speech_prob -> Float
*/
static VALUE
ruby_whisper_segment_get_no_speech_prob(VALUE self)
{
ruby_whisper_segment *rws;
Data_Get_Struct(self, ruby_whisper_segment, rws);
ruby_whisper *rw;
Data_Get_Struct(rws->context, ruby_whisper, rw);
return DBL2NUM(whisper_full_get_segment_no_speech_prob(rw->context, rws->index));
}
void
init_ruby_whisper_segment(VALUE *mWhisper, VALUE *cContext)
{
cSegment = rb_define_class_under(*mWhisper, "Segment", rb_cObject);
rb_define_alloc_func(cSegment, ruby_whisper_segment_allocate);
rb_define_method(cSegment, "start_time", ruby_whisper_segment_get_start_time, 0);
rb_define_method(cSegment, "end_time", ruby_whisper_segment_get_end_time, 0);
rb_define_method(cSegment, "speaker_next_turn?", ruby_whisper_segment_get_speaker_turn_next, 0);
rb_define_method(cSegment, "text", ruby_whisper_segment_get_text, 0);
rb_define_method(cSegment, "no_speech_prob", ruby_whisper_segment_get_no_speech_prob, 0);
}

@ -0,0 +1,159 @@
#include <ruby.h>
#include "ruby_whisper.h"
#define DR_WAV_IMPLEMENTATION
#include "dr_wav.h"
#include <string>
#include <vector>
#ifdef __cplusplus
extern "C" {
#endif
extern ID id_to_s;
extern ID id_call;
extern void
register_callbacks(ruby_whisper_params * rwp, VALUE * self);
/*
* transcribe a single file
* can emit to a block results
*
* params = Whisper::Params.new
* params.duration = 60_000
* whisper.transcribe "path/to/audio.wav", params do |text|
* puts text
* end
*
* call-seq:
* transcribe(path_to_audio, params) {|text| ...}
**/
VALUE
ruby_whisper_transcribe(int argc, VALUE *argv, VALUE self) {
ruby_whisper *rw;
ruby_whisper_params *rwp;
VALUE wave_file_path, blk, params;
rb_scan_args(argc, argv, "02&", &wave_file_path, &params, &blk);
Data_Get_Struct(self, ruby_whisper, rw);
Data_Get_Struct(params, ruby_whisper_params, rwp);
if (!rb_respond_to(wave_file_path, id_to_s)) {
rb_raise(rb_eRuntimeError, "Expected file path to wave file");
}
std::string fname_inp = StringValueCStr(wave_file_path);
std::vector<float> pcmf32; // mono-channel F32 PCM
std::vector<std::vector<float>> pcmf32s; // stereo-channel F32 PCM
// WAV input - this is directly from main.cpp example
{
drwav wav;
std::vector<uint8_t> wav_data; // used for pipe input from stdin
if (fname_inp == "-") {
{
uint8_t buf[1024];
while (true) {
const size_t n = fread(buf, 1, sizeof(buf), stdin);
if (n == 0) {
break;
}
wav_data.insert(wav_data.end(), buf, buf + n);
}
}
if (drwav_init_memory(&wav, wav_data.data(), wav_data.size(), nullptr) == false) {
fprintf(stderr, "error: failed to open WAV file from stdin\n");
return self;
}
fprintf(stderr, "%s: read %zu bytes from stdin\n", __func__, wav_data.size());
} else if (drwav_init_file(&wav, fname_inp.c_str(), nullptr) == false) {
fprintf(stderr, "error: failed to open '%s' as WAV file\n", fname_inp.c_str());
return self;
}
if (wav.channels != 1 && wav.channels != 2) {
fprintf(stderr, "WAV file '%s' must be mono or stereo\n", fname_inp.c_str());
return self;
}
if (rwp->diarize && wav.channels != 2 && rwp->params.print_timestamps == false) {
fprintf(stderr, "WAV file '%s' must be stereo for diarization and timestamps have to be enabled\n", fname_inp.c_str());
return self;
}
if (wav.sampleRate != WHISPER_SAMPLE_RATE) {
fprintf(stderr, "WAV file '%s' must be %i kHz\n", fname_inp.c_str(), WHISPER_SAMPLE_RATE/1000);
return self;
}
if (wav.bitsPerSample != 16) {
fprintf(stderr, "WAV file '%s' must be 16-bit\n", fname_inp.c_str());
return self;
}
const uint64_t n = wav_data.empty() ? wav.totalPCMFrameCount : wav_data.size()/(wav.channels*wav.bitsPerSample/8);
std::vector<int16_t> pcm16;
pcm16.resize(n*wav.channels);
drwav_read_pcm_frames_s16(&wav, n, pcm16.data());
drwav_uninit(&wav);
// convert to mono, float
pcmf32.resize(n);
if (wav.channels == 1) {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float(pcm16[i])/32768.0f;
}
} else {
for (uint64_t i = 0; i < n; i++) {
pcmf32[i] = float((int32_t)pcm16[2*i] + pcm16[2*i + 1])/65536.0f;
}
}
if (rwp->diarize) {
// convert to stereo, float
pcmf32s.resize(2);
pcmf32s[0].resize(n);
pcmf32s[1].resize(n);
for (uint64_t i = 0; i < n; i++) {
pcmf32s[0][i] = float(pcm16[2*i])/32768.0f;
pcmf32s[1][i] = float(pcm16[2*i + 1])/32768.0f;
}
}
}
{
static bool is_aborted = false; // NOTE: this should be atomic to avoid data race
rwp->params.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
bool is_aborted = *(bool*)user_data;
return !is_aborted;
};
rwp->params.encoder_begin_callback_user_data = &is_aborted;
}
register_callbacks(rwp, &self);
if (whisper_full_parallel(rw->context, rwp->params, pcmf32.data(), pcmf32.size(), 1) != 0) {
fprintf(stderr, "failed to process audio\n");
return self;
}
const int n_segments = whisper_full_n_segments(rw->context);
VALUE output = rb_str_new2("");
for (int i = 0; i < n_segments; ++i) {
const char * text = whisper_full_get_segment_text(rw->context, i);
output = rb_str_concat(output, rb_str_new2(text));
}
VALUE idCall = id_call;
if (blk != Qnil) {
rb_funcall(blk, idCall, 1, output);
}
return self;
}
#ifdef __cplusplus
}
#endif

@ -65,6 +65,13 @@ module Whisper
end
end
end
rescue => err
if cache_path.exist?
warn err
# Use cache file
else
raise
end
end
def download(response)

@ -20,13 +20,12 @@ module Whisper
def self.lang_id: (string name) -> Integer
def self.lang_str: (Integer id) -> String
def self.lang_str_full: (Integer id) -> String
def self.log_set=: (log_callback) -> log_callback
def self.finalize_log_callback: (void) -> void # Second argument of ObjectSpace.define_finalizer
def self.log_set: (log_callback, Object? user_data) -> log_callback
class Context
def initialize: (string | _ToPath | ::URI::HTTP ) -> void
def transcribe: (string, Params) -> void
| (string, Params) { (String) -> void } -> void
def self.new: (string | _ToPath | ::URI::HTTP) -> instance
def transcribe: (string, Params) -> self
| (string, Params) { (String) -> void } -> self
def model_n_vocab: () -> Integer
def model_n_audio_ctx: () -> Integer
def model_n_audio_state: () -> Integer
@ -35,6 +34,10 @@ module Whisper
def model_n_mels: () -> Integer
def model_ftype: () -> Integer
def model_type: () -> String
def each_segment: { (Segment) -> void } -> void
| () -> Enumerator[Segment]
def model: () -> Model
def full_get_segment: (Integer nth) -> Segment
def full_n_segments: () -> Integer
def full_lang_id: () -> Integer
def full_get_segment_t0: (Integer) -> Integer
@ -42,18 +45,46 @@ module Whisper
def full_get_segment_speaker_turn_next: (Integer) -> (true | false)
def full_get_segment_text: (Integer) -> String
def full_get_segment_no_speech_prob: (Integer) -> Float
def full: (Params, Array[Float], ?Integer) -> void
| (Params, _Samples, ?Integer) -> void
def full_parallel: (Params, Array[Float], ?Integer) -> void
| (Params, _Samples, ?Integer) -> void
| (Params, _Samples, ?Integer?, Integer) -> void
def each_segment: { (Segment) -> void } -> void
| () -> Enumerator[Segment]
def model: () -> Model
def full: (Params, Array[Float] samples, ?Integer n_samples) -> self
| (Params, _Samples, ?Integer n_samples) -> self
def full_parallel: (Params, Array[Float], ?Integer n_samples) -> self
| (Params, _Samples, ?Integer n_samples) -> self
| (Params, _Samples, ?Integer? n_samples, Integer n_processors) -> self
end
class Params
def initialize: () -> void
def self.new: (
?language: string,
?translate: boolish,
?no_context: boolish,
?single_segment: boolish,
?print_special: boolish,
?print_progress: boolish,
?print_realtime: boolish,
?print_timestamps: boolish,
?suppress_blank: boolish,
?suppress_nst: boolish,
?token_timestamps: boolish,
?split_on_word: boolish,
?initial_prompt: string | nil,
?diarize: boolish,
?offset: Integer,
?duration: Integer,
?max_text_tokens: Integer,
?temperature: Float,
?max_initial_ts: Float,
?length_penalty: Float,
?temperature_inc: Float,
?entropy_thold: Float,
?logprob_thold: Float,
?no_speech_thold: Float,
?new_segment_callback: new_segment_callback,
?new_segment_callback_user_data: Object,
?progress_callback: progress_callback,
?progress_callback_user_data: Object,
?abort_callback: abort_callback,
?abort_callback_user_data: Object
) -> instance
def language=: (String) -> String # TODO: Enumerate lang names
def language: () -> String
def translate=: (boolish) -> boolish
@ -79,7 +110,7 @@ module Whisper
def split_on_word=: (boolish) -> boolish
def split_on_word: () -> (true | false)
def initial_prompt=: (_ToS) -> _ToS
def initial_prompt: () -> String
def initial_prompt: () -> (String | nil)
def diarize=: (boolish) -> boolish
def diarize: () -> (true | false)
def offset=: (Integer) -> Integer
@ -103,19 +134,25 @@ module Whisper
def no_speech_thold=: (Float) -> Float
def no_speech_thold: () -> Float
def new_segment_callback=: (new_segment_callback) -> new_segment_callback
def new_segment_callback: () -> (new_segment_callback | nil)
def new_segment_callback_user_data=: (Object) -> Object
def new_segment_callback_user_data: () -> Object
def progress_callback=: (progress_callback) -> progress_callback
def progress_callback: () -> (progress_callback | nil)
def progress_callback_user_data=: (Object) -> Object
def progress_callback_user_data: () -> Object
def abort_callback=: (abort_callback) -> abort_callback
def abort_callback: () -> (abort_callback | nil)
def abort_callback_user_data=: (Object) -> Object
def abort_callback_user_data: () -> Object
def on_new_segment: { (Segment) -> void } -> void
def on_progress: { (Integer) -> void } -> void
def abort_on: { (Object) -> boolish } -> void
def on_progress: { (Integer progress) -> void } -> void
def abort_on: { (Object user_data) -> boolish } -> void
end
class Model
def self.pre_converted_models: () -> Hash[String, Model::URI]
def initialize: () -> void
def self.new: () -> instance
def n_vocab: () -> Integer
def n_audio_ctx: () -> Integer
def n_audio_state: () -> Integer
@ -130,14 +167,13 @@ module Whisper
def type: () -> String
class URI
def initialize: (string | ::URI::HTTP) -> void
def self.new: (string | ::URI::HTTP) -> self
def to_path: -> String
def clear_cache: -> void
end
end
class Segment
def initialize: () -> void
def start_time: () -> Integer
def end_time: () -> Integer
def speaker_next_turn?: () -> (true | false)
@ -148,6 +184,6 @@ module Whisper
class Error < StandardError
attr_reader code: Integer
def initialize: (Integer) -> void
def self.new: (Integer code) -> instance
end
end

@ -1,6 +1,39 @@
require_relative "helper"
class TestParams < TestBase
PARAM_NAMES = [
:language,
:translate,
:no_context,
:single_segment,
:print_special,
:print_progress,
:print_realtime,
:print_timestamps,
:suppress_blank,
:suppress_nst,
:token_timestamps,
:split_on_word,
:initial_prompt,
:diarize,
:offset,
:duration,
:max_text_tokens,
:temperature,
:max_initial_ts,
:length_penalty,
:temperature_inc,
:entropy_thold,
:logprob_thold,
:no_speech_thold,
:new_segment_callback,
:new_segment_callback_user_data,
:progress_callback,
:progress_callback_user_data,
:abort_callback,
:abort_callback_user_data,
]
def setup
@params = Whisper::Params.new
end
@ -157,4 +190,57 @@ class TestParams < TestBase
@params.no_speech_thold = 0.2
assert_in_delta 0.2, @params.no_speech_thold
end
def test_new_with_kw_args
params = Whisper::Params.new(language: "es")
assert_equal "es", params.language
assert_equal 1.0, params.max_initial_ts
end
def test_new_with_kw_args_non_existent
assert_raise ArgumentError do
Whisper::Params.new(non_existent: "value")
end
end
def test_new_with_kw_args_wrong_type
assert_raise TypeError do
Whisper::Params.new(language: 3)
end
end
data(PARAM_NAMES.collect {|param| [param, param]}.to_h)
def test_new_with_kw_args_default_values(param)
default_value = @params.send(param)
value = case [param, default_value]
in [*, true | false]
!default_value
in [*, Integer | Float]
default_value + 1
in [:language, *]
"es"
in [:initial_prompt, *]
"Initial prompt"
in [/_callback\Z/, *]
proc {}
in [/_user_data\Z/, *]
Object.new
end
params = Whisper::Params.new(param => value)
if Float === value
assert_in_delta value, params.send(param)
else
assert_equal value, params.send(param)
end
PARAM_NAMES.reject {|name| name == param}.each do |name|
expected = @params.send(name)
actual = params.send(name)
if Float === expected
assert_in_delta expected, actual
else
assert_equal expected, actual
end
end
end
end

@ -29,6 +29,12 @@ class TestWhisper < TestBase
assert_equal 0, whisper.full_lang_id
end
def test_full_get_segment
segment = whisper.full_get_segment(0)
assert_equal 0, segment.start_time
assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
end
def test_full_get_segment_t0
assert_equal 0, whisper.full_get_segment_t0(0)
assert_raise IndexError do