ruby : Add no_speech_thold (#2641)

* Remove Whisper::Model.[]

* Fix Whisper::Model::URI#request

* Make Whisper::Context#initialize accept pre-converted model name

* Use downloading pre-converted model feature for testing

* Update README

* Remove unnecessary task

* Move whisper/model.rb -> whisper/model/uri.rb

* Update document comment of Whisper::Context#initialize

* Don't show download progress when not tty

* Pass String to raise

* Use cache model file if download fails

* Add test for auto download

* Specify required Ruby version

* Fix a typo

* Remove unnecessary flags

* Initialize Whisper::Params#diarize explicitely

* Remove redundant code from README for simplicity

* Add Whisper::Params#no_speech_thold attribute

* Add test for Whisper::Params#no_speech_thold
This commit is contained in:
KITAITI Makoto 2024-12-18 18:00:50 +09:00 committed by GitHub
parent d34445e960
commit 3a27b2b91b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 89 additions and 65 deletions

View File

@ -22,7 +22,7 @@ Usage
```ruby
require "whisper"
whisper = Whisper::Context.new(Whisper::Model["base"])
whisper = Whisper::Context.new("base")
params = Whisper::Params.new
params.language = "en"
@ -44,17 +44,23 @@ end
Some models are prepared up-front:
```ruby
base_en = Whisper::Model["base.en"]
base_en = Whisper::Model.pre_converted_models["base.en"]
whisper = Whisper::Context.new(base_en)
```
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
```ruby
Whisper::Model["base"].clear_cache
Whisper::Model.pre_converted_models["base"].clear_cache
```
You can see the list of prepared model names by `Whisper::Model.preconverted_model_names`:
You also can use shorthand for pre-converted models:
```ruby
whisper = Whisper::Context.new("base.en")
```
You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`:
```ruby
puts Whisper::Model.preconverted_model_names
@ -124,13 +130,6 @@ end
You can also add hook to params called on new segment:
```ruby
def format_time(time_ms)
sec, decimal_part = time_ms.divmod(1000)
min, sec = sec.divmod(60)
hour, min = min.divmod(60)
"%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
end
# Add hook before calling #transcribe
params.on_new_segment do |segment|
line = "[%{st} --> %{ed}] %{text}" % {
@ -151,7 +150,7 @@ whisper.transcribe("path/to/audio.wav", params)
You can see model information:
```ruby
whisper = Whisper::Context.new(Whisper::Model["base"])
whisper = Whisper::Context.new("base")
model = whisper.model
model.n_vocab # => 51864
@ -200,7 +199,7 @@ Using this feature, you are also able to suppress log:
Whisper.log_set ->(level, buffer, user_data) {
# do nothing
}, nil
Whisper::Context.new(MODEL)
Whisper::Context.new("base")
```
### Low-level API to transcribe ###
@ -214,7 +213,7 @@ require "wavefile"
reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000))
samples = reader.enum_for(:each_buffer).map(&:samples).flatten
whisper = Whisper::Context.new(Whisper::Model["base"])
whisper = Whisper::Context.new("base")
whisper.full(Whisper::Params.new, samples)
whisper.each_segment do |segment|
puts segment.text

View File

@ -25,7 +25,6 @@ task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whis
directory "pkg"
CLOBBER.include "pkg"
TEST_MODEL = "../../models/ggml-base.en.bin"
LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"])
SO_FILE = File.join("ext", LIB_NAME)
LIB_FILE = File.join("lib", LIB_NAME)
@ -41,23 +40,17 @@ file SO_FILE => "ext/Makefile" do |t|
sh "make"
end
end
CLEAN.include LIB_FILE
CLEAN.include SO_FILE
directory "lib"
file LIB_FILE => [SO_FILE, "lib"] do |t|
copy t.source, t.name
end
CLEAN.include LIB_FILE
Rake::TestTask.new do |t|
t.test_files = FileList["tests/test_*.rb"]
end
task test: [TEST_MODEL, LIB_FILE]
file TEST_MODEL do
Dir.chdir "../.." do
sh "./models/download-ggml-model.sh base.en"
end
end
TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
@ -67,4 +60,5 @@ file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
end
end
CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
task test: TEST_MEMORY_VIEW
task test: [LIB_FILE, TEST_MEMORY_VIEW]

View File

