diff --git a/bindings/go/params.go b/bindings/go/params.go index 3c9dd5ce..5931bb0b 100644 --- a/bindings/go/params.go +++ b/bindings/go/params.go @@ -123,6 +123,11 @@ func (p *Params) SetAudioCtx(n int) { p.audio_ctx = C.int(n) } +// Set initial prompt +func (p *Params) SetInitialPrompt(prompt string) { + p.initial_prompt = C.CString(prompt) +} + /////////////////////////////////////////////////////////////////////////////// // PRIVATE METHODS @@ -147,6 +152,7 @@ func (p *Params) String() string { str += fmt.Sprintf(" offset_ms=%d", p.offset_ms) str += fmt.Sprintf(" duration_ms=%d", p.duration_ms) str += fmt.Sprintf(" audio_ctx=%d", p.audio_ctx) + str += fmt.Sprintf(" initial_prompt=%s", C.GoString(p.initial_prompt)) if p.translate { str += " translate" } diff --git a/bindings/go/pkg/whisper/context.go b/bindings/go/pkg/whisper/context.go index f51d4f89..0863ef6b 100644 --- a/bindings/go/pkg/whisper/context.go +++ b/bindings/go/pkg/whisper/context.go @@ -130,6 +130,11 @@ func (context *context) SetAudioCtx(n uint) { context.params.SetAudioCtx(int(n)) } +// Set initial prompt +func (context *context) SetInitialPrompt(prompt string) { + context.params.SetInitialPrompt(prompt) +} + // ResetTimings resets the mode timings. Should be called before processing func (context *context) ResetTimings() { context.model.ctx.Whisper_reset_timings() diff --git a/bindings/go/pkg/whisper/interface.go b/bindings/go/pkg/whisper/interface.go index 4744271d..4339e16f 100644 --- a/bindings/go/pkg/whisper/interface.go +++ b/bindings/go/pkg/whisper/interface.go @@ -38,17 +38,18 @@ type Context interface { IsMultilingual() bool // Return true if the model is multilingual. Language() string // Get language - SetOffset(time.Duration) // Set offset - SetDuration(time.Duration) // Set duration - SetThreads(uint) // Set number of threads to use - SetSpeedup(bool) // Set speedup flag - SetSplitOnWord(bool) // Set split on word flag - SetTokenThreshold(float32) // Set timestamp token probability threshold - SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold - SetMaxSegmentLength(uint) // Set max segment length in characters - SetTokenTimestamps(bool) // Set token timestamps flag - SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit) - SetAudioCtx(uint) // Set audio encoder context + SetOffset(time.Duration) // Set offset + SetDuration(time.Duration) // Set duration + SetThreads(uint) // Set number of threads to use + SetSpeedup(bool) // Set speedup flag + SetSplitOnWord(bool) // Set split on word flag + SetTokenThreshold(float32) // Set timestamp token probability threshold + SetTokenSumThreshold(float32) // Set timestamp token sum probability threshold + SetMaxSegmentLength(uint) // Set max segment length in characters + SetTokenTimestamps(bool) // Set token timestamps flag + SetMaxTokensPerSegment(uint) // Set max tokens per segment (0 = no limit) + SetAudioCtx(uint) // Set audio encoder context + SetInitialPrompt(prompt string) // Set initial prompt // Process mono audio data and return any errors. // If defined, newly generated segments are passed to the