2024-01-05 17:04:46 +00:00
|
|
|
package openai
|
|
|
|
|
|
|
|
import (
|
2024-04-17 21:33:49 +00:00
|
|
|
"bufio"
|
|
|
|
"encoding/base64"
|
2024-01-05 17:04:46 +00:00
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
2024-04-17 21:33:49 +00:00
|
|
|
"io"
|
|
|
|
"net/http"
|
|
|
|
"os"
|
|
|
|
"path/filepath"
|
|
|
|
"strconv"
|
|
|
|
"strings"
|
|
|
|
"time"
|
2024-01-05 17:04:46 +00:00
|
|
|
|
2024-04-17 21:33:49 +00:00
|
|
|
"github.com/go-skynet/LocalAI/core/config"
|
|
|
|
"github.com/go-skynet/LocalAI/core/schema"
|
|
|
|
"github.com/google/uuid"
|
2024-01-05 17:04:46 +00:00
|
|
|
|
2024-02-21 01:21:19 +00:00
|
|
|
"github.com/go-skynet/LocalAI/core/backend"
|
2024-03-01 15:19:53 +00:00
|
|
|
|
2024-04-17 21:33:49 +00:00
|
|
|
model "github.com/go-skynet/LocalAI/pkg/model"
|
2024-01-05 17:04:46 +00:00
|
|
|
"github.com/gofiber/fiber/v2"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
|
|
)
|
|
|
|
|
2024-04-17 21:33:49 +00:00
|
|
|
func downloadFile(url string) (string, error) {
|
|
|
|
// Get the data
|
|
|
|
resp, err := http.Get(url)
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
// Create the file
|
|
|
|
out, err := os.CreateTemp("", "image")
|
|
|
|
if err != nil {
|
|
|
|
return "", err
|
|
|
|
}
|
|
|
|
defer out.Close()
|
|
|
|
|
|
|
|
// Write the body to file
|
|
|
|
_, err = io.Copy(out, resp.Body)
|
|
|
|
return out.Name(), err
|
|
|
|
}
|
|
|
|
|
|
|
|
//
|
2024-01-05 17:04:46 +00:00
|
|
|
|
|
|
|
/*
|
|
|
|
*
|
|
|
|
|
|
|
|
curl http://localhost:8080/v1/images/generations \
|
|
|
|
-H "Content-Type: application/json" \
|
|
|
|
-d '{
|
|
|
|
"prompt": "A cute baby sea otter",
|
|
|
|
"n": 1,
|
|
|
|
"size": "512x512"
|
|
|
|
}'
|
|
|
|
|
|
|
|
*
|
|
|
|
*/
|
2024-03-29 21:29:33 +00:00
|
|
|
// ImageEndpoint is the OpenAI Image generation API endpoint https://platform.openai.com/docs/api-reference/images/create
|
|
|
|
// @Summary Creates an image given a prompt.
|
|
|
|
// @Param request body schema.OpenAIRequest true "query params"
|
|
|
|
// @Success 200 {object} schema.OpenAIResponse "Response"
|
|
|
|
// @Router /v1/images/generations [post]
|
2024-04-17 21:33:49 +00:00
|
|
|
func ImageEndpoint(cl *config.BackendConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) func(c *fiber.Ctx) error {
|
2024-01-05 17:04:46 +00:00
|
|
|
return func(c *fiber.Ctx) error {
|
2024-04-17 21:33:49 +00:00
|
|
|
m, input, err := readRequest(c, ml, appConfig, false)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
2024-01-05 17:04:46 +00:00
|
|
|
}
|
2024-04-17 21:33:49 +00:00
|
|
|
|
|
|
|
if m == "" {
|
|
|
|
m = model.StableDiffusionBackend
|
|
|
|
}
|
|
|
|
log.Debug().Msgf("Loading model: %+v", m)
|
|
|
|
|
|
|
|
config, input, err := mergeRequestWithConfig(m, input, cl, ml, appConfig.Debug, 0, 0, false)
|
2024-01-05 17:04:46 +00:00
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed reading parameters from request:%w", err)
|
|
|
|
}
|
|
|
|
|
2024-04-17 21:33:49 +00:00
|
|
|
src := ""
|
|
|
|
if input.File != "" {
|
|
|
|
|
|
|
|
fileData := []byte{}
|
|
|
|
// check if input.File is an URL, if so download it and save it
|
|
|
|
// to a temporary file
|
|
|
|
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
|
|
|
out, err := downloadFile(input.File)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed downloading file:%w", err)
|
|
|
|
}
|
|
|
|
defer os.RemoveAll(out)
|
|
|
|
|
|
|
|
fileData, err = os.ReadFile(out)
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("failed reading file:%w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
// base 64 decode the file and write it somewhere
|
|
|
|
// that we will cleanup
|
|
|
|
fileData, err = base64.StdEncoding.DecodeString(input.File)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
}
|
2024-01-05 17:04:46 +00:00
|
|
|
|
2024-04-17 21:33:49 +00:00
|
|
|
// Create a temporary file
|
|
|
|
outputFile, err := os.CreateTemp(appConfig.ImageDir, "b64")
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
// write the base64 result
|
|
|
|
writer := bufio.NewWriter(outputFile)
|
|
|
|
_, err = writer.Write(fileData)
|
|
|
|
if err != nil {
|
|
|
|
outputFile.Close()
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
outputFile.Close()
|
|
|
|
src = outputFile.Name()
|
|
|
|
defer os.RemoveAll(src)
|
2024-01-05 17:04:46 +00:00
|
|
|
}
|
|
|
|
|
2024-04-17 21:33:49 +00:00
|
|
|
log.Debug().Msgf("Parameter Config: %+v", config)
|
|
|
|
|
|
|
|
switch config.Backend {
|
|
|
|
case "stablediffusion":
|
|
|
|
config.Backend = model.StableDiffusionBackend
|
|
|
|
case "tinydream":
|
|
|
|
config.Backend = model.TinyDreamBackend
|
|
|
|
case "":
|
|
|
|
config.Backend = model.StableDiffusionBackend
|
|
|
|
}
|
|
|
|
|
|
|
|
sizeParts := strings.Split(input.Size, "x")
|
|
|
|
if len(sizeParts) != 2 {
|
|
|
|
return fmt.Errorf("invalid value for 'size'")
|
|
|
|
}
|
|
|
|
width, err := strconv.Atoi(sizeParts[0])
|
2024-01-05 17:04:46 +00:00
|
|
|
if err != nil {
|
2024-04-17 21:33:49 +00:00
|
|
|
return fmt.Errorf("invalid value for 'size'")
|
2024-01-05 17:04:46 +00:00
|
|
|
}
|
2024-04-17 21:33:49 +00:00
|
|
|
height, err := strconv.Atoi(sizeParts[1])
|
|
|
|
if err != nil {
|
|
|
|
return fmt.Errorf("invalid value for 'size'")
|
|
|
|
}
|
|
|
|
|
2024-05-29 12:40:54 +00:00
|
|
|
b64JSON := config.ResponseFormat == "b64_json"
|
|
|
|
|
2024-04-17 21:33:49 +00:00
|
|
|
// src and clip_skip
|
|
|
|
var result []schema.Item
|
|
|
|
for _, i := range config.PromptStrings {
|
|
|
|
n := input.N
|
|
|
|
if input.N == 0 {
|
|
|
|
n = 1
|
|
|
|
}
|
|
|
|
for j := 0; j < n; j++ {
|
|
|
|
prompts := strings.Split(i, "|")
|
|
|
|
positive_prompt := prompts[0]
|
|
|
|
negative_prompt := ""
|
|
|
|
if len(prompts) > 1 {
|
|
|
|
negative_prompt = prompts[1]
|
|
|
|
}
|
|
|
|
|
|
|
|
mode := 0
|
|
|
|
step := config.Step
|
|
|
|
if step == 0 {
|
|
|
|
step = 15
|
|
|
|
}
|
|
|
|
|
|
|
|
if input.Mode != 0 {
|
|
|
|
mode = input.Mode
|
|
|
|
}
|
|
|
|
|
|
|
|
if input.Step != 0 {
|
|
|
|
step = input.Step
|
|
|
|
}
|
|
|
|
|
|
|
|
tempDir := ""
|
|
|
|
if !b64JSON {
|
|
|
|
tempDir = appConfig.ImageDir
|
|
|
|
}
|
|
|
|
// Create a temporary file
|
|
|
|
outputFile, err := os.CreateTemp(tempDir, "b64")
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
outputFile.Close()
|
|
|
|
output := outputFile.Name() + ".png"
|
|
|
|
// Rename the temporary file
|
|
|
|
err = os.Rename(outputFile.Name(), output)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
baseURL := c.BaseURL()
|
|
|
|
|
|
|
|
fn, err := backend.ImageGeneration(height, width, mode, step, *config.Seed, positive_prompt, negative_prompt, src, output, ml, *config, appConfig)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
if err := fn(); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
item := &schema.Item{}
|
|
|
|
|
|
|
|
if b64JSON {
|
|
|
|
defer os.RemoveAll(output)
|
|
|
|
data, err := os.ReadFile(output)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
item.B64JSON = base64.StdEncoding.EncodeToString(data)
|
|
|
|
} else {
|
|
|
|
base := filepath.Base(output)
|
|
|
|
item.URL = baseURL + "/generated-images/" + base
|
|
|
|
}
|
|
|
|
|
|
|
|
result = append(result, *item)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
id := uuid.New().String()
|
|
|
|
created := int(time.Now().Unix())
|
|
|
|
resp := &schema.OpenAIResponse{
|
|
|
|
ID: id,
|
|
|
|
Created: created,
|
|
|
|
Data: result,
|
|
|
|
}
|
|
|
|
|
|
|
|
jsonResult, _ := json.Marshal(resp)
|
2024-01-05 17:04:46 +00:00
|
|
|
log.Debug().Msgf("Response: %s", jsonResult)
|
2024-04-17 21:33:49 +00:00
|
|
|
|
2024-01-05 17:04:46 +00:00
|
|
|
// Return the prediction in the response body
|
2024-04-17 21:33:49 +00:00
|
|
|
return c.JSON(resp)
|
2024-01-05 17:04:46 +00:00
|
|
|
}
|
|
|
|
}
|