mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2024-12-18 20:27:53 +00:00
ruby : Add no_speech_thold (#2641)
* Remove Whisper::Model.[] * Fix Whisper::Model::URI#request * Make Whisper::Context#initialize accept pre-converted model name * Use downloading pre-converted model feature for testing * Update README * Remove unnecessary task * Move whisper/model.rb -> whisper/model/uri.rb * Update document comment of Whisper::Context#initialize * Don't show download progress when not tty * Pass String to raise * Use cache model file if download fails * Add test for auto download * Specify required Ruby version * Fix a typo * Remove unnecessary flags * Initialize Whisper::Params#diarize explicitely * Remove redundant code from README for simplicity * Add Whisper::Params#no_speech_thold attribute * Add test for Whisper::Params#no_speech_thold
This commit is contained in:
parent
d34445e960
commit
3a27b2b91b
@ -22,7 +22,7 @@ Usage
|
||||
```ruby
|
||||
require "whisper"
|
||||
|
||||
whisper = Whisper::Context.new(Whisper::Model["base"])
|
||||
whisper = Whisper::Context.new("base")
|
||||
|
||||
params = Whisper::Params.new
|
||||
params.language = "en"
|
||||
@ -44,17 +44,23 @@ end
|
||||
Some models are prepared up-front:
|
||||
|
||||
```ruby
|
||||
base_en = Whisper::Model["base.en"]
|
||||
base_en = Whisper::Model.pre_converted_models["base.en"]
|
||||
whisper = Whisper::Context.new(base_en)
|
||||
```
|
||||
|
||||
At first time you use a model, it is downloaded automatically. After that, downloaded cached file is used. To clear cache, call `#clear_cache`:
|
||||
|
||||
```ruby
|
||||
Whisper::Model["base"].clear_cache
|
||||
Whisper::Model.pre_converted_models["base"].clear_cache
|
||||
```
|
||||
|
||||
You can see the list of prepared model names by `Whisper::Model.preconverted_model_names`:
|
||||
You also can use shorthand for pre-converted models:
|
||||
|
||||
```ruby
|
||||
whisper = Whisper::Context.new("base.en")
|
||||
```
|
||||
|
||||
You can see the list of prepared model names by `Whisper::Model.preconverted_models.keys`:
|
||||
|
||||
```ruby
|
||||
puts Whisper::Model.preconverted_model_names
|
||||
@ -124,13 +130,6 @@ end
|
||||
You can also add hook to params called on new segment:
|
||||
|
||||
```ruby
|
||||
def format_time(time_ms)
|
||||
sec, decimal_part = time_ms.divmod(1000)
|
||||
min, sec = sec.divmod(60)
|
||||
hour, min = min.divmod(60)
|
||||
"%02d:%02d:%02d.%03d" % [hour, min, sec, decimal_part]
|
||||
end
|
||||
|
||||
# Add hook before calling #transcribe
|
||||
params.on_new_segment do |segment|
|
||||
line = "[%{st} --> %{ed}] %{text}" % {
|
||||
@ -151,7 +150,7 @@ whisper.transcribe("path/to/audio.wav", params)
|
||||
You can see model information:
|
||||
|
||||
```ruby
|
||||
whisper = Whisper::Context.new(Whisper::Model["base"])
|
||||
whisper = Whisper::Context.new("base")
|
||||
model = whisper.model
|
||||
|
||||
model.n_vocab # => 51864
|
||||
@ -200,7 +199,7 @@ Using this feature, you are also able to suppress log:
|
||||
Whisper.log_set ->(level, buffer, user_data) {
|
||||
# do nothing
|
||||
}, nil
|
||||
Whisper::Context.new(MODEL)
|
||||
Whisper::Context.new("base")
|
||||
```
|
||||
|
||||
### Low-level API to transcribe ###
|
||||
@ -214,7 +213,7 @@ require "wavefile"
|
||||
reader = WaveFile::Reader.new("path/to/audio.wav", WaveFile::Format.new(:mono, :float, 16000))
|
||||
samples = reader.enum_for(:each_buffer).map(&:samples).flatten
|
||||
|
||||
whisper = Whisper::Context.new(Whisper::Model["base"])
|
||||
whisper = Whisper::Context.new("base")
|
||||
whisper.full(Whisper::Params.new, samples)
|
||||
whisper.each_segment do |segment|
|
||||
puts segment.text
|
||||
|
@ -25,7 +25,6 @@ task build: ["ext/Makefile", "ext/ruby_whisper.h", "ext/ruby_whisper.cpp", "whis
|
||||
directory "pkg"
|
||||
CLOBBER.include "pkg"
|
||||
|
||||
TEST_MODEL = "../../models/ggml-base.en.bin"
|
||||
LIB_NAME = "whisper".ext(RbConfig::CONFIG["DLEXT"])
|
||||
SO_FILE = File.join("ext", LIB_NAME)
|
||||
LIB_FILE = File.join("lib", LIB_NAME)
|
||||
@ -41,23 +40,17 @@ file SO_FILE => "ext/Makefile" do |t|
|
||||
sh "make"
|
||||
end
|
||||
end
|
||||
CLEAN.include LIB_FILE
|
||||
CLEAN.include SO_FILE
|
||||
|
||||
directory "lib"
|
||||
file LIB_FILE => [SO_FILE, "lib"] do |t|
|
||||
copy t.source, t.name
|
||||
end
|
||||
CLEAN.include LIB_FILE
|
||||
|
||||
Rake::TestTask.new do |t|
|
||||
t.test_files = FileList["tests/test_*.rb"]
|
||||
end
|
||||
task test: [TEST_MODEL, LIB_FILE]
|
||||
|
||||
file TEST_MODEL do
|
||||
Dir.chdir "../.." do
|
||||
sh "./models/download-ggml-model.sh base.en"
|
||||
end
|
||||
end
|
||||
|
||||
TEST_MEMORY_VIEW = "tests/jfk_reader/jfk_reader.#{RbConfig::CONFIG['DLEXT']}"
|
||||
file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
|
||||
@ -67,4 +60,5 @@ file TEST_MEMORY_VIEW => "tests/jfk_reader/jfk_reader.c" do |t|
|
||||
end
|
||||
end
|
||||
CLEAN.include "tests/jfk_reader/jfk_reader.{o,#{RbConfig::CONFIG['DLEXT']}}"
|
||||
task test: TEST_MEMORY_VIEW
|
||||
|
||||
task test: [LIB_FILE, TEST_MEMORY_VIEW]
|
||||
|
@ -111,11 +111,6 @@ unless ENV['RISCV']
|
||||
$MK_CFLAGS << ' -march=native -mtune=native'
|
||||
$HOST_CXXFLAGS << ' -march=native -mtune=native'
|
||||
end
|
||||
|
||||
if $UNAME_M.match? /aarch64.*/
|
||||
$MK_CFLAGS << ' -mcpu=native'
|
||||
$MK_CXXFLAGS << ' -mcpu=native'
|
||||
end
|
||||
else
|
||||
$MK_CFLAGS << ' -march=rv64gcv -mabi=lp64d'
|
||||
$MK_CXXFLAGS << ' -march=rv64gcv -mabi=lp64d'
|
||||
|
@ -38,6 +38,9 @@ VALUE cContext;
|
||||
VALUE cParams;
|
||||
VALUE eError;
|
||||
|
||||
VALUE cSegment;
|
||||
VALUE cModel;
|
||||
|
||||
static ID id_to_s;
|
||||
static ID id_call;
|
||||
static ID id___method__;
|
||||
@ -46,6 +49,7 @@ static ID id_length;
|
||||
static ID id_next;
|
||||
static ID id_new;
|
||||
static ID id_to_path;
|
||||
static ID id_pre_converted_models;
|
||||
|
||||
static bool is_log_callback_finalized = false;
|
||||
|
||||
@ -187,6 +191,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
|
||||
ruby_whisper_params *rwp;
|
||||
rwp = ALLOC(ruby_whisper_params);
|
||||
rwp->params = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
rwp->diarize = false;
|
||||
rwp->new_segment_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->progress_callback_container = rb_whisper_callback_container_allocate();
|
||||
rwp->abort_callback_container = rb_whisper_callback_container_allocate();
|
||||
@ -195,7 +200,7 @@ static VALUE ruby_whisper_params_allocate(VALUE klass) {
|
||||
|
||||
/*
|
||||
* call-seq:
|
||||
* new(Whisper::Model["base.en"]) -> Whisper::Context
|
||||
* new("base.en") -> Whisper::Context
|
||||
* new("path/to/model.bin") -> Whisper::Context
|
||||
* new(Whisper::Model::URI.new("https://example.net/uri/of/model.bin")) -> Whisper::Context
|
||||
*/
|
||||
@ -207,6 +212,11 @@ static VALUE ruby_whisper_initialize(int argc, VALUE *argv, VALUE self) {
|
||||
rb_scan_args(argc, argv, "01", &whisper_model_file_path);
|
||||
Data_Get_Struct(self, ruby_whisper, rw);
|
||||
|
||||
VALUE pre_converted_models = rb_funcall(cModel, id_pre_converted_models, 0);
|
||||
VALUE pre_converted_model = rb_hash_aref(pre_converted_models, whisper_model_file_path);
|
||||
if (!NIL_P(pre_converted_model)) {
|
||||
whisper_model_file_path = pre_converted_model;
|
||||
}
|
||||
if (rb_respond_to(whisper_model_file_path, id_to_path)) {
|
||||
whisper_model_file_path = rb_funcall(whisper_model_file_path, id_to_path, 0);
|
||||
}
|
||||
@ -1251,6 +1261,25 @@ static VALUE ruby_whisper_params_set_logprob_thold(VALUE self, VALUE value) {
|
||||
rwp->params.logprob_thold = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
/*
|
||||
* call-seq:
|
||||
* no_speech_thold -> Float
|
||||
*/
|
||||
static VALUE ruby_whisper_params_get_no_speech_thold(VALUE self) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
return DBL2NUM(rwp->params.no_speech_thold);
|
||||
}
|
||||
/*
|
||||
* call-seq:
|
||||
* no_speech_thold = threshold -> threshold
|
||||
*/
|
||||
static VALUE ruby_whisper_params_set_no_speech_thold(VALUE self, VALUE value) {
|
||||
ruby_whisper_params *rwp;
|
||||
Data_Get_Struct(self, ruby_whisper_params, rwp);
|
||||
rwp->params.no_speech_thold = RFLOAT_VALUE(value);
|
||||
return value;
|
||||
}
|
||||
/*
|
||||
* Sets new segment callback, called for every newly generated text segment.
|
||||
*
|
||||
@ -1347,9 +1376,6 @@ typedef struct {
|
||||
VALUE context;
|
||||
} ruby_whisper_model;
|
||||
|
||||
VALUE cSegment;
|
||||
VALUE cModel;
|
||||
|
||||
static void rb_whisper_segment_mark(ruby_whisper_segment *rws) {
|
||||
rb_gc_mark(rws->context);
|
||||
}
|
||||
@ -1740,6 +1766,7 @@ void Init_whisper() {
|
||||
id_next = rb_intern("next");
|
||||
id_new = rb_intern("new");
|
||||
id_to_path = rb_intern("to_path");
|
||||
id_pre_converted_models = rb_intern("pre_converted_models");
|
||||
|
||||
mWhisper = rb_define_module("Whisper");
|
||||
cContext = rb_define_class_under(mWhisper, "Context", rb_cObject);
|
||||
@ -1835,6 +1862,8 @@ void Init_whisper() {
|
||||
rb_define_method(cParams, "entropy_thold=", ruby_whisper_params_set_entropy_thold, 1);
|
||||
rb_define_method(cParams, "logprob_thold", ruby_whisper_params_get_logprob_thold, 0);
|
||||
rb_define_method(cParams, "logprob_thold=", ruby_whisper_params_set_logprob_thold, 1);
|
||||
rb_define_method(cParams, "no_speech_thold", ruby_whisper_params_get_no_speech_thold, 0);
|
||||
rb_define_method(cParams, "no_speech_thold=", ruby_whisper_params_set_no_speech_thold, 1);
|
||||
|
||||
rb_define_method(cParams, "new_segment_callback=", ruby_whisper_params_set_new_segment_callback, 1);
|
||||
rb_define_method(cParams, "new_segment_callback_user_data=", ruby_whisper_params_set_new_segment_callback_user_data, 1);
|
||||
|
@ -1,2 +1,2 @@
|
||||
require "whisper.so"
|
||||
require "whisper/model"
|
||||
require "whisper/model/uri"
|
||||
|
@ -1,6 +1,7 @@
|
||||
require "whisper.so"
|
||||
require "uri"
|
||||
require "net/http"
|
||||
require "time"
|
||||
require "pathname"
|
||||
require "io/console/size"
|
||||
|
||||
@ -56,9 +57,11 @@ class Whisper::Model
|
||||
when Net::HTTPOK
|
||||
download response
|
||||
when Net::HTTPRedirection
|
||||
request URI(response["location"])
|
||||
request URI(response["location"]), headers
|
||||
else
|
||||
raise response
|
||||
return if headers.key?("if-modified-since") # Use cache file
|
||||
|
||||
raise "#{response.code} #{response.message}\n#{response.body}"
|
||||
end
|
||||
end
|
||||
end
|
||||
@ -81,6 +84,7 @@ class Whisper::Model
|
||||
end
|
||||
|
||||
def show_progress(current, size)
|
||||
return unless $stderr.tty?
|
||||
return unless size
|
||||
|
||||
unless @prev
|
||||
@ -111,7 +115,7 @@ class Whisper::Model
|
||||
end
|
||||
end
|
||||
|
||||
@names = {}
|
||||
@pre_converted_models = {}
|
||||
%w[
|
||||
tiny
|
||||
tiny.en
|
||||
@ -137,23 +141,17 @@ class Whisper::Model
|
||||
large-v1
|
||||
large-v2
|
||||
large-v2-q5_0
|
||||
large-v2-8_0
|
||||
large-v2-q8_0
|
||||
large-v3
|
||||
large-v3-q5_0
|
||||
large-v3-turbo
|
||||
large-v3-turbo-q5_0
|
||||
large-v3-turbo-q8_0
|
||||
].each do |name|
|
||||
@names[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
|
||||
@pre_converted_models[name] = URI.new("https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-#{name}.bin")
|
||||
end
|
||||
|
||||
class << self
|
||||
def [](name)
|
||||
@names[name]
|
||||
end
|
||||
|
||||
def preconverted_model_names
|
||||
@names.keys
|
||||
end
|
||||
attr_reader :pre_converted_models
|
||||
end
|
||||
end
|
@ -3,6 +3,5 @@ require "whisper"
|
||||
require_relative "jfk_reader/jfk_reader"
|
||||
|
||||
class TestBase < Test::Unit::TestCase
|
||||
MODEL = File.join(__dir__, "..", "..", "..", "models", "ggml-base.en.bin")
|
||||
AUDIO = File.join(__dir__, "..", "..", "..", "samples", "jfk.wav")
|
||||
end
|
||||
|
@ -1,14 +1,11 @@
|
||||
require "test/unit"
|
||||
require "whisper"
|
||||
|
||||
class TestCallback < Test::Unit::TestCase
|
||||
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
|
||||
require_relative "helper"
|
||||
|
||||
class TestCallback < TestBase
|
||||
def setup
|
||||
GC.start
|
||||
@params = Whisper::Params.new
|
||||
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
|
||||
@audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
@audio = File.join(AUDIO)
|
||||
end
|
||||
|
||||
def test_new_segment_callback
|
||||
|
@ -3,12 +3,12 @@ require "pathname"
|
||||
|
||||
class TestModel < TestBase
|
||||
def test_model
|
||||
whisper = Whisper::Context.new(MODEL)
|
||||
whisper = Whisper::Context.new("base.en")
|
||||
assert_instance_of Whisper::Model, whisper.model
|
||||
end
|
||||
|
||||
def test_attributes
|
||||
whisper = Whisper::Context.new(MODEL)
|
||||
whisper = Whisper::Context.new("base.en")
|
||||
model = whisper.model
|
||||
|
||||
assert_equal 51864, model.n_vocab
|
||||
@ -26,7 +26,7 @@ class TestModel < TestBase
|
||||
end
|
||||
|
||||
def test_gc
|
||||
model = Whisper::Context.new(MODEL).model
|
||||
model = Whisper::Context.new("base.en").model
|
||||
GC.start
|
||||
|
||||
assert_equal 51864, model.n_vocab
|
||||
@ -44,7 +44,7 @@ class TestModel < TestBase
|
||||
end
|
||||
|
||||
def test_pathname
|
||||
path = Pathname(MODEL)
|
||||
path = Pathname(Whisper::Model.pre_converted_models["base.en"].to_path)
|
||||
whisper = Whisper::Context.new(path)
|
||||
model = whisper.model
|
||||
|
||||
@ -61,4 +61,11 @@ class TestModel < TestBase
|
||||
assert_equal 1, model.ftype
|
||||
assert_equal "base", model.type
|
||||
end
|
||||
|
||||
def test_auto_download
|
||||
path = Whisper::Model.pre_converted_models["base.en"].to_path
|
||||
|
||||
assert_path_exist path
|
||||
assert_equal 147964211, File.size(path)
|
||||
end
|
||||
end
|
||||
|
@ -151,4 +151,10 @@ class TestParams < TestBase
|
||||
@params.logprob_thold = -0.5
|
||||
assert_in_delta -0.5, @params.logprob_thold
|
||||
end
|
||||
|
||||
def test_no_speech_thold
|
||||
assert_in_delta 0.6, @params.no_speech_thold
|
||||
@params.no_speech_thold = 0.2
|
||||
assert_in_delta 0.2, @params.no_speech_thold
|
||||
end
|
||||
end
|
||||
|
@ -5,7 +5,7 @@ class TestSegment < TestBase
|
||||
attr_reader :whisper
|
||||
|
||||
def startup
|
||||
@whisper = Whisper::Context.new(TestBase::MODEL)
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
@whisper.transcribe(TestBase::AUDIO, params)
|
||||
|
@ -11,7 +11,7 @@ class TestWhisper < TestBase
|
||||
end
|
||||
|
||||
def test_whisper
|
||||
@whisper = Whisper::Context.new(MODEL)
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
|
||||
@ -25,7 +25,7 @@ class TestWhisper < TestBase
|
||||
attr_reader :whisper
|
||||
|
||||
def startup
|
||||
@whisper = Whisper::Context.new(TestBase::MODEL)
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
params = Whisper::Params.new
|
||||
params.print_timestamps = false
|
||||
@whisper.transcribe(TestBase::AUDIO, params)
|
||||
@ -104,7 +104,7 @@ class TestWhisper < TestBase
|
||||
logs << [level, buffer, udata]
|
||||
}
|
||||
Whisper.log_set log_callback, user_data
|
||||
Whisper::Context.new(MODEL)
|
||||
Whisper::Context.new("base.en")
|
||||
|
||||
assert logs.length > 30
|
||||
logs.each do |log|
|
||||
@ -120,7 +120,7 @@ class TestWhisper < TestBase
|
||||
}, nil
|
||||
dev = StringIO.new("")
|
||||
$stderr = dev
|
||||
Whisper::Context.new(MODEL)
|
||||
Whisper::Context.new("base.en")
|
||||
assert_empty dev.string
|
||||
ensure
|
||||
$stderr = stderr
|
||||
@ -129,7 +129,7 @@ class TestWhisper < TestBase
|
||||
sub_test_case "full" do
|
||||
def setup
|
||||
super
|
||||
@whisper = Whisper::Context.new(MODEL)
|
||||
@whisper = Whisper::Context.new("base.en")
|
||||
@samples = File.read(AUDIO, nil, 78).unpack("s<*").collect {|i| i.to_f / 2**15}
|
||||
end
|
||||
|
||||
|
@ -23,7 +23,7 @@ Gem::Specification.new do |s|
|
||||
s.test_files = s.files.select {|file| file.start_with? "tests/"}
|
||||
|
||||
s.extensions << 'ext/extconf.rb'
|
||||
|
||||
s.required_ruby_version = '>= 3.1.0'
|
||||
|
||||
#### Documentation and testing.
|
||||
s.homepage = 'https://github.com/ggerganov/whisper.cpp'
|
||||
|
Loading…
Reference in New Issue
Block a user