mirror of
https://github.com/mudler/LocalAI.git
synced 2025-04-19 16:41:01 +00:00
feat(img2vid,txt2vid): Initial support for img2vid,txt2vid (#1442)
* feat(img2vid): Initial support for img2vid * doc(SD): fix SDXL Example * Minor fixups for img2vid * docs(img2img): fix example curl call * feat(txt2vid): initial support Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com> * diffusers: be retro-compatible with CUDA settings * docs(img2vid, txt2vid): examples * Add notice on docs --------- Signed-off-by: Ettore Di Giacinto <mudler@users.noreply.github.com>
This commit is contained in:
parent
fb6a5bc620
commit
dd982acf2c
@ -16,7 +16,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||
model.WithContext(o.Context),
|
||||
model.WithModel(c.Model),
|
||||
model.WithLoadGRPCLoadModelOpts(&proto.ModelOptions{
|
||||
CUDA: c.CUDA,
|
||||
CUDA: c.CUDA || c.Diffusers.CUDA,
|
||||
SchedulerType: c.Diffusers.SchedulerType,
|
||||
PipelineType: c.Diffusers.PipelineType,
|
||||
CFGScale: c.Diffusers.CFGScale,
|
||||
|
@ -68,6 +68,7 @@ type GRPC struct {
|
||||
}
|
||||
|
||||
type Diffusers struct {
|
||||
CUDA bool `yaml:"cuda"`
|
||||
PipelineType string `yaml:"pipeline_type"`
|
||||
SchedulerType string `yaml:"scheduler_type"`
|
||||
EnableParameters string `yaml:"enable_parameters"` // A list of comma separated parameters to specify
|
||||
|
@ -5,6 +5,8 @@ import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
@ -22,6 +24,26 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func downloadFile(url string) (string, error) {
|
||||
// Get the data
|
||||
resp, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Create the file
|
||||
out, err := os.CreateTemp("", "image")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer out.Close()
|
||||
|
||||
// Write the body to file
|
||||
_, err = io.Copy(out, resp.Body)
|
||||
return out.Name(), err
|
||||
}
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/images/create
|
||||
|
||||
/*
|
||||
@ -56,12 +78,31 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
||||
|
||||
src := ""
|
||||
if input.File != "" {
|
||||
//base 64 decode the file and write it somewhere
|
||||
// that we will cleanup
|
||||
decoded, err := base64.StdEncoding.DecodeString(input.File)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
fileData := []byte{}
|
||||
// check if input.File is an URL, if so download it and save it
|
||||
// to a temporary file
|
||||
if strings.HasPrefix(input.File, "http://") || strings.HasPrefix(input.File, "https://") {
|
||||
out, err := downloadFile(input.File)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed downloading file:%w", err)
|
||||
}
|
||||
defer os.RemoveAll(out)
|
||||
|
||||
fileData, err = os.ReadFile(out)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading file:%w", err)
|
||||
}
|
||||
|
||||
} else {
|
||||
// base 64 decode the file and write it somewhere
|
||||
// that we will cleanup
|
||||
fileData, err = base64.StdEncoding.DecodeString(input.File)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Create a temporary file
|
||||
outputFile, err := os.CreateTemp(o.ImageDir, "b64")
|
||||
if err != nil {
|
||||
@ -69,7 +110,7 @@ func ImageEndpoint(cm *config.ConfigLoader, o *options.Option) func(c *fiber.Ctx
|
||||
}
|
||||
// write the base64 result
|
||||
writer := bufio.NewWriter(outputFile)
|
||||
_, err = writer.Write(decoded)
|
||||
_, err = writer.Write(fileData)
|
||||
if err != nil {
|
||||
outputFile.Close()
|
||||
return err
|
||||
|
@ -18,9 +18,9 @@ import backend_pb2_grpc
|
||||
import grpc
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel
|
||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils import load_image,export_to_video
|
||||
from compel import Compel
|
||||
|
||||
from transformers import CLIPTextModel
|
||||
@ -31,6 +31,10 @@ _ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
COMPEL=os.environ.get("COMPEL", "1") == "1"
|
||||
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
||||
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
|
||||
CHUNK_SIZE=os.environ.get("CHUNK_SIZE", "8")
|
||||
FPS=os.environ.get("FPS", "7")
|
||||
DISABLE_CPU_OFFLOAD=os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
|
||||
FRAMES=os.environ.get("FRAMES", "64")
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
@ -163,7 +167,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
modelFile = request.ModelFile
|
||||
|
||||
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
||||
|
||||
self.img2vid=False
|
||||
self.txt2vid=False
|
||||
## img2img
|
||||
if (request.PipelineType == "StableDiffusionImg2ImgPipeline") or (request.IMG2IMG and request.PipelineType == ""):
|
||||
if fromSingleFile:
|
||||
@ -179,6 +184,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
## img2vid
|
||||
elif request.PipelineType == "StableVideoDiffusionPipeline":
|
||||
self.img2vid=True
|
||||
self.pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
request.Model, torch_dtype=torchType, variant=variant
|
||||
)
|
||||
if not DISABLE_CPU_OFFLOAD:
|
||||
self.pipe.enable_model_cpu_offload()
|
||||
## text2img
|
||||
elif request.PipelineType == "AutoPipelineForText2Image" or request.PipelineType == "":
|
||||
self.pipe = AutoPipelineForText2Image.from_pretrained(request.Model,
|
||||
@ -199,6 +212,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
elif request.PipelineType == "VideoDiffusionPipeline":
|
||||
self.txt2vid=True
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
elif request.PipelineType == "StableDiffusionXLPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
|
||||
@ -222,7 +240,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.SchedulerType != "":
|
||||
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
||||
|
||||
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
|
||||
if not self.img2vid:
|
||||
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
|
||||
|
||||
|
||||
if request.ControlNet:
|
||||
@ -331,7 +350,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"num_inference_steps": steps,
|
||||
}
|
||||
|
||||
if request.src != "" and not self.controlnet:
|
||||
if request.src != "" and not self.controlnet and not self.img2vid:
|
||||
image = Image.open(request.src)
|
||||
options["image"] = image
|
||||
elif self.controlnet and request.src:
|
||||
@ -359,6 +378,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
request.seed
|
||||
)
|
||||
|
||||
if self.img2vid:
|
||||
# Load the conditioning image
|
||||
image = load_image(request.src)
|
||||
image = image.resize((1024, 576))
|
||||
|
||||
generator = torch.manual_seed(request.seed)
|
||||
frames = self.pipe(image, decode_chunk_size=CHUNK_SIZE, generator=generator).frames[0]
|
||||
export_to_video(frames, request.dst, fps=FPS)
|
||||
return backend_pb2.Result(message="Media generated successfully", success=True)
|
||||
|
||||
if self.txt2vid:
|
||||
video_frames = self.pipe(prompt, num_inference_steps=steps, num_frames=int(FRAMES)).frames
|
||||
export_to_video(video_frames, request.dst)
|
||||
return backend_pb2.Result(message="Media generated successfully", success=True)
|
||||
|
||||
image = {}
|
||||
if COMPEL:
|
||||
conditioning = self.compel.build_conditioning_tensor(prompt)
|
||||
@ -377,7 +411,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
# save the result
|
||||
image.save(request.dst)
|
||||
|
||||
return backend_pb2.Result(message="Model loaded successfully", success=True)
|
||||
return backend_pb2.Result(message="Media generated", success=True)
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||
|
@ -147,7 +147,6 @@ backend: diffusers
|
||||
# Force CPU usage - set to true for GPU
|
||||
f16: false
|
||||
diffusers:
|
||||
pipeline_type: StableDiffusionXLPipeline
|
||||
cuda: false # Enable for GPU usage (CUDA)
|
||||
scheduler_type: euler_a
|
||||
```
|
||||
|
@ -15,7 +15,6 @@ backend: diffusers
|
||||
# Force CPU usage - set to true for GPU
|
||||
f16: false
|
||||
diffusers:
|
||||
pipeline_type: StableDiffusionXLPipeline
|
||||
cuda: false # Enable for GPU usage (CUDA)
|
||||
scheduler_type: dpm_2_a
|
||||
```
|
||||
|
@ -27,12 +27,9 @@ name: animagine-xl
|
||||
parameters:
|
||||
model: Linaqruf/animagine-xl
|
||||
backend: diffusers
|
||||
|
||||
# Force CPU usage - set to true for GPU
|
||||
f16: false
|
||||
cuda: true
|
||||
f16: true
|
||||
diffusers:
|
||||
pipeline_type: StableDiffusionXLPipeline
|
||||
cuda: false # Enable for GPU usage (CUDA)
|
||||
scheduler_type: euler_a
|
||||
```
|
||||
|
||||
@ -47,9 +44,9 @@ parameters:
|
||||
backend: diffusers
|
||||
step: 30
|
||||
f16: true
|
||||
cuda: true
|
||||
diffusers:
|
||||
pipeline_type: StableDiffusionPipeline
|
||||
cuda: true
|
||||
enable_parameters: "negative_prompt,num_inference_steps,clip_skip"
|
||||
scheduler_type: "k_dpmpp_sde"
|
||||
cfg_scale: 8
|
||||
@ -69,7 +66,7 @@ The following parameters are available in the configuration file:
|
||||
| `scheduler_type` | Scheduler type | `k_dpp_sde` |
|
||||
| `cfg_scale` | Configuration scale | `8` |
|
||||
| `clip_skip` | Clip skip | None |
|
||||
| `pipeline_type` | Pipeline type | `StableDiffusionPipeline` |
|
||||
| `pipeline_type` | Pipeline type | `AutoPipelineForText2Image` |
|
||||
|
||||
There are available several types of schedulers:
|
||||
|
||||
@ -131,17 +128,16 @@ parameters:
|
||||
model: nitrosocke/Ghibli-Diffusion
|
||||
backend: diffusers
|
||||
step: 25
|
||||
|
||||
cuda: true
|
||||
f16: true
|
||||
diffusers:
|
||||
pipeline_type: StableDiffusionImg2ImgPipeline
|
||||
cuda: true
|
||||
enable_parameters: "negative_prompt,num_inference_steps,image"
|
||||
```
|
||||
|
||||
```bash
|
||||
IMAGE_PATH=/path/to/your/image
|
||||
(echo -n '{"image": "'; base64 $IMAGE_PATH; echo '", "prompt": "a sky background","size": "512x512","model":"stablediffusion-edit"}') |
|
||||
(echo -n '{"file": "'; base64 $IMAGE_PATH; echo '", "prompt": "a sky background","size": "512x512","model":"stablediffusion-edit"}') |
|
||||
curl -H "Content-Type: application/json" -d @- http://localhost:8080/v1/images/generations
|
||||
```
|
||||
|
||||
@ -157,14 +153,67 @@ backend: diffusers
|
||||
step: 50
|
||||
# Force CPU usage
|
||||
f16: true
|
||||
cuda: true
|
||||
diffusers:
|
||||
pipeline_type: StableDiffusionDepth2ImgPipeline
|
||||
cuda: true
|
||||
enable_parameters: "negative_prompt,num_inference_steps,image"
|
||||
cfg_scale: 6
|
||||
```
|
||||
|
||||
```bash
|
||||
(echo -n '{"image": "'; base64 ~/path/to/image.jpeg; echo '", "prompt": "a sky background","size": "512x512","model":"stablediffusion-depth"}') |
|
||||
(echo -n '{"file": "'; base64 ~/path/to/image.jpeg; echo '", "prompt": "a sky background","size": "512x512","model":"stablediffusion-depth"}') |
|
||||
curl -H "Content-Type: application/json" -d @- http://localhost:8080/v1/images/generations
|
||||
```
|
||||
|
||||
## img2vid
|
||||
|
||||
{{% notice note %}}
|
||||
|
||||
Experimental and available only on master builds. See: https://github.com/mudler/LocalAI/pull/1442
|
||||
|
||||
{{% /notice %}}
|
||||
|
||||
```yaml
|
||||
name: img2vid
|
||||
parameters:
|
||||
model: stabilityai/stable-video-diffusion-img2vid
|
||||
backend: diffusers
|
||||
step: 25
|
||||
# Force CPU usage
|
||||
f16: true
|
||||
cuda: true
|
||||
diffusers:
|
||||
pipeline_type: StableVideoDiffusionPipeline
|
||||
```
|
||||
|
||||
```bash
|
||||
(echo -n '{"file": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket.png?download=true","size": "512x512","model":"img2vid"}') |
|
||||
curl -H "Content-Type: application/json" -X POST -d @- http://localhost:8080/v1/images/generations
|
||||
```
|
||||
|
||||
## txt2vid
|
||||
|
||||
{{% notice note %}}
|
||||
|
||||
Experimental and available only on master builds. See: https://github.com/mudler/LocalAI/pull/1442
|
||||
|
||||
{{% /notice %}}
|
||||
|
||||
```yaml
|
||||
name: txt2vid
|
||||
parameters:
|
||||
model: damo-vilab/text-to-video-ms-1.7b
|
||||
backend: diffusers
|
||||
step: 25
|
||||
# Force CPU usage
|
||||
f16: true
|
||||
cuda: true
|
||||
diffusers:
|
||||
pipeline_type: VideoDiffusionPipeline
|
||||
cuda: true
|
||||
```
|
||||
|
||||
```bash
|
||||
(echo -n '{"prompt": "spiderman surfing","size": "512x512","model":"txt2vid"}') |
|
||||
curl -H "Content-Type: application/json" -X POST -d @- http://localhost:8080/v1/images/generations
|
||||
```
|
Loading…
x
Reference in New Issue
Block a user