package whisper import ( "fmt" "io" "runtime" "strings" "time" // Bindings whisper "github.com/ggerganov/whisper.cpp/bindings/go" ) /////////////////////////////////////////////////////////////////////////////// // TYPES type context struct { n int model *model params whisper.Params } // Make sure context adheres to the interface var _ Context = (*context)(nil) /////////////////////////////////////////////////////////////////////////////// // LIFECYCLE func newContext(model *model, params whisper.Params) (Context, error) { context := new(context) context.model = model context.params = params // Return success return context, nil } /////////////////////////////////////////////////////////////////////////////// // PUBLIC METHODS // Set the language to use for speech recognition. func (context *context) SetLanguage(lang string) error { if context.model.ctx == nil { return ErrInternalAppError } if !context.model.IsMultilingual() { return ErrModelNotMultilingual } if lang == "auto" { context.params.SetLanguage(-1) } else if id := context.model.ctx.Whisper_lang_id(lang); id < 0 { return ErrUnsupportedLanguage } else if err := context.params.SetLanguage(id); err != nil { return err } // Return success return nil } func (context *context) IsMultilingual() bool { return context.model.IsMultilingual() } // Get language func (context *context) Language() string { id := context.params.Language() if id == -1 { return "auto" } return whisper.Whisper_lang_str(context.params.Language()) } // Set translate flag func (context *context) SetTranslate(v bool) { context.params.SetTranslate(v) } func (context *context) SetSplitOnWord(v bool) { context.params.SetSplitOnWord(v) } // Set number of threads to use func (context *context) SetThreads(v uint) { context.params.SetThreads(int(v)) } // Set time offset func (context *context) SetOffset(v time.Duration) { context.params.SetOffset(int(v.Milliseconds())) } // Set duration of audio to process func (context *context) SetDuration(v time.Duration) { context.params.SetDuration(int(v.Milliseconds())) } // Set timestamp token probability threshold (~0.01) func (context *context) SetTokenThreshold(t float32) { context.params.SetTokenThreshold(t) } // Set timestamp token sum probability threshold (~0.01) func (context *context) SetTokenSumThreshold(t float32) { context.params.SetTokenSumThreshold(t) } // Set max segment length in characters func (context *context) SetMaxSegmentLength(n uint) { context.params.SetMaxSegmentLength(int(n)) } // Set token timestamps flag func (context *context) SetTokenTimestamps(b bool) { context.params.SetTokenTimestamps(b) } // Set max tokens per segment (0 = no limit) func (context *context) SetMaxTokensPerSegment(n uint) { context.params.SetMaxTokensPerSegment(int(n)) } // Set audio encoder context func (context *context) SetAudioCtx(n uint) { context.params.SetAudioCtx(int(n)) } // Set maximum number of text context tokens to store func (context *context) SetMaxContext(n int) { context.params.SetMaxContext(n) } // Set Beam Size func (context *context) SetBeamSize(n int) { context.params.SetBeamSize(n) } // Set Entropy threshold func (context *context) SetEntropyThold(t float32) { context.params.SetEntropyThold(t) } // Set initial prompt func (context *context) SetInitialPrompt(prompt string) { context.params.SetInitialPrompt(prompt) } // ResetTimings resets the mode timings. Should be called before processing func (context *context) ResetTimings() { context.model.ctx.Whisper_reset_timings() } // PrintTimings prints the model timings to stdout. func (context *context) PrintTimings() { context.model.ctx.Whisper_print_timings() } // SystemInfo returns the system information func (context *context) SystemInfo() string { return fmt.Sprintf("system_info: n_threads = %d / %d | %s\n", context.params.Threads(), runtime.NumCPU(), whisper.Whisper_print_system_info(), ) } // Use mel data at offset_ms to try and auto-detect the spoken language // Make sure to call whisper_pcm_to_mel() or whisper_set_mel() first. // Returns the probabilities of all languages. func (context *context) WhisperLangAutoDetect(offset_ms int, n_threads int) ([]float32, error) { langProbs, err := context.model.ctx.Whisper_lang_auto_detect(offset_ms, n_threads) if err != nil { return nil, err } return langProbs, nil } // Process new sample data and return any errors func (context *context) Process( data []float32, callNewSegment SegmentCallback, callProgress ProgressCallback, ) error { if context.model.ctx == nil { return ErrInternalAppError } // If the callback is defined then we force on single_segment mode if callNewSegment != nil { context.params.SetSingleSegment(true) } // We don't do parallel processing at the moment processors := 0 if processors > 1 { if err := context.model.ctx.Whisper_full_parallel(context.params, data, processors, nil, func(new int) { if callNewSegment != nil { num_segments := context.model.ctx.Whisper_full_n_segments() s0 := num_segments - new for i := s0; i < num_segments; i++ { callNewSegment(toSegment(context.model.ctx, i)) } } }); err != nil { return err } } else if err := context.model.ctx.Whisper_full(context.params, data, nil, func(new int) { if callNewSegment != nil { num_segments := context.model.ctx.Whisper_full_n_segments() s0 := num_segments - new for i := s0; i < num_segments; i++ { callNewSegment(toSegment(context.model.ctx, i)) } } }, func(progress int) { if callProgress != nil { callProgress(progress) } }); err != nil { return err } // Return success return nil } // Return the next segment of tokens func (context *context) NextSegment() (Segment, error) { if context.model.ctx == nil { return Segment{}, ErrInternalAppError } if context.n >= context.model.ctx.Whisper_full_n_segments() { return Segment{}, io.EOF } // Populate result result := toSegment(context.model.ctx, context.n) // Increment the cursor context.n++ // Return success return result, nil } // Test for text tokens func (context *context) IsText(t Token) bool { switch { case context.IsBEG(t): return false case context.IsSOT(t): return false case whisper.Token(t.Id) >= context.model.ctx.Whisper_token_eot(): return false case context.IsPREV(t): return false case context.IsSOLM(t): return false case context.IsNOT(t): return false default: return true } } // Test for "begin" token func (context *context) IsBEG(t Token) bool { return whisper.Token(t.Id) == context.model.ctx.Whisper_token_beg() } // Test for "start of transcription" token func (context *context) IsSOT(t Token) bool { return whisper.Token(t.Id) == context.model.ctx.Whisper_token_sot() } // Test for "end of transcription" token func (context *context) IsEOT(t Token) bool { return whisper.Token(t.Id) == context.model.ctx.Whisper_token_eot() } // Test for "start of prev" token func (context *context) IsPREV(t Token) bool { return whisper.Token(t.Id) == context.model.ctx.Whisper_token_prev() } // Test for "start of lm" token func (context *context) IsSOLM(t Token) bool { return whisper.Token(t.Id) == context.model.ctx.Whisper_token_solm() } // Test for "No timestamps" token func (context *context) IsNOT(t Token) bool { return whisper.Token(t.Id) == context.model.ctx.Whisper_token_not() } // Test for token associated with a specific language func (context *context) IsLANG(t Token, lang string) bool { if id := context.model.ctx.Whisper_lang_id(lang); id >= 0 { return whisper.Token(t.Id) == context.model.ctx.Whisper_token_lang(id) } else { return false } } /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS func toSegment(ctx *whisper.Context, n int) Segment { return Segment{ Num: n, Text: strings.TrimSpace(ctx.Whisper_full_get_segment_text(n)), Start: time.Duration(ctx.Whisper_full_get_segment_t0(n)) * time.Millisecond * 10, End: time.Duration(ctx.Whisper_full_get_segment_t1(n)) * time.Millisecond * 10, Tokens: toTokens(ctx, n), } } func toTokens(ctx *whisper.Context, n int) []Token { result := make([]Token, ctx.Whisper_full_n_tokens(n)) for i := 0; i < len(result); i++ { data := ctx.Whisper_full_get_token_data(n, i) result[i] = Token{ Id: int(ctx.Whisper_full_get_token_id(n, i)), Text: ctx.Whisper_full_get_token_text(n, i), P: ctx.Whisper_full_get_token_p(n, i), Start: time.Duration(data.T0()) * time.Millisecond * 10, End: time.Duration(data.T1()) * time.Millisecond * 10, } } return result }