Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-10-16 09:02:14 +02:00
parent 1ec5d1f51c
commit 38a417b65c
3 changed files with 34 additions and 22 deletions

View File

@ -180,16 +180,26 @@ func API(application *application.Application) (*fiber.App, error) {
Browse: true, Browse: true,
})) }))
app.Use("/ws", func(c *fiber.Ctx) error { app.Use(func(c *fiber.Ctx) error {
// IsWebSocketUpgrade returns true if the client
// requested upgrade to the WebSocket protocol.
if websocket.IsWebSocketUpgrade(c) { if websocket.IsWebSocketUpgrade(c) {
c.Locals("allowed", true) // Returns true if the client requested upgrade to the WebSocket protocol
return c.Next() c.Next()
} }
return fiber.ErrUpgradeRequired
return nil
}) })
// app.Use("/v1/realtime", func(c *fiber.Ctx) error {
// fmt.Println("Hit upgrade from http")
// // IsWebSocketUpgrade returns true if the client
// // requested upgrade to the WebSocket protocol.
// if websocket.IsWebSocketUpgrade(c) {
// c.Locals("allowed", true)
// return c.Next()
// }
// return fiber.ErrUpgradeRequired
// })
// Define a custom 404 handler // Define a custom 404 handler
// Note: keep this at the bottom! // Note: keep this at the bottom!
router.Use(notFoundHandler) router.Use(notFoundHandler)

View File

@ -106,13 +106,16 @@ var sessionLock sync.Mutex
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) { func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
return func(c *websocket.Conn) { return func(c *websocket.Conn) {
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
// Generate a unique session ID // Generate a unique session ID
sessionID := generateSessionID() sessionID := generateSessionID()
modelFile, input, err := readWSRequest(c, cl, ml, appConfig, true) // modelFile, input, err := readWSRequest(c, cl, ml, appConfig, true)
if err != nil { // if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err) // return fmt.Errorf("failed reading parameters from request:%w", err)
} // }
session := &Session{ session := &Session{
ID: sessionID, ID: sessionID,

View File

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"github.com/gofiber/fiber/v2" "github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx" fiberContext "github.com/mudler/LocalAI/core/http/ctx"
@ -49,24 +48,24 @@ func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLo
return modelFile, input, err return modelFile, input, err
} }
func readWSRequest(c *websocket.Conn, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { // func readWSRequest(c *websocket.Conn, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest) // input := new(schema.OpenAIRequest)
input.Model = c.Query("name") // input.Model = c.Query("name")
received, _ := json.Marshal(input) // received, _ := json.Marshal(input)
ctx, cancel := context.WithCancel(o.Context) // ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx // input.Context = ctx
input.Cancel = cancel // input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received)) // log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel) // modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
return modelFile, input, err // return modelFile, input, err
} // }
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo { if input.Echo {