diff --git a/core/http/endpoints/openai/realtime.go b/core/http/endpoints/openai/realtime.go index 0ba28699..2b401dc3 100644 --- a/core/http/endpoints/openai/realtime.go +++ b/core/http/endpoints/openai/realtime.go @@ -108,6 +108,12 @@ func RegisterRealtime(cl *config.BackendConfigLoader, ml *model.ModelLoader, app return func(c *websocket.Conn) { // Generate a unique session ID sessionID := generateSessionID() + + modelFile, input, err := readWSRequest(c, cl, ml, appConfig, true) + if err != nil { + return fmt.Errorf("failed reading parameters from request:%w", err) + } + session := &Session{ ID: sessionID, Model: "gpt-4o", // default model diff --git a/core/http/endpoints/openai/request.go b/core/http/endpoints/openai/request.go index 2451f15f..e1b25c51 100644 --- a/core/http/endpoints/openai/request.go +++ b/core/http/endpoints/openai/request.go @@ -6,6 +6,7 @@ import ( "fmt" "github.com/gofiber/fiber/v2" + "github.com/gofiber/websocket/v2" "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" fiberContext "github.com/mudler/LocalAI/core/http/ctx" @@ -48,6 +49,25 @@ func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLo return modelFile, input, err } +func readWSRequest(c *websocket.Conn, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { + input := new(schema.OpenAIRequest) + + input.Model = c.Query("name") + + received, _ := json.Marshal(input) + + ctx, cancel := context.WithCancel(o.Context) + + input.Context = ctx + input.Cancel = cancel + + log.Debug().Msgf("Request received: %s", string(received)) + + modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel) + + return modelFile, input, err +} + func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { if input.Echo { config.Echo = input.Echo