ruby : support new-segment callback (#2506)

* Add Params#new_segment_callback= method

* Add tests for Params#new_segment_callback=

* Group tests for #transcribe

* Don't use static for thread-safety

* Set new_segment_callback only when necessary

* Remove redundant check

* [skip ci] Add Ruby version README

* Revert "Group tests for #transcribe"

This reverts commit 71b65b00cc.

* Revert "Add tests for Params#new_segment_callback="

This reverts commit 81e6df3bab.

* Add test for Context#full_n_segments

* Add Context#full_n_segments

* Add tests for lang API

* Add lang API

* Add tests for Context#full_lang_id API

* Add Context#full_lang_id

* Add abnormal test cases for lang

* Raise appropriate errors from lang APIs

* Add tests for Context#full_get_segment_t{0,1} API

* Add Context#full_get_segment_t{0,1}

* Add tests for Context#full_get_segment_speaker_turn_next API

* Add Context#full_get_segment_speaker_turn_next

* Add tests for Context#full_get_segment_text

* Add Context#full_get_setgment_text

* Add tests for Params#new_segment_callback=

* Run new segment callback

* Split tests to multiple files

* Use container struct for new segment callback

* Add tests for Params#new_segment_callback_user_data=

* Add Whisper::Params#new_user_callback_user_data=

* Add GC-related test for new segment callback

* Protect new segment callback related structs from GC

* Add meaningful test for build

* Rename: new_segment_callback_user_data -> new_segment_callback_container

* Add tests for Whisper::Segment

* Add Whisper::Segment and Whisper::Context#each_segment

* Extract c_ruby_whisper_callback_container_allocate()

* Add test for Whisper::Params#on_new_segment

* Add Whisper::Params#on_new_egment

* Assign symbol IDs to variables

* Make extsources.yaml simpler

* Update README

* Add document comments

* Add test for calling Whisper::Params#on_new_segment multiple times

* Add file dependencies to GitHub actions config and .gitignore

* Add more files to ext/.gitignore
This commit is contained in:
KITAITI Makoto
2024-10-28 22:43:27 +09:00
committed by GitHub
parent c0ea41f6b2
commit fc49ee4479
14 changed files with 1117 additions and 175 deletions

View File

@ -0,0 +1,76 @@
require "test/unit"
require "whisper"
class TestCallback < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
def setup
@params = Whisper::Params.new
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
@audio = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
end
def test_new_segment_callback
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_kind_of Integer, n_new
assert n_new > 0
assert_same @whisper, context
n_segments = context.full_n_segments
n_new.times do |i|
i_segment = n_segments - 1 + i
start_time = context.full_get_segment_t0(i_segment) * 10
end_time = context.full_get_segment_t1(i_segment) * 10
text = context.full_get_segment_text(i_segment)
assert_kind_of Integer, start_time
assert start_time >= 0
assert_kind_of Integer, end_time
assert end_time > 0
assert_match /ask not what your country can do for you, ask what you can do for your country/, text if i_segment == 0
end
}
@whisper.transcribe(@audio, @params)
end
def test_new_segment_callback_closure
search_word = "what"
@params.new_segment_callback = ->(context, state, n_new, user_data) {
n_segments = context.full_n_segments
n_new.times do |i|
i_segment = n_segments - 1 + i
text = context.full_get_segment_text(i_segment)
if text.include?(search_word)
t0 = context.full_get_segment_t0(i_segment)
t1 = context.full_get_segment_t1(i_segment)
raise "search word '#{search_word}' found at between #{t0} and #{t1}"
end
end
}
assert_raise RuntimeError do
@whisper.transcribe(@audio, @params)
end
end
def test_new_segment_callback_user_data
udata = Object.new
@params.new_segment_callback_user_data = udata
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_same udata, user_data
}
@whisper.transcribe(@audio, @params)
end
def test_new_segment_callback_user_data_gc
@params.new_segment_callback_user_data = "My user data"
@params.new_segment_callback = ->(context, state, n_new, user_data) {
assert_equal "My user data", user_data
}
GC.start
assert_same @whisper, @whisper.transcribe(@audio, @params)
end
end

View File

@ -0,0 +1,28 @@
require 'test/unit'
require 'tempfile'
require 'tmpdir'
require 'shellwords'
class TestPackage < Test::Unit::TestCase
def test_build
Tempfile.create do |file|
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
assert file.size > 0
end
end
sub_test_case "Building binary on installation" do
def setup
system "rake", "build", exception: true
end
def test_install
filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1]
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename)
end
end
end
end

View File

