Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-10-16 09:02:14 +02:00
parent 1ec5d1f51c
commit 38a417b65c
3 changed files with 34 additions and 22 deletions

View File

@ -180,16 +180,26 @@ func API(application *application.Application) (*fiber.App, error) {
Browse: true,
}))
app.Use("/ws", func(c *fiber.Ctx) error {
// IsWebSocketUpgrade returns true if the client
// requested upgrade to the WebSocket protocol.
app.Use(func(c *fiber.Ctx) error {
if websocket.IsWebSocketUpgrade(c) {
c.Locals("allowed", true)
return c.Next()
// Returns true if the client requested upgrade to the WebSocket protocol
c.Next()
}
return fiber.ErrUpgradeRequired
return nil
})
// app.Use("/v1/realtime", func(c *fiber.Ctx) error {
// fmt.Println("Hit upgrade from http")
// // IsWebSocketUpgrade returns true if the client
// // requested upgrade to the WebSocket protocol.
// if websocket.IsWebSocketUpgrade(c) {
// c.Locals("allowed", true)
// return c.Next()
// }
// return fiber.ErrUpgradeRequired
// })
// Define a custom 404 handler
// Note: keep this at the bottom!
router.Use(notFoundHandler)

View File

@ -106,13 +106,16 @@ var sessionLock sync.Mutex
func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *websocket.Conn) {
return func(c *websocket.Conn) {
log.Debug().Msgf("WebSocket connection established with '%s'", c.RemoteAddr().String())
// Generate a unique session ID
sessionID := generateSessionID()
modelFile, input, err := readWSRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
// modelFile, input, err := readWSRequest(c, cl, ml, appConfig, true)
// if err != nil {
// return fmt.Errorf("failed reading parameters from request:%w", err)
// }
session := &Session{
ID: sessionID,

View File

@ -6,7 +6,6 @@ import (
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
@ -49,24 +48,24 @@ func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLo
return modelFile, input, err
}
func readWSRequest(c *websocket.Conn, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
// func readWSRequest(c *websocket.Conn, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
// input := new(schema.OpenAIRequest)
input.Model = c.Query("name")
// input.Model = c.Query("name")
received, _ := json.Marshal(input)
// received, _ := json.Marshal(input)
ctx, cancel := context.WithCancel(o.Context)
// ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
// input.Context = ctx
// input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
// log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
// modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
return modelFile, input, err
}
// return modelFile, input, err
// }
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo {