From 96db0c5a9c8517052026cc672da4b0c63dcf3de5 Mon Sep 17 00:00:00 2001 From: Amanda Der Bedrosian Date: Wed, 19 Mar 2025 00:05:04 -0700 Subject: [PATCH] go : add Encoder Begin Callback (#2900) Adding in EncoderBeginCallback to the Context's Process callback. This optional callback function returns false if computation should be aborted. Co-authored-by: Amanda Der Bedrosian --- bindings/go/README.md | 2 +- bindings/go/examples/go-whisper/process.go | 2 +- bindings/go/pkg/whisper/context.go | 35 ++++++++++++---------- bindings/go/pkg/whisper/context_test.go | 2 +- bindings/go/pkg/whisper/interface.go | 8 +++-- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/bindings/go/README.md b/bindings/go/README.md index 6958ede8..cbd2a622 100644 --- a/bindings/go/README.md +++ b/bindings/go/README.md @@ -31,7 +31,7 @@ func main() { if err != nil { panic(err) } - if err := context.Process(samples, nil, nil); err != nil { + if err := context.Process(samples, nil, nil, nil); err != nil { return err } diff --git a/bindings/go/examples/go-whisper/process.go b/bindings/go/examples/go-whisper/process.go index 71e52f01..833947e8 100644 --- a/bindings/go/examples/go-whisper/process.go +++ b/bindings/go/examples/go-whisper/process.go @@ -67,7 +67,7 @@ func Process(model whisper.Model, path string, flags *Flags) error { // Process the data fmt.Fprintf(flags.Output(), " ...processing %q\n", path) context.ResetTimings() - if err := context.Process(data, cb, nil); err != nil { + if err := context.Process(data, nil, cb, nil); err != nil { return err } diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index 06376b1b..a8061293 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -189,6 +189,7 @@ func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]f // Process new sample data and return any errors func (context *context) Process( data []float32, + callEncoderBegin EncoderBeginCallback, callNewSegment SegmentCallback, callProgress ProgressCallback, ) error { @@ -203,7 +204,20 @@ func (context *context) Process( // We don't do parallel processing at the moment processors := 0 if processors > 1 { - if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) { + if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, callEncoderBegin, + func(new int) { + if callNewSegment != nil { + num_segments := context.model.ctx.Whisper_full_n_segments() + s0 := num_segments - new + for i := s0; i < num_segments; i++ { + callNewSegment(toSegment(context.model.ctx, i)) + } + } + }); err != nil { + return err + } + } else if err := context.model.ctx.Whisper_full(context.params, data, callEncoderBegin, + func(new int) { if callNewSegment != nil { num_segments := context.model.ctx.Whisper_full_n_segments() s0 := num_segments - new @@ -211,22 +225,11 @@ func (context *context) Process( callNewSegment(toSegment(context.model.ctx, i)) } } - }); err != nil { - return err - } - } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) { - if callNewSegment != nil { - num_segments := context.model.ctx.Whisper_full_n_segments() - s0 := num_segments - new - for i := s0; i < num_segments; i++ { - callNewSegment(toSegment(context.model.ctx, i)) + }, func(progress int) { + if callProgress != nil { + callProgress(progress) } - } - }, func(progress int) { - if callProgress != nil { - callProgress(progress) - } - }); err != nil { + }); err != nil { return err } diff --git a/bindings/go/pkg/whisper/context_test.go b/bindings/go/pkg/whisper/context_test.go index 7d83a8df..51051481 100644 --- a/bindings/go/pkg/whisper/context_test.go +++ b/bindings/go/pkg/whisper/context_test.go @@ -88,6 +88,6 @@ func TestProcess(t *testing.T) { context, err := model.NewContext() assert.NoError(err) - err = context.Process(data, nil, nil) + err = context.Process(data, nil, nil, nil) assert.NoError(err) } diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index 8981b1a8..2b6a9c8e 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -16,6 +16,10 @@ type SegmentCallback func(Segment) // processing. It is called during the Process function type ProgressCallback func(int) +// EncoderBeginCallback is the callback function for checking if we want to +// continue processing. It is called during the Process function +type EncoderBeginCallback func() bool + // Model is the interface to a whisper model. Create a new model with the // function whisper.New(string) type Model interface { @@ -31,7 +35,7 @@ type Model interface { Languages() []string } -// Context is the speach recognition context. +// Context is the speech recognition context. type Context interface { SetLanguage(string) error // Set the language to use for speech recognition, use "auto" for auto detect language. SetTranslate(bool) // Set translate flag @@ -58,7 +62,7 @@ type Context interface { // Process mono audio data and return any errors. // If defined, newly generated segments are passed to the // callback function during processing. - Process([]float32, SegmentCallback, ProgressCallback) error + Process([]float32, EncoderBeginCallback, SegmentCallback, ProgressCallback) error // After process is called, return segments until the end of the stream // is reached, when io.EOF is returned.