diff --git a/core/http/app.go b/core/http/app.go index 5ce7453b..19d3eb40 100644 --- a/core/http/app.go +++ b/core/http/app.go @@ -180,16 +180,26 @@ func API(application *application.Application) (*fiber.App, error) { Browse: true, })) - app.Use("/ws", func(c *fiber.Ctx) error { - // IsWebSocketUpgrade returns true if the client - // requested upgrade to the WebSocket protocol. + app.Use(func(c *fiber.Ctx) error { if websocket.IsWebSocketUpgrade(c) { - c.Locals("allowed", true) - return c.Next() + // Returns true if the client requested upgrade to the WebSocket protocol + c.Next() } - return fiber.ErrUpgradeRequired + + return nil }) + // app.Use("/v1/realtime", func(c *fiber.Ctx) error { + // fmt.Println("Hit upgrade from http") + // // IsWebSocketUpgrade returns true if the client + // // requested upgrade to the WebSocket protocol. + // if websocket.IsWebSocketUpgrade(c) { + // c.Locals("allowed", true) + // return c.Next() + // } + // return fiber.ErrUpgradeRequired + // }) + // Define a custom 404 handler // Note: keep this at the bottom! router.Use(notFoundHandler) diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 2b401dc3..9559e170 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -106,13 +106,16 @@ var sessionLock sync.Mutex func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) { return func(c *websocket.Conn) { + + log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String()) + // Generate a unique session ID sessionID := generateSessionID() - modelFile, input, err := readWSRequest(c, cl, ml, appConfig, true) - if err != nil { - return fmt.Errorf("failed reading parameters from request:%w", err) - } + // modelFile, input, err := readWSRequest(c, cl, ml, appConfig, true) + // if err != nil { + // return fmt.Errorf("failed reading parameters from request:%w", err) + // } session := &Session{ ID: sessionID, diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index e1b25c51..548b015e 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -6,7 +6,6 @@ import ( "fmt" "github.com/gofiber/fiber/v2" - "github.com/gofiber/websocket/v2" "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" fiberContext "github.com/mudler/LocalAI/core/http/ctx" @@ -49,24 +48,24 @@ func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLo return modelFile, input, err } -func readWSRequest(c *websocket.Conn, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { - input := new(schema.OpenAIRequest) +// func readWSRequest(c *websocket.Conn, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { +// input := new(schema.OpenAIRequest) - input.Model = c.Query("name") +// input.Model = c.Query("name") - received, _ := json.Marshal(input) +// received, _ := json.Marshal(input) - ctx, cancel := context.WithCancel(o.Context) +// ctx, cancel := context.WithCancel(o.Context) - input.Context = ctx - input.Cancel = cancel +// input.Context = ctx +// input.Cancel = cancel - log.Debug().Msgf("Request received: %s", string(received)) +// log.Debug().Msgf("Request received: %s", string(received)) - modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel) +// modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel) - return modelFile, input, err -} +// return modelFile, input, err +// } func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { if input.Echo {