diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index c36bad96..d70c42b0 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -10,7 +10,9 @@ import ( "time" "github.com/go-audio/audio" + "github.com/gofiber/fiber/v2" "github.com/gofiber/websocket/v2" + "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" @@ -121,10 +123,14 @@ var sessionLock sync.Mutex type Model interface { VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error) Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error) - PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error + PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error } -func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) { +func Realtime(application *application.Application) fiber.Handler { + return websocket.New(registerRealtime(application)) +} + +func registerRealtime(application *application.Application) func(c *websocket.Conn) { return func(c *websocket.Conn) { log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String()) @@ -153,7 +159,12 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app session.Conversations[conversationID] = conversation session.DefaultConversationID = conversationID - m, err := newModel(cl, ml, appConfig, model) + m, err := newModel( + application.BackendLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + model, + ) if err != nil { log.Error().Msgf("failed to load model: %s", err.Error()) sendError(c, "model_load_error", "Failed to load model", "", "") @@ -210,7 +221,13 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app sendError(c, "invalid_session_update", "Invalid session update format", "", "") continue } - if err := updateSession(session, &sessionUpdate, cl, ml, appConfig); err != nil { + if err := updateSession( + session, + &sessionUpdate, + application.BackendLoader(), + application.ModelLoader(), + application.ApplicationConfig(), + ); err != nil { log.Error().Msgf("failed to update session: %s", err.Error()) sendError(c, "session_update_error", "Failed to update session", "", "") continue diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index 20b77862..3b06c783 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -59,7 +59,7 @@ func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, op return m.LLMClient.Predict(ctx, in) } -func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { +func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error { // TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it) return m.LLMClient.PredictStream(ctx, in, f) @@ -69,7 +69,7 @@ func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, o return m.LLMClient.Predict(ctx, in) } -func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error { +func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error { return m.LLMClient.PredictStream(ctx, in, f) } diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index 8349d76c..fec66cf8 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -2,7 +2,6 @@ package routes import ( "github.com/gofiber/fiber/v2" - "github.com/gofiber/websocket/v2" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/core/http/endpoints/openai" @@ -13,7 +12,7 @@ func RegisterOpenAIRoutes(app *fiber.App, // openAI compatible API endpoint // realtime - app.Get("/v1/realtime", websocket.New(openai.RegisterRealtime(cl, ml, appConfig))) + app.Get("/v1/realtime", openai.Realtime(application)) // chat app.Post("/v1/chat/completions", diff --git a/go.mod b/go.mod index 72adc007..d8a66d7c 100644 --- a/go.mod +++ b/go.mod @@ -100,7 +100,6 @@ require ( github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/go-viper/mapstructure/v2 v2.0.0 // indirect github.com/gofiber/contrib/websocket v1.3.2 // indirect - github.com/gofiber/websocket/v2 v2.2.1 // indirect github.com/google/s2a-go v0.1.7 // indirect github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/googleapis/gax-go/v2 v2.12.4 // indirect diff --git a/go.sum b/go.sum index 5a13b4ea..b9fe0cb8 100644 --- a/go.sum +++ b/go.sum @@ -167,6 +167,7 @@ github.com/fasthttp/websocket v1.5.3 h1:TPpQuLwJYfd4LJPXvHDYPMFWbLjsT91n3GpWtCQt github.com/fasthttp/websocket v1.5.3/go.mod h1:46gg/UBmTU1kUaTcwQXpUxtRwG2PvIZYeA8oL6vF3Fs= github.com/fasthttp/websocket v1.5.8 h1:k5DpirKkftIF/w1R8ZzjSgARJrs54Je9YJK37DL/Ah8= github.com/fasthttp/websocket v1.5.8/go.mod h1:d08g8WaT6nnyvg9uMm8K9zMYyDjfKyj3170AtPRuVU0= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc= github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg= github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag= @@ -410,6 +411,7 @@ github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+ github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY= github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= +github.com/labstack/echo/v4 v4.12.0/go.mod h1:UP9Cr2DJXbOK3Kr9ONYzNowSh7HP0aG0ShAyycHSJvM= github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU= github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2 h1:hRGSmZu7j271trc9sneMrpOW7GN5ngLm8YUZIPzf394= github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=