Small adaptations

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-12-27 18:39:56 +01:00
parent 06e438d68b
commit c526f05de5
5 changed files with 26 additions and 9 deletions

View File

@ -10,7 +10,9 @@ import (
"time"
"github.com/go-audio/audio"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/config"
"github.com/mudler/LocalAI/pkg/grpc/proto"
model "github.com/mudler/LocalAI/pkg/model"
@ -121,10 +123,14 @@ var sessionLock sync.Mutex
type Model interface {
VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error)
Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error)
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
}
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
func Realtime(application *application.Application) fiber.Handler {
return websocket.New(registerRealtime(application))
}
func registerRealtime(application *application.Application) func(c *websocket.Conn) {
return func(c *websocket.Conn) {
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
@ -153,7 +159,12 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
session.Conversations[conversationID] = conversation
session.DefaultConversationID = conversationID
m, err := newModel(cl, ml, appConfig, model)
m, err := newModel(
application.BackendLoader(),
application.ModelLoader(),
application.ApplicationConfig(),
model,
)
if err != nil {
log.Error().Msgf("failed to load model: %s", err.Error())
sendError(c, "model_load_error", "Failed to load model", "", "")
@ -210,7 +221,13 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
sendError(c, "invalid_session_update", "Invalid session update format", "", "")
continue
}
if err := updateSession(session, &sessionUpdate, cl, ml, appConfig); err != nil {
if err := updateSession(
session,
&sessionUpdate,
application.BackendLoader(),
application.ModelLoader(),
application.ApplicationConfig(),
); err != nil {
log.Error().Msgf("failed to update session: %s", err.Error())
sendError(c, "session_update_error", "Failed to update session", "", "")
continue

View File

@ -59,7 +59,7 @@ func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, op
return m.LLMClient.Predict(ctx, in)
}
func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
// TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
return m.LLMClient.PredictStream(ctx, in, f)
@ -69,7 +69,7 @@ func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, o
return m.LLMClient.Predict(ctx, in)
}
func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
return m.LLMClient.PredictStream(ctx, in, f)
}

View File

@ -2,7 +2,6 @@ package routes
import (
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/mudler/LocalAI/core/application"
"github.com/mudler/LocalAI/core/http/endpoints/localai"
"github.com/mudler/LocalAI/core/http/endpoints/openai"
@ -13,7 +12,7 @@ func RegisterOpenAIRoutes(app *fiber.App,
// openAI compatible API endpoint
// realtime
app.Get("/v1/realtime", websocket.New(openai.RegisterRealtime(cl, ml, appConfig)))
app.Get("/v1/realtime", openai.Realtime(application))
// chat
app.Post("/v1/chat/completions",

1
go.mod
View File

@ -100,7 +100,6 @@ require (
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/go-viper/mapstructure/v2 v2.0.0 // indirect
github.com/gofiber/contrib/websocket v1.3.2 // indirect
github.com/gofiber/websocket/v2 v2.2.1 // indirect
github.com/google/s2a-go v0.1.7 // indirect
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
github.com/googleapis/gax-go/v2 v2.12.4 // indirect

2
go.sum
View File

@ -167,6 +167,7 @@ github.com/fasthttp/websocket v1.5.3 h1:TPpQuLwJYfd4LJPXvHDYPMFWbLjsT91n3GpWtCQt
github.com/fasthttp/websocket v1.5.3/go.mod h1:46gg/UBmTU1kUaTcwQXpUxtRwG2PvIZYeA8oL6vF3Fs=
github.com/fasthttp/websocket v1.5.8 h1:k5DpirKkftIF/w1R8ZzjSgARJrs54Je9YJK37DL/Ah8=
github.com/fasthttp/websocket v1.5.8/go.mod h1:d08g8WaT6nnyvg9uMm8K9zMYyDjfKyj3170AtPRuVU0=
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
@ -410,6 +411,7 @@ github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+
github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY=
github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g=
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
github.com/labstack/echo/v4 v4.12.0/go.mod h1:UP9Cr2DJXbOK3Kr9ONYzNowSh7HP0aG0ShAyycHSJvM=
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2 h1:hRGSmZu7j271trc9sneMrpOW7GN5ngLm8YUZIPzf394=
github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=