mirror of
https://github.com/mudler/LocalAI.git
synced 2025-01-27 14:49:39 +00:00
Small adaptations
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
06e438d68b
commit
c526f05de5
@ -10,7 +10,9 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/go-audio/audio"
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/config"
|
||||
"github.com/mudler/LocalAI/pkg/grpc/proto"
|
||||
model "github.com/mudler/LocalAI/pkg/model"
|
||||
@ -121,10 +123,14 @@ var sessionLock sync.Mutex
|
||||
type Model interface {
|
||||
VAD(ctx context.Context, in *proto.VADRequest, opts ...grpc.CallOption) (*proto.VADResponse, error)
|
||||
Predict(ctx context.Context, in *proto.PredictOptions, opts ...grpc.CallOption) (*proto.Reply, error)
|
||||
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error
|
||||
PredictStream(ctx context.Context, in *proto.PredictOptions, f func(*proto.Reply), opts ...grpc.CallOption) error
|
||||
}
|
||||
|
||||
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
|
||||
func Realtime(application *application.Application) fiber.Handler {
|
||||
return websocket.New(registerRealtime(application))
|
||||
}
|
||||
|
||||
func registerRealtime(application *application.Application) func(c *websocket.Conn) {
|
||||
return func(c *websocket.Conn) {
|
||||
|
||||
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
|
||||
@ -153,7 +159,12 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
||||
session.Conversations[conversationID] = conversation
|
||||
session.DefaultConversationID = conversationID
|
||||
|
||||
m, err := newModel(cl, ml, appConfig, model)
|
||||
m, err := newModel(
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.ApplicationConfig(),
|
||||
model,
|
||||
)
|
||||
if err != nil {
|
||||
log.Error().Msgf("failed to load model: %s", err.Error())
|
||||
sendError(c, "model_load_error", "Failed to load model", "", "")
|
||||
@ -210,7 +221,13 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
|
||||
sendError(c, "invalid_session_update", "Invalid session update format", "", "")
|
||||
continue
|
||||
}
|
||||
if err := updateSession(session, &sessionUpdate, cl, ml, appConfig); err != nil {
|
||||
if err := updateSession(
|
||||
session,
|
||||
&sessionUpdate,
|
||||
application.BackendLoader(),
|
||||
application.ModelLoader(),
|
||||
application.ApplicationConfig(),
|
||||
); err != nil {
|
||||
log.Error().Msgf("failed to update session: %s", err.Error())
|
||||
sendError(c, "session_update_error", "Failed to update session", "", "")
|
||||
continue
|
||||
|
@ -59,7 +59,7 @@ func (m *wrappedModel) Predict(ctx context.Context, in *proto.PredictOptions, op
|
||||
return m.LLMClient.Predict(ctx, in)
|
||||
}
|
||||
|
||||
func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
|
||||
func (m *wrappedModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
|
||||
// TODO: Convert with pipeline (audio to text, text to llm, result to tts, and return it)
|
||||
|
||||
return m.LLMClient.PredictStream(ctx, in, f)
|
||||
@ -69,7 +69,7 @@ func (m *anyToAnyModel) Predict(ctx context.Context, in *proto.PredictOptions, o
|
||||
return m.LLMClient.Predict(ctx, in)
|
||||
}
|
||||
|
||||
func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(s []byte), opts ...grpc.CallOption) error {
|
||||
func (m *anyToAnyModel) PredictStream(ctx context.Context, in *proto.PredictOptions, f func(reply *proto.Reply), opts ...grpc.CallOption) error {
|
||||
return m.LLMClient.PredictStream(ctx, in, f)
|
||||
}
|
||||
|
||||
|
@ -2,7 +2,6 @@ package routes
|
||||
|
||||
import (
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/websocket/v2"
|
||||
"github.com/mudler/LocalAI/core/application"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/localai"
|
||||
"github.com/mudler/LocalAI/core/http/endpoints/openai"
|
||||
@ -13,7 +12,7 @@ func RegisterOpenAIRoutes(app *fiber.App,
|
||||
// openAI compatible API endpoint
|
||||
|
||||
// realtime
|
||||
app.Get("/v1/realtime", websocket.New(openai.RegisterRealtime(cl, ml, appConfig)))
|
||||
app.Get("/v1/realtime", openai.Realtime(application))
|
||||
|
||||
// chat
|
||||
app.Post("/v1/chat/completions",
|
||||
|
1
go.mod
1
go.mod
@ -100,7 +100,6 @@ require (
|
||||
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
|
||||
github.com/go-viper/mapstructure/v2 v2.0.0 // indirect
|
||||
github.com/gofiber/contrib/websocket v1.3.2 // indirect
|
||||
github.com/gofiber/websocket/v2 v2.2.1 // indirect
|
||||
github.com/google/s2a-go v0.1.7 // indirect
|
||||
github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
|
||||
github.com/googleapis/gax-go/v2 v2.12.4 // indirect
|
||||
|
2
go.sum
2
go.sum
@ -167,6 +167,7 @@ github.com/fasthttp/websocket v1.5.3 h1:TPpQuLwJYfd4LJPXvHDYPMFWbLjsT91n3GpWtCQt
|
||||
github.com/fasthttp/websocket v1.5.3/go.mod h1:46gg/UBmTU1kUaTcwQXpUxtRwG2PvIZYeA8oL6vF3Fs=
|
||||
github.com/fasthttp/websocket v1.5.8 h1:k5DpirKkftIF/w1R8ZzjSgARJrs54Je9YJK37DL/Ah8=
|
||||
github.com/fasthttp/websocket v1.5.8/go.mod h1:d08g8WaT6nnyvg9uMm8K9zMYyDjfKyj3170AtPRuVU0=
|
||||
github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
|
||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
|
||||
github.com/flynn/noise v1.1.0 h1:KjPQoQCEFdZDiP03phOvGi11+SVVhBG2wOWAorLsstg=
|
||||
github.com/flynn/noise v1.1.0/go.mod h1:xbMo+0i6+IGbYdJhF31t2eR1BIU0CYc12+BNAKwUTag=
|
||||
@ -410,6 +411,7 @@ github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+
|
||||
github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY=
|
||||
github.com/labstack/echo/v4 v4.13.3/go.mod h1:o90YNEeQWjDozo584l7AwhJMHN0bOC4tAfg+Xox9q5g=
|
||||
github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0=
|
||||
github.com/labstack/echo/v4 v4.12.0/go.mod h1:UP9Cr2DJXbOK3Kr9ONYzNowSh7HP0aG0ShAyycHSJvM=
|
||||
github.com/labstack/gommon v0.4.2/go.mod h1:QlUFxVM+SNXhDL/Z7YhocGIBYOiwB0mXm1+1bAPHPyU=
|
||||
github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2 h1:hRGSmZu7j271trc9sneMrpOW7GN5ngLm8YUZIPzf394=
|
||||
github.com/lib/pq v0.0.0-20180327071824-d34b9ff171c2/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
|
Loading…
x
Reference in New Issue
Block a user