mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-21 05:33:09 +00:00
5866fc8ded
Signed-off-by: Sertac Ozercan <sozercan@gmail.com>
300 lines
7.2 KiB
Go
300 lines
7.2 KiB
Go
package openai
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gofiber/fiber/v2"
|
|
"github.com/mudler/LocalAI/core/config"
|
|
fiberContext "github.com/mudler/LocalAI/core/http/ctx"
|
|
"github.com/mudler/LocalAI/core/schema"
|
|
"github.com/mudler/LocalAI/pkg/functions"
|
|
model "github.com/mudler/LocalAI/pkg/model"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
func readRequest(c *fiber.Ctx, ml *model.ModelLoader, o *config.ApplicationConfig, firstModel bool) (string, *schema.OpenAIRequest, error) {
|
|
input := new(schema.OpenAIRequest)
|
|
|
|
// Get input data from the request body
|
|
if err := c.BodyParser(input); err != nil {
|
|
return "", nil, fmt.Errorf("failed parsing request body: %w", err)
|
|
}
|
|
|
|
received, _ := json.Marshal(input)
|
|
|
|
ctx, cancel := context.WithCancel(o.Context)
|
|
input.Context = ctx
|
|
input.Cancel = cancel
|
|
|
|
log.Debug().Msgf("Request received: %s", string(received))
|
|
|
|
modelFile, err := fiberContext.ModelFromContext(c, ml, input.Model, firstModel)
|
|
|
|
return modelFile, input, err
|
|
}
|
|
|
|
// this function check if the string is an URL, if it's an URL downloads the image in memory
|
|
// encodes it in base64 and returns the base64 string
|
|
func getBase64Image(s string) (string, error) {
|
|
if strings.HasPrefix(s, "http") {
|
|
// download the image
|
|
resp, err := http.Get(s)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// read the image data into memory
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
// encode the image data in base64
|
|
encoded := base64.StdEncoding.EncodeToString(data)
|
|
|
|
// return the base64 string
|
|
return encoded, nil
|
|
}
|
|
|
|
// if the string instead is prefixed with "data:image/...;base64,", drop it
|
|
dropPrefix := []string{"data:image/jpeg;base64,", "data:image/png;base64,"}
|
|
for _, prefix := range dropPrefix {
|
|
if strings.HasPrefix(s, prefix) {
|
|
return strings.ReplaceAll(s, prefix, ""), nil
|
|
}
|
|
}
|
|
|
|
return "", fmt.Errorf("not valid string")
|
|
}
|
|
|
|
func updateRequestConfig(config *config.BackendConfig, input *schema.OpenAIRequest) {
|
|
if input.Echo {
|
|
config.Echo = input.Echo
|
|
}
|
|
if input.TopK != nil {
|
|
config.TopK = input.TopK
|
|
}
|
|
if input.TopP != nil {
|
|
config.TopP = input.TopP
|
|
}
|
|
|
|
if input.Backend != "" {
|
|
config.Backend = input.Backend
|
|
}
|
|
|
|
if input.ClipSkip != 0 {
|
|
config.Diffusers.ClipSkip = input.ClipSkip
|
|
}
|
|
|
|
if input.ModelBaseName != "" {
|
|
config.AutoGPTQ.ModelBaseName = input.ModelBaseName
|
|
}
|
|
|
|
if input.NegativePromptScale != 0 {
|
|
config.NegativePromptScale = input.NegativePromptScale
|
|
}
|
|
|
|
if input.UseFastTokenizer {
|
|
config.UseFastTokenizer = input.UseFastTokenizer
|
|
}
|
|
|
|
if input.NegativePrompt != "" {
|
|
config.NegativePrompt = input.NegativePrompt
|
|
}
|
|
|
|
if input.RopeFreqBase != 0 {
|
|
config.RopeFreqBase = input.RopeFreqBase
|
|
}
|
|
|
|
if input.RopeFreqScale != 0 {
|
|
config.RopeFreqScale = input.RopeFreqScale
|
|
}
|
|
|
|
if input.Grammar != "" {
|
|
config.Grammar = input.Grammar
|
|
}
|
|
|
|
if input.Temperature != nil {
|
|
config.Temperature = input.Temperature
|
|
}
|
|
|
|
if input.Maxtokens != nil {
|
|
config.Maxtokens = input.Maxtokens
|
|
}
|
|
|
|
if input.ResponseFormat != nil {
|
|
switch responseFormat := input.ResponseFormat.(type) {
|
|
case string:
|
|
config.ResponseFormat = responseFormat
|
|
case map[string]interface{}:
|
|
config.ResponseFormatMap = responseFormat
|
|
}
|
|
}
|
|
|
|
switch stop := input.Stop.(type) {
|
|
case string:
|
|
if stop != "" {
|
|
config.StopWords = append(config.StopWords, stop)
|
|
}
|
|
case []interface{}:
|
|
for _, pp := range stop {
|
|
if s, ok := pp.(string); ok {
|
|
config.StopWords = append(config.StopWords, s)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(input.Tools) > 0 {
|
|
for _, tool := range input.Tools {
|
|
input.Functions = append(input.Functions, tool.Function)
|
|
}
|
|
}
|
|
|
|
if input.ToolsChoice != nil {
|
|
var toolChoice functions.Tool
|
|
|
|
switch content := input.ToolsChoice.(type) {
|
|
case string:
|
|
_ = json.Unmarshal([]byte(content), &toolChoice)
|
|
case map[string]interface{}:
|
|
dat, _ := json.Marshal(content)
|
|
_ = json.Unmarshal(dat, &toolChoice)
|
|
}
|
|
input.FunctionCall = map[string]interface{}{
|
|
"name": toolChoice.Function.Name,
|
|
}
|
|
}
|
|
|
|
// Decode each request's message content
|
|
index := 0
|
|
for i, m := range input.Messages {
|
|
switch content := m.Content.(type) {
|
|
case string:
|
|
input.Messages[i].StringContent = content
|
|
case []interface{}:
|
|
dat, _ := json.Marshal(content)
|
|
c := []schema.Content{}
|
|
json.Unmarshal(dat, &c)
|
|
for _, pp := range c {
|
|
if pp.Type == "text" {
|
|
input.Messages[i].StringContent = pp.Text
|
|
} else if pp.Type == "image_url" {
|
|
// Detect if pp.ImageURL is an URL, if it is download the image and encode it in base64:
|
|
base64, err := getBase64Image(pp.ImageURL.URL)
|
|
if err == nil {
|
|
input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff
|
|
// set a placeholder for each image
|
|
input.Messages[i].StringContent = fmt.Sprintf("[img-%d]", index) + input.Messages[i].StringContent
|
|
index++
|
|
} else {
|
|
log.Error().Msgf("Failed encoding image: %s", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if input.RepeatPenalty != 0 {
|
|
config.RepeatPenalty = input.RepeatPenalty
|
|
}
|
|
|
|
if input.FrequencyPenalty != 0 {
|
|
config.FrequencyPenalty = input.FrequencyPenalty
|
|
}
|
|
|
|
if input.PresencePenalty != 0 {
|
|
config.PresencePenalty = input.PresencePenalty
|
|
}
|
|
|
|
if input.Keep != 0 {
|
|
config.Keep = input.Keep
|
|
}
|
|
|
|
if input.Batch != 0 {
|
|
config.Batch = input.Batch
|
|
}
|
|
|
|
if input.IgnoreEOS {
|
|
config.IgnoreEOS = input.IgnoreEOS
|
|
}
|
|
|
|
if input.Seed != nil {
|
|
config.Seed = input.Seed
|
|
}
|
|
|
|
if input.TypicalP != nil {
|
|
config.TypicalP = input.TypicalP
|
|
}
|
|
|
|
switch inputs := input.Input.(type) {
|
|
case string:
|
|
if inputs != "" {
|
|
config.InputStrings = append(config.InputStrings, inputs)
|
|
}
|
|
case []interface{}:
|
|
for _, pp := range inputs {
|
|
switch i := pp.(type) {
|
|
case string:
|
|
config.InputStrings = append(config.InputStrings, i)
|
|
case []interface{}:
|
|
tokens := []int{}
|
|
for _, ii := range i {
|
|
tokens = append(tokens, int(ii.(float64)))
|
|
}
|
|
config.InputToken = append(config.InputToken, tokens)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Can be either a string or an object
|
|
switch fnc := input.FunctionCall.(type) {
|
|
case string:
|
|
if fnc != "" {
|
|
config.SetFunctionCallString(fnc)
|
|
}
|
|
case map[string]interface{}:
|
|
var name string
|
|
n, exists := fnc["name"]
|
|
if exists {
|
|
nn, e := n.(string)
|
|
if e {
|
|
name = nn
|
|
}
|
|
}
|
|
config.SetFunctionCallNameString(name)
|
|
}
|
|
|
|
switch p := input.Prompt.(type) {
|
|
case string:
|
|
config.PromptStrings = append(config.PromptStrings, p)
|
|
case []interface{}:
|
|
for _, pp := range p {
|
|
if s, ok := pp.(string); ok {
|
|
config.PromptStrings = append(config.PromptStrings, s)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func mergeRequestWithConfig(modelFile string, input *schema.OpenAIRequest, cm *config.BackendConfigLoader, loader *model.ModelLoader, debug bool, threads, ctx int, f16 bool) (*config.BackendConfig, *schema.OpenAIRequest, error) {
|
|
cfg, err := cm.LoadBackendConfigFileByName(modelFile, loader.ModelPath,
|
|
config.LoadOptionDebug(debug),
|
|
config.LoadOptionThreads(threads),
|
|
config.LoadOptionContextSize(ctx),
|
|
config.LoadOptionF16(f16),
|
|
config.ModelPath(loader.ModelPath),
|
|
)
|
|
|
|
// Set the parameters for the language model prediction
|
|
updateRequestConfig(cfg, input)
|
|
|
|
return cfg, input, err
|
|
}
|