@ -111,11 +111,6 @@ unless ENV['RISCV']
$MK_CFLAGS << ' -march=native -mtune=native'
$HOST_CXXFLAGS << ' -march=native -mtune=native'
end
if $UNAME_M.match? /aarch64.*/
$MK_CFLAGS << ' -mcpu=native'
$MK_CXXFLAGS << ' -mcpu=native'
end
else
$MK_CFLAGS << ' -march=rv64gcv -mabi=lp64d'
$MK_CXXFLAGS << ' -march=rv64gcv -mabi=lp64d'

View File

@ -38,6 +38,9 @@ VALUE cContext;
VALUE cParams;
VALUE eError;
VALUE cSegment;
VALUE cModel;
static ID id_to_s;
static ID id_call;
static ID id___method__;
@ -46,6 +49,7 @@ static ID id_length;
static ID id_next;
static ID id_new;
static ID id_to_path;
static ID id_pre_converted_models;
static bool is_log_callback_finalized = false;
@ -187,6 +191,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
ruby_whisper_params *rwp;
rwp = ALLOC(ruby_whisper_params);
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
rwp->diarize = false;
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
@ -195,7 +200,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
/*
* call-seq:
* new(Whisper::Model["base.en"]) -> Whisper::Context
* new("base.en") -> Whisper::Context
* new("path/to/model.bin") -> Whisper::Context
* new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
*/
@ -207,6 +212,11 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
Data_Get_Struct(self, ruby_whisper, rw);
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
if (!NIL_P(pre_converted_model)) {
whisper_model_file_path = pre_converted_model;
}
if (rb_respond_to(whisper_model_file_path, id_to_path)) {
whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
}
@ -1251,6 +1261,25 @@ static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) {
rwp->params.logprob_thold = RFLOAT_VALUE(value);
return value;
}
/*
* call-seq:
* no_speech_thold -> Float
*/
static VALUE ruby_whisper_params_get_no_speech_thold(VALUE self) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
return DBL2NUM(rwp->params.no_speech_thold);
}
/*
* call-seq:
* no_speech_thold = threshold -> threshold
*/
static VALUE ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) {
ruby_whisper_params *rwp;
Data_Get_Struct(self, ruby_whisper_params, rwp);
rwp->params.no_speech_thold = RFLOAT_VALUE(value);
return value;
}
/*
* Sets new segment callback, called for every newly generated text segment.
*
@ -1347,9 +1376,6 @@ typedef struct {
VALUE context;
} ruby_whisper_model;
VALUE cSegment;
VALUE cModel;
static void rb_whisper_segment_mark(ruby_whisper_segment *rws) {
rb_gc_mark(rws->context);
}
@ -1740,6 +1766,7 @@ void Init_whisper() {
id_next = rb_intern("next");
id_new = rb_intern("new");
id_to_path = rb_intern("to_path");
id_pre_converted_models = rb_intern("pre_converted_models");
mWhisper = rb_define_module("Whisper");
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
@ -1835,6 +1862,8 @@ void Init_whisper() {
rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1);
rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0);
rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1);
rb_define_method(cParams, "no_speech_thold", ruby_whisper_params_get_no_speech_thold, 0);
rb_define_method(cParams, "no_speech_thold=", ruby_whisper_params_set_no_speech_thold, 1);
rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);

View File

@ -1,2 +1,2 @@
require "whisper.so"
require "whisper/model"
require "whisper/model/uri"

View File

@ -1,6 +1,7 @@
require "whisper.so"
require "uri"
require "net/http"
require "time"
require "pathname"
require "io/console/size"
@ -56,9 +57,11 @@ class Whisper::Model
when Net::HTTPOK
download response
when Net::HTTPRedirection
request URI(response["location"])
request URI(response["location"]), headers
else
raise response
return if headers.key?("if-modified-since") # Use cache file
raise "#{response.code} #{response.message}\n#{response.body}"
end
end
end
@ -81,6 +84,7 @@ class Whisper::Model
end
def show_progress(current, size)
return unless $stderr.tty?
return unless size
unless @prev
@ -111,7 +115,7 @@ class Whisper::Model
end
end
@names = {}
@pre_converted_models = {}
%w[
tiny
tiny.en
@ -137,23 +141,17 @@ class Whisper::Model
large-v1
large-v2
large-v2-q5_0
large-v2-8_0
large-v2-q8_0
large-v3
large-v3-q5_0
large-v3-turbo
large-v3-turbo-q5_0
large-v3-turbo-q8_0
].each do |name|
@names[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
@pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
end
class << self
def [](name)
@names[name]
end
def preconverted_model_names
@names.keys
end
attr_reader :pre_converted_models
end
end

View File

@ -3,6 +3,5 @@ require "whisper"
require_relative "jfk_reader/jfk_reader"
class TestBase < Test::Unit::TestCase
MODEL = File.join(__dir__, "..", "..", "..", "models", "ggml-base.en.bin")
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
end

View File

@ -1,14 +1,11 @@
require "test/unit"
require "whisper"
class TestCallback < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
require_relative "helper"
class TestCallback < TestBase
def setup
GC.start
@params = Whisper::Params.new
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
@audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper = Whisper::Context.new("base.en")
@audio = File.join(AUDIO)
end
def test_new_segment_callback

View File

@ -3,12 +3,12 @@ require "pathname"
class TestModel < TestBase
def test_model
whisper = Whisper::Context.new(MODEL)
whisper = Whisper::Context.new("base.en")
assert_instance_of Whisper::Model, whisper.model
end
def test_attributes
whisper = Whisper::Context.new(MODEL)
whisper = Whisper::Context.new("base.en")
model = whisper.model
assert_equal 51864, model.n_vocab
@ -26,7 +26,7 @@ class TestModel < TestBase
end
def test_gc
model = Whisper::Context.new(MODEL).model
model = Whisper::Context.new("base.en").model
GC.start
assert_equal 51864, model.n_vocab
@ -44,7 +44,7 @@ class TestModel < TestBase
end
def test_pathname
path = Pathname(MODEL)
path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path)
whisper = Whisper::Context.new(path)
model = whisper.model
@ -61,4 +61,11 @@ class TestModel < TestBase
assert_equal 1, model.ftype
assert_equal "base", model.type
end
def test_auto_download
path = Whisper::Model.pre_converted_models["base.en"].to_path
assert_path_exist path
assert_equal 147964211, File.size(path)
end
end

View File

@ -151,4 +151,10 @@ class TestParams < TestBase
@params.logprob_thold = -0.5
assert_in_delta -0.5, @params.logprob_thold
end
def test_no_speech_thold
assert_in_delta 0.6, @params.no_speech_thold
@params.no_speech_thold = 0.2
assert_in_delta 0.2, @params.no_speech_thold
end
end

View File

@ -5,7 +5,7 @@ class TestSegment < TestBase
attr_reader :whisper
def startup
@whisper = Whisper::Context.new(TestBase::MODEL)
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)

View File

@ -11,7 +11,7 @@ class TestWhisper < TestBase
end
def test_whisper
@whisper = Whisper::Context.new(MODEL)
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@ -25,7 +25,7 @@ class TestWhisper < TestBase
attr_reader :whisper
def startup
@whisper = Whisper::Context.new(TestBase::MODEL)
@whisper = Whisper::Context.new("base.en")
params = Whisper::Params.new
params.print_timestamps = false
@whisper.transcribe(TestBase::AUDIO, params)
@ -104,7 +104,7 @@ class TestWhisper < TestBase
logs << [level, buffer, udata]
}
Whisper.log_set log_callback, user_data
Whisper::Context.new(MODEL)
Whisper::Context.new("base.en")
assert logs.length > 30
logs.each do |log|
@ -120,7 +120,7 @@ class TestWhisper < TestBase
}, nil
dev = StringIO.new("")
$stderr = dev
Whisper::Context.new(MODEL)
Whisper::Context.new("base.en")
assert_empty dev.string
ensure
$stderr = stderr
@ -129,7 +129,7 @@ class TestWhisper < TestBase
sub_test_case "full" do
def setup
super
@whisper = Whisper::Context.new(MODEL)
@whisper = Whisper::Context.new("base.en")
@samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
end

View File

@ -23,7 +23,7 @@ Gem::Specification.new do |s|
s.test_files = s.files.select {|file| file.start_with? "tests/"}
s.extensions << 'ext/extconf.rb'
s.required_ruby_version = '>= 3.1.0'
#### Documentation and testing.
s.homepage = 'https://github.com/ggerganov/whisper.cpp'