@ -0,0 +1,112 @@
require 'whisper'
class TestParams < Test::Unit::TestCase
def setup
@params = Whisper::Params.new
end
def test_language
@params.language = "en"
assert_equal @params.language, "en"
@params.language = "auto"
assert_equal @params.language, "auto"
end
def test_offset
@params.offset = 10_000
assert_equal @params.offset, 10_000
@params.offset = 0
assert_equal @params.offset, 0
end
def test_duration
@params.duration = 60_000
assert_equal @params.duration, 60_000
@params.duration = 0
assert_equal @params.duration, 0
end
def test_max_text_tokens
@params.max_text_tokens = 300
assert_equal @params.max_text_tokens, 300
@params.max_text_tokens = 0
assert_equal @params.max_text_tokens, 0
end
def test_translate
@params.translate = true
assert @params.translate
@params.translate = false
assert !@params.translate
end
def test_no_context
@params.no_context = true
assert @params.no_context
@params.no_context = false
assert !@params.no_context
end
def test_single_segment
@params.single_segment = true
assert @params.single_segment
@params.single_segment = false
assert !@params.single_segment
end
def test_print_special
@params.print_special = true
assert @params.print_special
@params.print_special = false
assert !@params.print_special
end
def test_print_progress
@params.print_progress = true
assert @params.print_progress
@params.print_progress = false
assert !@params.print_progress
end
def test_print_realtime
@params.print_realtime = true
assert @params.print_realtime
@params.print_realtime = false
assert !@params.print_realtime
end
def test_print_timestamps
@params.print_timestamps = true
assert @params.print_timestamps
@params.print_timestamps = false
assert !@params.print_timestamps
end
def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank
@params.suppress_blank = false
assert !@params.suppress_blank
end
def test_suppress_non_speech_tokens
@params.suppress_non_speech_tokens = true
assert @params.suppress_non_speech_tokens
@params.suppress_non_speech_tokens = false
assert !@params.suppress_non_speech_tokens
end
def test_token_timestamps
@params.token_timestamps = true
assert @params.token_timestamps
@params.token_timestamps = false
assert !@params.token_timestamps
end
def test_split_on_word
@params.split_on_word = true
assert @params.split_on_word
@params.split_on_word = false
assert !@params.split_on_word
end
end

View File

