package openai import ( "context" "encoding/json" "fmt" "github.com/gofiber/fiber/v2" "github.com/google/uuid" "github.com/mudler/LocalAI/core/config" fiberContext "github.com/mudler/LocalAI/core/http/ctx" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/templates" "github.com/mudler/LocalAI/pkg/utils" "github.com/rs/zerolog/log" ) type correlationIDKeyType string // CorrelationIDKey to track request across process boundary const CorrelationIDKey correlationIDKeyType = "correlationID" func readRequest(c *fiber.Ctx, cl *config.BackendConfigLoader, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) { input := new(schema.OpenAIRequest) // Get input data from the request body if err := c.BodyParser(input); err != nil { return "", nil, fmt.Errorf("failed parsing request body: %w", err) } received, _ := json.Marshal(input) // Extract or generate the correlation ID correlationID := c.Get("X-Correlation-ID", uuid.New().String()) ctx, cancel := context.WithCancel(o.Context) // Add the correlation ID to the new context ctxWithCorrelationID := context.WithValue(ctx, CorrelationIDKey, correlationID) input.Context = ctxWithCorrelationID input.Cancel = cancel log.Debug().Msgf("Request received: %s", string(received)) modelFile, err := fiberContext.ModelFromContext(c, cl, ml, input.Model, firstModel) return modelFile, input, err } func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) { if input.Echo { config.Echo = input.Echo } if input.TopK != nil { config.TopK = input.TopK } if input.TopP != nil { config.TopP = input.TopP } if input.Backend != "" { config.Backend = input.Backend } if input.ClipSkip != 0 { config.Diffusers.ClipSkip = input.ClipSkip } if input.ModelBaseName != "" { config.AutoGPTQ.ModelBaseName = input.ModelBaseName } if input.NegativePromptScale != 0 { config.NegativePromptScale = input.NegativePromptScale } if input.UseFastTokenizer { config.UseFastTokenizer = input.UseFastTokenizer } if input.NegativePrompt != "" { config.NegativePrompt = input.NegativePrompt } if input.RopeFreqBase != 0 { config.RopeFreqBase = input.RopeFreqBase } if input.RopeFreqScale != 0 { config.RopeFreqScale = input.RopeFreqScale } if input.Grammar != "" { config.Grammar = input.Grammar } if input.Temperature != nil { config.Temperature = input.Temperature } if input.Maxtokens != nil { config.Maxtokens = input.Maxtokens } if input.ResponseFormat != nil { switch responseFormat := input.ResponseFormat.(type) { case string: config.ResponseFormat = responseFormat case map[string]interface{}: config.ResponseFormatMap = responseFormat } } switch stop := input.Stop.(type) { case string: if stop != "" { config.StopWords = append(config.StopWords, stop) } case []interface{}: for _, pp := range stop { if s, ok := pp.(string); ok { config.StopWords = append(config.StopWords, s) } } } if len(input.Tools) > 0 { for _, tool := range input.Tools { input.Functions = append(input.Functions, tool.Function) } } if input.ToolsChoice != nil { var toolChoice functions.Tool switch content := input.ToolsChoice.(type) { case string: _ = json.Unmarshal([]byte(content), &toolChoice) case map[string]interface{}: dat, _ := json.Marshal(content) _ = json.Unmarshal(dat, &toolChoice) } input.FunctionCall = map[string]interface{}{ "name": toolChoice.Function.Name, } } // Decode each request's message content imgIndex, vidIndex, audioIndex := 0, 0, 0 for i, m := range input.Messages { switch content := m.Content.(type) { case string: input.Messages[i].StringContent = content case []interface{}: dat, _ := json.Marshal(content) c := []schema.Content{} json.Unmarshal(dat, &c) CONTENT: for _, pp := range c { switch pp.Type { case "text": input.Messages[i].StringContent = pp.Text case "video", "video_url": // Decode content as base64 either if it's an URL or base64 text base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) if err != nil { log.Error().Msgf("Failed encoding video: %s", err) continue CONTENT } input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff t := "[vid-{{.ID}}]{{.Text}}" if config.TemplateConfig.Video != "" { t = config.TemplateConfig.Video } // set a placeholder for each image input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, vidIndex, input.Messages[i].StringContent) vidIndex++ case "audio_url", "audio": // Decode content as base64 either if it's an URL or base64 text base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) if err != nil { log.Error().Msgf("Failed encoding image: %s", err) continue CONTENT } input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff // set a placeholder for each image t := "[audio-{{.ID}}]{{.Text}}" if config.TemplateConfig.Audio != "" { t = config.TemplateConfig.Audio } input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, audioIndex, input.Messages[i].StringContent) audioIndex++ case "image_url", "image": // Decode content as base64 either if it's an URL or base64 text base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) if err != nil { log.Error().Msgf("Failed encoding image: %s", err) continue CONTENT } t := "[img-{{.ID}}]{{.Text}}" if config.TemplateConfig.Image != "" { t = config.TemplateConfig.Image } input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff // set a placeholder for each image input.Messages[i].StringContent, _ = templates.TemplateMultiModal(t, imgIndex, input.Messages[i].StringContent) imgIndex++ } } } } if input.RepeatPenalty != 0 { config.RepeatPenalty = input.RepeatPenalty } if input.FrequencyPenalty != 0 { config.FrequencyPenalty = input.FrequencyPenalty } if input.PresencePenalty != 0 { config.PresencePenalty = input.PresencePenalty } if input.Keep != 0 { config.Keep = input.Keep } if input.Batch != 0 { config.Batch = input.Batch } if input.IgnoreEOS { config.IgnoreEOS = input.IgnoreEOS } if input.Seed != nil { config.Seed = input.Seed } if input.TypicalP != nil { config.TypicalP = input.TypicalP } switch inputs := input.Input.(type) { case string: if inputs != "" { config.InputStrings = append(config.InputStrings, inputs) } case []interface{}: for _, pp := range inputs { switch i := pp.(type) { case string: config.InputStrings = append(config.InputStrings, i) case []interface{}: tokens := []int{} for _, ii := range i { tokens = append(tokens, int(ii.(float64))) } config.InputToken = append(config.InputToken, tokens) } } } // Can be either a string or an object switch fnc := input.FunctionCall.(type) { case string: if fnc != "" { config.SetFunctionCallString(fnc) } case map[string]interface{}: var name string n, exists := fnc["name"] if exists { nn, e := n.(string) if e { name = nn } } config.SetFunctionCallNameString(name) } switch p := input.Prompt.(type) { case string: config.PromptStrings = append(config.PromptStrings, p) case []interface{}: for _, pp := range p { if s, ok := pp.(string); ok { config.PromptStrings = append(config.PromptStrings, s) } } } } func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) { cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath, config.LoadOptionDebug(debug), config.LoadOptionThreads(threads), config.LoadOptionContextSize(ctx), config.LoadOptionF16(f16), config.ModelPath(loader.ModelPath), ) // Set the parameters for the language model prediction updateRequestConfig(cfg, input) if !cfg.Validate() { return nil, nil, fmt.Errorf("failed to validate config") } return cfg, input, err }