mirror of
https://github.com/mudler/LocalAI.git
synced 2025-05-12 21:43:14 +00:00
chore: extract realtime models into two categories
One is anyToAny models that requires a VAD model, and one is wrappedModel that requires as well VAD models along others in the pipeline. Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
4f69170273
commit
1796a1713d
@ -82,6 +82,11 @@ type Pipeline struct {
|
|||||||
TTS string `yaml:"tts"`
|
TTS string `yaml:"tts"`
|
||||||
LLM string `yaml:"llm"`
|
LLM string `yaml:"llm"`
|
||||||
Transcription string `yaml:"transcription"`
|
Transcription string `yaml:"transcription"`
|
||||||
|
VAD string `yaml:"vad"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Pipeline) IsNotConfigured() bool {
|
||||||
|
return p.LLM == "" || p.TTS == "" || p.Transcription == ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type File struct {
|
type File struct {
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package openai
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -8,10 +9,10 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/gofiber/websocket/v2"
|
"github.com/gofiber/websocket/v2"
|
||||||
"github.com/mudler/LocalAI/core/backend"
|
|
||||||
"github.com/mudler/LocalAI/core/config"
|
"github.com/mudler/LocalAI/core/config"
|
||||||
grpc "github.com/mudler/LocalAI/pkg/grpc"
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
model "github.com/mudler/LocalAI/pkg/model"
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@ -114,95 +115,7 @@ var sessionLock sync.Mutex
|
|||||||
|
|
||||||
// TODO: implement interface as we start to define usages
|
// TODO: implement interface as we start to define usages
|
||||||
type Model interface {
|
type Model interface {
|
||||||
}
|
VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error)
|
||||||
|
|
||||||
type wrappedModel struct {
|
|
||||||
TTSConfig *config.BackendConfig
|
|
||||||
TranscriptionConfig *config.BackendConfig
|
|
||||||
LLMConfig *config.BackendConfig
|
|
||||||
TTSClient grpc.Backend
|
|
||||||
TranscriptionClient grpc.Backend
|
|
||||||
LLMClient grpc.Backend
|
|
||||||
}
|
|
||||||
|
|
||||||
// returns and loads either a wrapped model or a model that support audio-to-audio
|
|
||||||
func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) {
|
|
||||||
|
|
||||||
cfg, err := cl.LoadBackendConfigFileByName(modelName, ml.ModelPath)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cfg.Validate() {
|
|
||||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cfg.Pipeline.LLM == "" || cfg.Pipeline.TTS == "" || cfg.Pipeline.Transcription == "" {
|
|
||||||
// If we don't have Wrapped model definitions, just return a standard model
|
|
||||||
opts := backend.ModelOptions(*cfg, appConfig, model.WithBackendString(cfg.Backend),
|
|
||||||
model.WithModel(cfg.Model))
|
|
||||||
return ml.Load(opts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msg("Loading a wrapped model")
|
|
||||||
|
|
||||||
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
|
|
||||||
cfgLLM, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.LLM, ml.ModelPath)
|
|
||||||
if err != nil {
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cfg.Validate() {
|
|
||||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfgTTS, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.TTS, ml.ModelPath)
|
|
||||||
if err != nil {
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cfg.Validate() {
|
|
||||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
cfgSST, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.Transcription, ml.ModelPath)
|
|
||||||
if err != nil {
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cfg.Validate() {
|
|
||||||
return nil, fmt.Errorf("failed to validate config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opts := backend.ModelOptions(*cfgTTS, appConfig)
|
|
||||||
ttsClient, err := ml.Load(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load tts model: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opts = backend.ModelOptions(*cfgSST, appConfig)
|
|
||||||
transcriptionClient, err := ml.Load(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load SST model: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
opts = backend.ModelOptions(*cfgLLM, appConfig)
|
|
||||||
llmClient, err := ml.Load(opts...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to load LLM model: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &wrappedModel{
|
|
||||||
TTSConfig: cfgTTS,
|
|
||||||
TranscriptionConfig: cfgSST,
|
|
||||||
LLMConfig: cfgLLM,
|
|
||||||
TTSClient: ttsClient,
|
|
||||||
TranscriptionClient: transcriptionClient,
|
|
||||||
LLMClient: llmClient,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
|
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
|
||||||
|
169
core/http/endpoints/openai/realtime_model.go
Normal file
169
core/http/endpoints/openai/realtime_model.go
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/mudler/LocalAI/core/backend"
|
||||||
|
"github.com/mudler/LocalAI/core/config"
|
||||||
|
grpcClient "github.com/mudler/LocalAI/pkg/grpc"
|
||||||
|
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||||
|
model "github.com/mudler/LocalAI/pkg/model"
|
||||||
|
"github.com/rs/zerolog/log"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
)
|
||||||
|
|
||||||
|
// wrappedModel represent a model which does not support Any-to-Any operations
|
||||||
|
// This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods
|
||||||
|
// which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS)
|
||||||
|
type wrappedModel struct {
|
||||||
|
TTSConfig *config.BackendConfig
|
||||||
|
TranscriptionConfig *config.BackendConfig
|
||||||
|
LLMConfig *config.BackendConfig
|
||||||
|
TTSClient grpcClient.Backend
|
||||||
|
TranscriptionClient grpcClient.Backend
|
||||||
|
LLMClient grpcClient.Backend
|
||||||
|
|
||||||
|
VADConfig *config.BackendConfig
|
||||||
|
VADClient grpcClient.Backend
|
||||||
|
}
|
||||||
|
|
||||||
|
// anyToAnyModel represent a model which supports Any-to-Any operations
|
||||||
|
// We have to wrap this out as well because we want to load two models one for VAD and one for the actual model.
|
||||||
|
// In the future there could be models that accept continous audio input only so this design will be useful for that
|
||||||
|
type anyToAnyModel struct {
|
||||||
|
LLMConfig *config.BackendConfig
|
||||||
|
LLMClient grpcClient.Backend
|
||||||
|
|
||||||
|
VADConfig *config.BackendConfig
|
||||||
|
VADClient grpcClient.Backend
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *wrappedModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
|
||||||
|
return m.VADClient.VAD(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *anyToAnyModel) VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) {
|
||||||
|
return m.VADClient.VAD(ctx, in)
|
||||||
|
}
|
||||||
|
|
||||||
|
// returns and loads either a wrapped model or a model that support audio-to-audio
|
||||||
|
func newModel(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelName string) (Model, error) {
|
||||||
|
|
||||||
|
cfg, err := cl.LoadBackendConfigFileByName(modelName, ml.ModelPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Validate() {
|
||||||
|
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare VAD model
|
||||||
|
cfgVAD, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.VAD, ml.ModelPath)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfgVAD.Validate() {
|
||||||
|
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := backend.ModelOptions(*cfgVAD, appConfig)
|
||||||
|
VADClient, err := ml.Load(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load tts model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we don't have Wrapped model definitions, just return a standard model
|
||||||
|
if cfg.Pipeline.IsNotConfigured() {
|
||||||
|
|
||||||
|
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
|
||||||
|
cfgAnyToAny, err := cl.LoadBackendConfigFileByName(cfg.Model, ml.ModelPath)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfgAnyToAny.Validate() {
|
||||||
|
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts := backend.ModelOptions(*cfgAnyToAny, appConfig)
|
||||||
|
anyToAnyClient, err := ml.Load(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load tts model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &anyToAnyModel{
|
||||||
|
LLMConfig: cfgAnyToAny,
|
||||||
|
LLMClient: anyToAnyClient,
|
||||||
|
VADConfig: cfgVAD,
|
||||||
|
VADClient: VADClient,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msg("Loading a wrapped model")
|
||||||
|
|
||||||
|
// Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations
|
||||||
|
cfgLLM, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.LLM, ml.ModelPath)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Validate() {
|
||||||
|
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgTTS, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.TTS, ml.ModelPath)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Validate() {
|
||||||
|
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfgSST, err := cl.LoadBackendConfigFileByName(cfg.Pipeline.Transcription, ml.ModelPath)
|
||||||
|
if err != nil {
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("failed to load backend config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cfg.Validate() {
|
||||||
|
return nil, fmt.Errorf("failed to validate config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts = backend.ModelOptions(*cfgTTS, appConfig)
|
||||||
|
ttsClient, err := ml.Load(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load tts model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts = backend.ModelOptions(*cfgSST, appConfig)
|
||||||
|
transcriptionClient, err := ml.Load(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load SST model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
opts = backend.ModelOptions(*cfgLLM, appConfig)
|
||||||
|
llmClient, err := ml.Load(opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to load LLM model: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &wrappedModel{
|
||||||
|
TTSConfig: cfgTTS,
|
||||||
|
TranscriptionConfig: cfgSST,
|
||||||
|
LLMConfig: cfgLLM,
|
||||||
|
TTSClient: ttsClient,
|
||||||
|
TranscriptionClient: transcriptionClient,
|
||||||
|
LLMClient: llmClient,
|
||||||
|
|
||||||
|
VADConfig: cfgVAD,
|
||||||
|
VADClient: VADClient,
|
||||||
|
}, nil
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user