From 3a27b2b91b6cda56d31590da9c1093e5dc96c818 Mon Sep 17 00:00:00 2001 From: KITAITI Makoto Date: Wed, 18 Dec 2024 18:00:50 +0900 Subject: [PATCH] ruby : Add no_speech_thold (#2641) * Remove Whisper::Model.[] * Fix Whisper::Model::URI#request * Make Whisper::Context#initialize accept pre-converted model name * Use downloading pre-converted model feature for testing * Update README * Remove unnecessary task * Move whisper/model.rb -> whisper/model/uri.rb * Update document comment of Whisper::Context#initialize * Don't show download progress when not tty * Pass String to raise * Use cache model file if download fails * Add test for auto download * Specify required Ruby version * Fix a typo * Remove unnecessary flags * Initialize Whisper::Params#diarize explicitely * Remove redundant code from README for simplicity * Add Whisper::Params#no_speech_thold attribute * Add test for Whisper::Params#no_speech_thold --- bindings/ruby/README.md | 27 +++++++------- bindings/ruby/Rakefile | 14 ++----- bindings/ruby/ext/extconf.rb | 5 --- bindings/ruby/ext/ruby_whisper.cpp | 37 +++++++++++++++++-- bindings/ruby/lib/whisper.rb | 2 +- .../lib/whisper/{model.rb => model/uri.rb} | 22 +++++------ bindings/ruby/tests/helper.rb | 1 - bindings/ruby/tests/test_callback.rb | 11 ++---- bindings/ruby/tests/test_model.rb | 15 ++++++-- bindings/ruby/tests/test_params.rb | 6 +++ bindings/ruby/tests/test_segment.rb | 2 +- bindings/ruby/tests/test_whisper.rb | 10 ++--- bindings/ruby/whispercpp.gemspec | 2 +- 13 files changed, 89 insertions(+), 65 deletions(-) rename bindings/ruby/lib/whisper/{model.rb => model/uri.rb} (88%) diff --git a/bindings/ruby/README.md b/bindings/ruby/README.md index e7065bf9..03a8b9e1 100644 --- a/bindings/ruby/README.md +++ b/bindings/ruby/README.md @@ -22,7 +22,7 @@ Usage ```ruby require "whisper" -whisper = Whisper::Context.new(Whisper::Model["base"]) +whisper = Whisper::Context.new("base") params = Whisper::Params.new params.language = "en" @@ -44,17 +44,23 @@ end Some models are prepared up-front: ```ruby -base_en = Whisper::Model["base.en"] +base_en = Whisper::Model.pre_converted_models["base.en"] whisper = Whisper::Context.new(base_en) ``` At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`: ```ruby -Whisper::Model["base"].clear_cache +Whisper::Model.pre_converted_models["base"].clear_cache ``` -You can see the list of prepared model names by `Whisper::Model.preconverted_model_names`: +You also can use shorthand for pre-converted models: + +```ruby +whisper = Whisper::Context.new("base.en") +``` + +You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`: ```ruby puts Whisper::Model.preconverted_model_names @@ -124,13 +130,6 @@ end You can also add hook to params called on new segment: ```ruby -def format_time(time_ms) - sec, decimal_part = time_ms.divmod(1000) - min, sec = sec.divmod(60) - hour, min = min.divmod(60) - "%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part] -end - # Add hook before calling #transcribe params.on_new_segment do |segment| line = "[%{st} --> %{ed}] %{text}" % { @@ -151,7 +150,7 @@ whisper.transcribe("path/to/audio.wav", params) You can see model information: ```ruby -whisper = Whisper::Context.new(Whisper::Model["base"]) +whisper = Whisper::Context.new("base") model = whisper.model model.n_vocab # => 51864 @@ -200,7 +199,7 @@ Using this feature, you are also able to suppress log: Whisper.log_set ->(level, buffer, user_data) { # do nothing }, nil -Whisper::Context.new(MODEL) +Whisper::Context.new("base") ``` ### Low-level API to transcribe ### @@ -214,7 +213,7 @@ require "wavefile" reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000)) samples = reader.enum_for(:each_buffer).map(&:samples).flatten -whisper = Whisper::Context.new(Whisper::Model["base"]) +whisper = Whisper::Context.new("base") whisper.full(Whisper::Params.new, samples) whisper.each_segment do |segment| puts segment.text diff --git a/bindings/ruby/Rakefile b/bindings/ruby/Rakefile index f640dce9..3a7809b7 100644 --- a/bindings/ruby/Rakefile +++ b/bindings/ruby/Rakefile @@ -25,7 +25,6 @@ task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whis directory "pkg" CLOBBER.include "pkg" -TEST_MODEL = "../../models/ggml-base.en.bin" LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"]) SO_FILE = File.join("ext", LIB_NAME) LIB_FILE = File.join("lib", LIB_NAME) @@ -41,23 +40,17 @@ file SO_FILE => "ext/Makefile" do |t| sh "make" end end -CLEAN.include LIB_FILE +CLEAN.include SO_FILE directory "lib" file LIB_FILE => [SO_FILE, "lib"] do |t| copy t.source, t.name end +CLEAN.include LIB_FILE Rake::TestTask.new do |t| t.test_files = FileList["tests/test_*.rb"] end -task test: [TEST_MODEL, LIB_FILE] - -file TEST_MODEL do - Dir.chdir "../.." do - sh "./models/download-ggml-model.sh base.en" - end -end TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}" file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t| @@ -67,4 +60,5 @@ file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t| end end CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}" -task test: TEST_MEMORY_VIEW + +task test: [LIB_FILE, TEST_MEMORY_VIEW] diff --git a/bindings/ruby/ext/extconf.rb b/bindings/ruby/ext/extconf.rb index 59388ffe..fbae2517 100644 --- a/bindings/ruby/ext/extconf.rb +++ b/bindings/ruby/ext/extconf.rb @@ -111,11 +111,6 @@ unless ENV['RISCV'] $MK_CFLAGS << ' -march=native -mtune=native' $HOST_CXXFLAGS << ' -march=native -mtune=native' end - - if $UNAME_M.match? /aarch64.*/ - $MK_CFLAGS << ' -mcpu=native' - $MK_CXXFLAGS << ' -mcpu=native' - end else $MK_CFLAGS << ' -march=rv64gcv -mabi=lp64d' $MK_CXXFLAGS << ' -march=rv64gcv -mabi=lp64d' diff --git a/bindings/ruby/ext/ruby_whisper.cpp b/bindings/ruby/ext/ruby_whisper.cpp index 83fc53fc..26e9def4 100644 --- a/bindings/ruby/ext/ruby_whisper.cpp +++ b/bindings/ruby/ext/ruby_whisper.cpp @@ -38,6 +38,9 @@ VALUE cContext; VALUE cParams; VALUE eError; +VALUE cSegment; +VALUE cModel; + static ID id_to_s; static ID id_call; static ID id___method__; @@ -46,6 +49,7 @@ static ID id_length; static ID id_next; static ID id_new; static ID id_to_path; +static ID id_pre_converted_models; static bool is_log_callback_finalized = false; @@ -187,6 +191,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) { ruby_whisper_params *rwp; rwp = ALLOC(ruby_whisper_params); rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY); + rwp->diarize = false; rwp->new_segment_callback_container = rb_whisper_callback_container_allocate(); rwp->progress_callback_container = rb_whisper_callback_container_allocate(); rwp->abort_callback_container = rb_whisper_callback_container_allocate(); @@ -195,7 +200,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) { /* * call-seq: - * new(Whisper::Model["base.en"]) -> Whisper::Context + * new("base.en") -> Whisper::Context * new("path/to/model.bin") -> Whisper::Context * new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context */ @@ -207,6 +212,11 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) { rb_scan_args(argc, argv, "01", &whisper_model_file_path); Data_Get_Struct(self, ruby_whisper, rw); + VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0); + VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path); + if (!NIL_P(pre_converted_model)) { + whisper_model_file_path = pre_converted_model; + } if (rb_respond_to(whisper_model_file_path, id_to_path)) { whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0); } @@ -1251,6 +1261,25 @@ static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) { rwp->params.logprob_thold = RFLOAT_VALUE(value); return value; } +/* + * call-seq: + * no_speech_thold -> Float + */ +static VALUE ruby_whisper_params_get_no_speech_thold(VALUE self) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + return DBL2NUM(rwp->params.no_speech_thold); +} +/* + * call-seq: + * no_speech_thold = threshold -> threshold + */ +static VALUE ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) { + ruby_whisper_params *rwp; + Data_Get_Struct(self, ruby_whisper_params, rwp); + rwp->params.no_speech_thold = RFLOAT_VALUE(value); + return value; +} /* * Sets new segment callback, called for every newly generated text segment. * @@ -1347,9 +1376,6 @@ typedef struct { VALUE context; } ruby_whisper_model; -VALUE cSegment; -VALUE cModel; - static void rb_whisper_segment_mark(ruby_whisper_segment *rws) { rb_gc_mark(rws->context); } @@ -1740,6 +1766,7 @@ void Init_whisper() { id_next = rb_intern("next"); id_new = rb_intern("new"); id_to_path = rb_intern("to_path"); + id_pre_converted_models = rb_intern("pre_converted_models"); mWhisper = rb_define_module("Whisper"); cContext = rb_define_class_under(mWhisper, "Context", rb_cObject); @@ -1835,6 +1862,8 @@ void Init_whisper() { rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1); rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0); rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1); + rb_define_method(cParams, "no_speech_thold", ruby_whisper_params_get_no_speech_thold, 0); + rb_define_method(cParams, "no_speech_thold=", ruby_whisper_params_set_no_speech_thold, 1); rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1); rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1); diff --git a/bindings/ruby/lib/whisper.rb b/bindings/ruby/lib/whisper.rb index 4c8e01e2..3a0b844e 100644 --- a/bindings/ruby/lib/whisper.rb +++ b/bindings/ruby/lib/whisper.rb @@ -1,2 +1,2 @@ require "whisper.so" -require "whisper/model" +require "whisper/model/uri" diff --git a/bindings/ruby/lib/whisper/model.rb b/bindings/ruby/lib/whisper/model/uri.rb similarity index 88% rename from bindings/ruby/lib/whisper/model.rb rename to bindings/ruby/lib/whisper/model/uri.rb index be67dff3..5ca77ed4 100644 --- a/bindings/ruby/lib/whisper/model.rb +++ b/bindings/ruby/lib/whisper/model/uri.rb @@ -1,6 +1,7 @@ require "whisper.so" require "uri" require "net/http" +require "time" require "pathname" require "io/console/size" @@ -56,9 +57,11 @@ class Whisper::Model when Net::HTTPOK download response when Net::HTTPRedirection - request URI(response["location"]) + request URI(response["location"]), headers else - raise response + return if headers.key?("if-modified-since") # Use cache file + + raise "#{response.code} #{response.message}\n#{response.body}" end end end @@ -81,6 +84,7 @@ class Whisper::Model end def show_progress(current, size) + return unless $stderr.tty? return unless size unless @prev @@ -111,7 +115,7 @@ class Whisper::Model end end - @names = {} + @pre_converted_models = {} %w[ tiny tiny.en @@ -137,23 +141,17 @@ class Whisper::Model large-v1 large-v2 large-v2-q5_0 - large-v2-8_0 + large-v2-q8_0 large-v3 large-v3-q5_0 large-v3-turbo large-v3-turbo-q5_0 large-v3-turbo-q8_0 ].each do |name| - @names[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin") + @pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin") end class << self - def [](name) - @names[name] - end - - def preconverted_model_names - @names.keys - end + attr_reader :pre_converted_models end end diff --git a/bindings/ruby/tests/helper.rb b/bindings/ruby/tests/helper.rb index 0c761a37..da52f268 100644 --- a/bindings/ruby/tests/helper.rb +++ b/bindings/ruby/tests/helper.rb @@ -3,6 +3,5 @@ require "whisper" require_relative "jfk_reader/jfk_reader" class TestBase < Test::Unit::TestCase - MODEL = File.join(__dir__, "..", "..", "..", "models", "ggml-base.en.bin") AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav") end diff --git a/bindings/ruby/tests/test_callback.rb b/bindings/ruby/tests/test_callback.rb index 1234d31d..8e9becf7 100644 --- a/bindings/ruby/tests/test_callback.rb +++ b/bindings/ruby/tests/test_callback.rb @@ -1,14 +1,11 @@ -require "test/unit" -require "whisper" - -class TestCallback < Test::Unit::TestCase - TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..')) +require_relative "helper" +class TestCallback < TestBase def setup GC.start @params = Whisper::Params.new - @whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin')) - @audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav') + @whisper = Whisper::Context.new("base.en") + @audio = File.join(AUDIO) end def test_new_segment_callback diff --git a/bindings/ruby/tests/test_model.rb b/bindings/ruby/tests/test_model.rb index 598dbde9..1362fc46 100644 --- a/bindings/ruby/tests/test_model.rb +++ b/bindings/ruby/tests/test_model.rb @@ -3,12 +3,12 @@ require "pathname" class TestModel < TestBase def test_model - whisper = Whisper::Context.new(MODEL) + whisper = Whisper::Context.new("base.en") assert_instance_of Whisper::Model, whisper.model end def test_attributes - whisper = Whisper::Context.new(MODEL) + whisper = Whisper::Context.new("base.en") model = whisper.model assert_equal 51864, model.n_vocab @@ -26,7 +26,7 @@ class TestModel < TestBase end def test_gc - model = Whisper::Context.new(MODEL).model + model = Whisper::Context.new("base.en").model GC.start assert_equal 51864, model.n_vocab @@ -44,7 +44,7 @@ class TestModel < TestBase end def test_pathname - path = Pathname(MODEL) + path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path) whisper = Whisper::Context.new(path) model = whisper.model @@ -61,4 +61,11 @@ class TestModel < TestBase assert_equal 1, model.ftype assert_equal "base", model.type end + + def test_auto_download + path = Whisper::Model.pre_converted_models["base.en"].to_path + + assert_path_exist path + assert_equal 147964211, File.size(path) + end end diff --git a/bindings/ruby/tests/test_params.rb b/bindings/ruby/tests/test_params.rb index bf73fd6b..d2667ef0 100644 --- a/bindings/ruby/tests/test_params.rb +++ b/bindings/ruby/tests/test_params.rb @@ -151,4 +151,10 @@ class TestParams < TestBase @params.logprob_thold = -0.5 assert_in_delta -0.5, @params.logprob_thold end + + def test_no_speech_thold + assert_in_delta 0.6, @params.no_speech_thold + @params.no_speech_thold = 0.2 + assert_in_delta 0.2, @params.no_speech_thold + end end diff --git a/bindings/ruby/tests/test_segment.rb b/bindings/ruby/tests/test_segment.rb index 8129ae5d..559bcea7 100644 --- a/bindings/ruby/tests/test_segment.rb +++ b/bindings/ruby/tests/test_segment.rb @@ -5,7 +5,7 @@ class TestSegment < TestBase attr_reader :whisper def startup - @whisper = Whisper::Context.new(TestBase::MODEL) + @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new params.print_timestamps = false @whisper.transcribe(TestBase::AUDIO, params) diff --git a/bindings/ruby/tests/test_whisper.rb b/bindings/ruby/tests/test_whisper.rb index 1f3ac269..115569ed 100644 --- a/bindings/ruby/tests/test_whisper.rb +++ b/bindings/ruby/tests/test_whisper.rb @@ -11,7 +11,7 @@ class TestWhisper < TestBase end def test_whisper - @whisper = Whisper::Context.new(MODEL) + @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new params.print_timestamps = false @@ -25,7 +25,7 @@ class TestWhisper < TestBase attr_reader :whisper def startup - @whisper = Whisper::Context.new(TestBase::MODEL) + @whisper = Whisper::Context.new("base.en") params = Whisper::Params.new params.print_timestamps = false @whisper.transcribe(TestBase::AUDIO, params) @@ -104,7 +104,7 @@ class TestWhisper < TestBase logs << [level, buffer, udata] } Whisper.log_set log_callback, user_data - Whisper::Context.new(MODEL) + Whisper::Context.new("base.en") assert logs.length > 30 logs.each do |log| @@ -120,7 +120,7 @@ class TestWhisper < TestBase }, nil dev = StringIO.new("") $stderr = dev - Whisper::Context.new(MODEL) + Whisper::Context.new("base.en") assert_empty dev.string ensure $stderr = stderr @@ -129,7 +129,7 @@ class TestWhisper < TestBase sub_test_case "full" do def setup super - @whisper = Whisper::Context.new(MODEL) + @whisper = Whisper::Context.new("base.en") @samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15} end diff --git a/bindings/ruby/whispercpp.gemspec b/bindings/ruby/whispercpp.gemspec index 2fd9cad9..d8f5c0d8 100644 --- a/bindings/ruby/whispercpp.gemspec +++ b/bindings/ruby/whispercpp.gemspec @@ -23,7 +23,7 @@ Gem::Specification.new do |s| s.test_files = s.files.select {|file| file.start_with? "tests/"} s.extensions << 'ext/extconf.rb' - + s.required_ruby_version = '>= 3.1.0' #### Documentation and testing. s.homepage = 'https://github.com/ggerganov/whisper.cpp'