diff --git a/core/config/backend_config.go b/core/config/backend_config.go index 696bab63..d5a4586b 100644 --- a/core/config/backend_config.go +++ b/core/config/backend_config.go @@ -81,7 +81,7 @@ type BackendConfig struct { type Pipeline struct { TTS string `yaml:"tts"` LLM string `yaml:"llm"` - Transcription string `yaml:"sst"` + Transcription string `yaml:"transcription"` } type File struct { diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 71d064dd..00fe28f7 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -10,7 +10,9 @@ import ( "github.com/gofiber/websocket/v2" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + grpc "github.com/mudler/LocalAI/pkg/grpc" model "github.com/mudler/LocalAI/pkg/model" + "github.com/rs/zerolog/log" ) @@ -111,13 +113,17 @@ type Model interface { } type wrappedModel struct { - TTS *config.BackendConfig - SST *config.BackendConfig - LLM *config.BackendConfig + TTSConfig *config.BackendConfig + TranscriptionConfig *config.BackendConfig + LLMConfig *config.BackendConfig + TTSClient grpc.Backend + TranscriptionClient grpc.Backend + LLMClient grpc.Backend } // returns and loads either a wrapped model or a model that support audio-to-audio func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) { + cfg, err := cl.LoadBackendConfigFileByName(modelName, ml.ModelPath) if err != nil { return nil, fmt.Errorf("failed to load backend config: %w", err) @@ -134,6 +140,8 @@ func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig * return ml.BackendLoader(opts...) } + log.Debug().Msg("Loading a wrapped model") + // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations cfgLLM, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.LLM, ml.ModelPath) if err != nil { @@ -165,10 +173,31 @@ func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig * return nil, fmt.Errorf("failed to validate config: %w", err) } + opts := backend.ModelOptions(*cfgTTS, appConfig) + ttsClient, err := ml.BackendLoader(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load tts model: %w", err) + } + + opts = backend.ModelOptions(*cfgSST, appConfig) + transcriptionClient, err := ml.BackendLoader(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load SST model: %w", err) + } + + opts = backend.ModelOptions(*cfgLLM, appConfig) + llmClient, err := ml.BackendLoader(opts...) + if err != nil { + return nil, fmt.Errorf("failed to load LLM model: %w", err) + } + return &wrappedModel{ - TTS: cfgTTS, - SST: cfgSST, - LLM: cfgLLM, + TTSConfig: cfgTTS, + TranscriptionConfig: cfgSST, + LLMConfig: cfgLLM, + TTSClient: ttsClient, + TranscriptionClient: transcriptionClient, + LLMClient: llmClient, }, nil }