feat(whisper): add translate option (#2649)

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-06-24 19:21:22 +02:00 committed by GitHub
parent 9e6dec0bc4
commit 03b1cf51fd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 18 additions and 8 deletions

View File

@ -230,6 +230,7 @@ message TranscriptRequest {
string dst = 2; string dst = 2;
string language = 3; string language = 3;
uint32 threads = 4; uint32 threads = 4;
bool translate = 5;
} }
message TranscriptResult { message TranscriptResult {

View File

@ -29,7 +29,7 @@ func audioToWav(src, dst string) error {
return nil return nil
} }
func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.TranscriptionResult, error) { func Transcript(model whisper.Model, audiopath, language string, translate bool, threads uint) (schema.TranscriptionResult, error) {
res := schema.TranscriptionResult{} res := schema.TranscriptionResult{}
dir, err := os.MkdirTemp("", "whisper") dir, err := os.MkdirTemp("", "whisper")
@ -75,6 +75,10 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
context.SetLanguage("auto") context.SetLanguage("auto")
} }
if translate {
context.SetTranslate(true)
}
if err := context.Process(data, nil, nil); err != nil { if err := context.Process(data, nil, nil); err != nil {
return res, err return res, err
} }

View File

@ -22,5 +22,5 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error {
} }
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) { func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) {
return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads)) return Transcript(sd.whisper, opts.Dst, opts.Language, opts.Translate, uint(opts.Threads))
} }

View File

@ -11,7 +11,7 @@ import (
model "github.com/mudler/LocalAI/pkg/model" model "github.com/mudler/LocalAI/pkg/model"
) )
func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
opts := modelOpts(backendConfig, appConfig, []model.Option{ opts := modelOpts(backendConfig, appConfig, []model.Option{
model.WithBackendString(model.WhisperBackend), model.WithBackendString(model.WhisperBackend),
@ -33,6 +33,7 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, backendCo
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
Dst: audio, Dst: audio,
Language: language, Language: language,
Translate: translate,
Threads: uint32(*backendConfig.Threads), Threads: uint32(*backendConfig.Threads),
}) })
} }

View File

@ -18,6 +18,7 @@ type TranscriptCMD struct {
Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"` Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
Model string `short:"m" required:"" help:"Model name to run the TTS"` Model string `short:"m" required:"" help:"Model name to run the TTS"`
Language string `short:"l" help:"Language of the audio file"` Language string `short:"l" help:"Language of the audio file"`
Translate bool `short:"t" help:"Translate the transcription to english"`
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"` Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"` BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
@ -50,7 +51,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
} }
}() }()
tr, err := backend.ModelTranscription(t.Filename, t.Language, ml, c, opts) tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, ml, c, opts)
if err != nil { if err != nil {
return err return err
} }

View File

@ -65,7 +65,7 @@ func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
log.Debug().Msgf("Audio file copied to: %+v", dst) log.Debug().Msgf("Audio file copied to: %+v", dst)
tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig) tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, ml, *config, appConfig)
if err != nil { if err != nil {
return err return err
} }

View File

@ -8,6 +8,9 @@ type PredictionOptions struct {
// Also part of the OpenAI official spec // Also part of the OpenAI official spec
Language string `json:"language"` Language string `json:"language"`
// Only for audio transcription
Translate bool `json:"translate"`
// Also part of the OpenAI official spec. use it for returning multiple results // Also part of the OpenAI official spec. use it for returning multiple results
N int `json:"n"` N int `json:"n"`