@ -0,0 +1,87 @@
require "test/unit"
require "whisper"
class TestSegment < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
params.print_timestamps = false
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper.transcribe(jfk, params)
end
end
def test_iteration
whisper.each_segment do |segment|
assert_instance_of Whisper::Segment, segment
end
end
def test_enumerator
enum = whisper.each_segment
assert_instance_of Enumerator, enum
enum.to_a.each_with_index do |segment, index|
assert_instance_of Whisper::Segment, segment
assert_kind_of Integer, index
end
end
def test_start_time
i = 0
whisper.each_segment do |segment|
assert_equal 0, segment.start_time if i == 0
i += 1
end
end
def test_end_time
i = 0
whisper.each_segment do |segment|
assert_equal whisper.full_get_segment_t1(i) * 10, segment.end_time
i += 1
end
end
def test_on_new_segment
params = Whisper::Params.new
seg = nil
index = 0
params.on_new_segment do |segment|
assert_instance_of Whisper::Segment, segment
if index == 0
seg = segment
assert_equal 0, segment.start_time
assert_match /ask not what your country can do for you, ask what you can do for your country/, segment.text
end
index += 1
end
whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params)
assert_equal 0, seg.start_time
assert_match /ask not what your country can do for you, ask what you can do for your country/, seg.text
end
def test_on_new_segment_twice
params = Whisper::Params.new
seg = nil
params.on_new_segment do |segment|
seg = segment
return
end
params.on_new_segment do |segment|
assert_same seg, segment
return
end
whisper.transcribe(File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav'), params)
end
private
def whisper
self.class.whisper
end
end

View File

@ -1,121 +1,13 @@
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
require 'whisper'
require 'test/unit'
require 'tempfile'
require 'tmpdir'
require 'shellwords'
class TestWhisper < Test::Unit::TestCase
TOPDIR = File.expand_path(File.join(File.dirname(__FILE__), '..'))
def setup
@params = Whisper::Params.new
end
def test_language
@params.language = "en"
assert_equal @params.language, "en"
@params.language = "auto"
assert_equal @params.language, "auto"
end
def test_offset
@params.offset = 10_000
assert_equal @params.offset, 10_000
@params.offset = 0
assert_equal @params.offset, 0
end
def test_duration
@params.duration = 60_000
assert_equal @params.duration, 60_000
@params.duration = 0
assert_equal @params.duration, 0
end
def test_max_text_tokens
@params.max_text_tokens = 300
assert_equal @params.max_text_tokens, 300
@params.max_text_tokens = 0
assert_equal @params.max_text_tokens, 0
end
def test_translate
@params.translate = true
assert @params.translate
@params.translate = false
assert !@params.translate
end
def test_no_context
@params.no_context = true
assert @params.no_context
@params.no_context = false
assert !@params.no_context
end
def test_single_segment
@params.single_segment = true
assert @params.single_segment
@params.single_segment = false
assert !@params.single_segment
end
def test_print_special
@params.print_special = true
assert @params.print_special
@params.print_special = false
assert !@params.print_special
end
def test_print_progress
@params.print_progress = true
assert @params.print_progress
@params.print_progress = false
assert !@params.print_progress
end
def test_print_realtime
@params.print_realtime = true
assert @params.print_realtime
@params.print_realtime = false
assert !@params.print_realtime
end
def test_print_timestamps
@params.print_timestamps = true
assert @params.print_timestamps
@params.print_timestamps = false
assert !@params.print_timestamps
end
def test_suppress_blank
@params.suppress_blank = true
assert @params.suppress_blank
@params.suppress_blank = false
assert !@params.suppress_blank
end
def test_suppress_non_speech_tokens
@params.suppress_non_speech_tokens = true
assert @params.suppress_non_speech_tokens
@params.suppress_non_speech_tokens = false
assert !@params.suppress_non_speech_tokens
end
def test_token_timestamps
@params.token_timestamps = true
assert @params.token_timestamps
@params.token_timestamps = false
assert !@params.token_timestamps
end
def test_split_on_word
@params.split_on_word = true
assert @params.split_on_word
@params.split_on_word = false
assert !@params.split_on_word
end
def test_whisper
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
@ -127,25 +19,81 @@ class TestWhisper < Test::Unit::TestCase
}
end
def test_build
Tempfile.create do |file|
assert system("gem", "build", "whispercpp.gemspec", "--output", file.to_path.shellescape, exception: true)
assert_path_exist file.to_path
sub_test_case "After transcription" do
class << self
attr_reader :whisper
def startup
@whisper = Whisper::Context.new(File.join(TOPDIR, '..', '..', 'models', 'ggml-base.en.bin'))
params = Whisper::Params.new
params.print_timestamps = false
jfk = File.join(TOPDIR, '..', '..', 'samples', 'jfk.wav')
@whisper.transcribe(jfk, params)
end
end
def whisper
self.class.whisper
end
def test_full_n_segments
assert_equal 1, whisper.full_n_segments
end
def test_full_lang_id
assert_equal 0, whisper.full_lang_id
end
def test_full_get_segment_t0
assert_equal 0, whisper.full_get_segment_t0(0)
assert_raise IndexError do
whisper.full_get_segment_t0(whisper.full_n_segments)
end
assert_raise IndexError do
whisper.full_get_segment_t0(-1)
end
end
def test_full_get_segment_t1
t1 = whisper.full_get_segment_t1(0)
assert_kind_of Integer, t1
assert t1 > 0
assert_raise IndexError do
whisper.full_get_segment_t1(whisper.full_n_segments)
end
end
def test_full_get_segment_speaker_turn_next
assert_false whisper.full_get_segment_speaker_turn_next(0)
end
def test_full_get_segment_text
assert_match /ask not what your country can do for you, ask what you can do for your country/, whisper.full_get_segment_text(0)
end
end
sub_test_case "Building binary on installation" do
def setup
system "rake", "build", exception: true
end
def test_lang_max_id
assert_kind_of Integer, Whisper.lang_max_id
end
def test_install
filename = `rake -Tbuild`.match(/(whispercpp-(?:.+)\.gem)/)[1]
basename = "whisper.#{RbConfig::CONFIG["DLEXT"]}"
Dir.mktmpdir do |dir|
system "gem", "install", "--install-dir", dir.shellescape, "pkg/#{filename.shellescape}", exception: true
assert_path_exist File.join(dir, "gems/whispercpp-1.3.0/lib", basename)
end
def test_lang_id
assert_equal 0, Whisper.lang_id("en")
assert_raise ArgumentError do
Whisper.lang_id("non existing language")
end
end
def test_lang_str
assert_equal "en", Whisper.lang_str(0)
assert_raise IndexError do
Whisper.lang_str(Whisper.lang_max_id + 1)
end
end
def test_lang_str_full
assert_equal "english", Whisper.lang_str_full(0)
assert_raise IndexError do
Whisper.lang_str_full(Whisper.lang_max_id + 1)
end
end
end