mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-11 15:32:44 +00:00
feat(whisper): add translate option (#2649)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
9e6dec0bc4
commit
03b1cf51fd
@ -230,6 +230,7 @@ message TranscriptRequest {
|
|||||||
string dst = 2;
|
string dst = 2;
|
||||||
string language = 3;
|
string language = 3;
|
||||||
uint32 threads = 4;
|
uint32 threads = 4;
|
||||||
|
bool translate = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message TranscriptResult {
|
message TranscriptResult {
|
||||||
|
@ -29,7 +29,7 @@ func audioToWav(src, dst string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func Transcript(model whisper.Model, audiopath, language string, threads uint) (schema.TranscriptionResult, error) {
|
func Transcript(model whisper.Model, audiopath, language string, translate bool, threads uint) (schema.TranscriptionResult, error) {
|
||||||
res := schema.TranscriptionResult{}
|
res := schema.TranscriptionResult{}
|
||||||
|
|
||||||
dir, err := os.MkdirTemp("", "whisper")
|
dir, err := os.MkdirTemp("", "whisper")
|
||||||
@ -75,6 +75,10 @@ func Transcript(model whisper.Model, audiopath, language string, threads uint) (
|
|||||||
context.SetLanguage("auto")
|
context.SetLanguage("auto")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if translate {
|
||||||
|
context.SetTranslate(true)
|
||||||
|
}
|
||||||
|
|
||||||
if err := context.Process(data, nil, nil); err != nil {
|
if err := context.Process(data, nil, nil); err != nil {
|
||||||
return res, err
|
return res, err
|
||||||
}
|
}
|
||||||
|
@ -22,5 +22,5 @@ func (sd *Whisper) Load(opts *pb.ModelOptions) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) {
|
func (sd *Whisper) AudioTranscription(opts *pb.TranscriptRequest) (schema.TranscriptionResult, error) {
|
||||||
return Transcript(sd.whisper, opts.Dst, opts.Language, uint(opts.Threads))
|
return Transcript(sd.whisper, opts.Dst, opts.Language, opts.Translate, uint(opts.Threads))
|
||||||
}
|
}
|
||||||
|
@ -11,7 +11,7 @@ import (
|
|||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func ModelTranscription(audio, language string, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
func ModelTranscription(audio, language string, translate bool, ml *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) {
|
||||||
|
|
||||||
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
opts := modelOpts(backendConfig, appConfig, []model.Option{
|
||||||
model.WithBackendString(model.WhisperBackend),
|
model.WithBackendString(model.WhisperBackend),
|
||||||
@ -33,6 +33,7 @@ func ModelTranscription(audio, language string, ml *model.ModelLoader, backendCo
|
|||||||
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
return whisperModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{
|
||||||
Dst: audio,
|
Dst: audio,
|
||||||
Language: language,
|
Language: language,
|
||||||
|
Translate: translate,
|
||||||
Threads: uint32(*backendConfig.Threads),
|
Threads: uint32(*backendConfig.Threads),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -18,6 +18,7 @@ type TranscriptCMD struct {
|
|||||||
Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
|
Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"`
|
||||||
Model string `short:"m" required:"" help:"Model name to run the TTS"`
|
Model string `short:"m" required:"" help:"Model name to run the TTS"`
|
||||||
Language string `short:"l" help:"Language of the audio file"`
|
Language string `short:"l" help:"Language of the audio file"`
|
||||||
|
Translate bool `short:"t" help:"Translate the transcription to english"`
|
||||||
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
|
Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"`
|
||||||
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"`
|
||||||
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
|
BackendAssetsPath string `env:"LOCALAI_BACKEND_ASSETS_PATH,BACKEND_ASSETS_PATH" type:"path" default:"/tmp/localai/backend_data" help:"Path used to extract libraries that are required by some of the backends in runtime" group:"storage"`
|
||||||
@ -50,7 +51,7 @@ func (t *TranscriptCMD) Run(ctx *cliContext.Context) error {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
tr, err := backend.ModelTranscription(t.Filename, t.Language, ml, c, opts)
|
tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, ml, c, opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -65,7 +65,7 @@ func TranscriptEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, a
|
|||||||
|
|
||||||
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
log.Debug().Msgf("Audio file copied to: %+v", dst)
|
||||||
|
|
||||||
tr, err := backend.ModelTranscription(dst, input.Language, ml, *config, appConfig)
|
tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, ml, *config, appConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,9 @@ type PredictionOptions struct {
|
|||||||
// Also part of the OpenAI official spec
|
// Also part of the OpenAI official spec
|
||||||
Language string `json:"language"`
|
Language string `json:"language"`
|
||||||
|
|
||||||
|
// Only for audio transcription
|
||||||
|
Translate bool `json:"translate"`
|
||||||
|
|
||||||
// Also part of the OpenAI official spec. use it for returning multiple results
|
// Also part of the OpenAI official spec. use it for returning multiple results
|
||||||
N int `json:"n"`
|
N int `json:"n"`
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user