Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-10-09 12:57:20 +02:00
parent dfe6d0e9c5
commit 4e19aab501
2 changed files with 26 additions and 0 deletions

View File

@ -108,6 +108,12 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app
return func(c *websocket.Conn) {
// Generate a unique session ID
sessionID := generateSessionID()
modelFile, input, err := readWSRequest(c, cl, ml, appConfig, true)
if err != nil {
return fmt.Errorf("failed reading parameters from request:%w", err)
}
session := &Session{
ID: sessionID,
Model: "gpt-4o", // default model

View File

@ -6,6 +6,7 @@ import (
"fmt"
"github.com/gofiber/fiber/v2"
"github.com/gofiber/websocket/v2"
"github.com/google/uuid"
"github.com/mudler/LocalAI/core/config"
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
@ -48,6 +49,25 @@ func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLo
return modelFile, input, err
}
func readWSRequest(c *websocket.Conn, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
input := new(schema.OpenAIRequest)
input.Model = c.Query("name")
received, _ := json.Marshal(input)
ctx, cancel := context.WithCancel(o.Context)
input.Context = ctx
input.Cancel = cancel
log.Debug().Msgf("Request received: %s", string(received))
modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel)
return modelFile, input, err
}
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
if input.Echo {
config.Echo = input.Echo