feat(diffusers): support flux models (#3129)

* feat(diffusers): support flux models

This adds support for FLUX models. For instance:
https://huggingface.co/black-forest-labs/FLUX.1-dev

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* feat(diffusers): support FluxTransformer2DModel

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

* Small fixups

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>

---------

Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
Ettore Di Giacinto 2024-08-11 01:31:53 +02:00 committed by GitHub
parent 7ba4a78fcc
commit 74eaf02484
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 45 additions and 7 deletions

View File

@ -18,13 +18,13 @@ import backend_pb2_grpc
import grpc import grpc
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \ from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
EulerAncestralDiscreteScheduler EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
from diffusers.pipelines.stable_diffusion import safety_checker from diffusers.pipelines.stable_diffusion import safety_checker
from diffusers.utils import load_image, export_to_video from diffusers.utils import load_image, export_to_video
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from optimum.quanto import freeze, qfloat8, quantize
from transformers import CLIPTextModel from transformers import CLIPTextModel, T5EncoderModel
from safetensors.torch import load_file from safetensors.torch import load_file
_ONE_DAY_IN_SECONDS = 60 * 60 * 24 _ONE_DAY_IN_SECONDS = 60 * 60 * 24
@ -163,6 +163,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
modelFile = request.Model modelFile = request.Model
self.cfg_scale = 7 self.cfg_scale = 7
self.PipelineType = request.PipelineType
if request.CFGScale != 0: if request.CFGScale != 0:
self.cfg_scale = request.CFGScale self.cfg_scale = request.CFGScale
@ -244,6 +246,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
torch_dtype=torchType, torch_dtype=torchType,
use_safetensors=True, use_safetensors=True,
variant=variant) variant=variant)
elif request.PipelineType == "FluxPipeline":
self.pipe = FluxPipeline.from_pretrained(
request.Model,
torch_dtype=torch.bfloat16)
if request.LowVRAM:
self.pipe.enable_model_cpu_offload()
elif request.PipelineType == "FluxTransformer2DModel":
dtype = torch.bfloat16
# specify from environment or default to "ChuckMcSneed/FLUX.1-dev"
bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype)
quantize(transformer, weights=qfloat8)
freeze(transformer)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)
self.pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
self.pipe.transformer = transformer
self.pipe.text_encoder_2 = text_encoder_2
if request.LowVRAM:
self.pipe.enable_model_cpu_offload()
if CLIPSKIP and request.CLIPSkip != 0: if CLIPSKIP and request.CLIPSkip != 0:
self.clip_skip = request.CLIPSkip self.clip_skip = request.CLIPSkip
@ -399,6 +425,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
request.seed request.seed
) )
if self.PipelineType == "FluxPipeline":
kwargs["max_sequence_length"] = 256
if self.PipelineType == "FluxTransformer2DModel":
kwargs["output_type"] = "pil"
kwargs["generator"] = torch.Generator("cpu").manual_seed(0)
if self.img2vid: if self.img2vid:
# Load the conditioning image # Load the conditioning image
image = load_image(request.src) image = load_image(request.src)

View File

@ -5,4 +5,5 @@ accelerate
compel compel
peft peft
sentencepiece sentencepiece
torch torch
optimum-quanto

View File

@ -6,4 +6,5 @@ transformers
accelerate accelerate
compel compel
peft peft
sentencepiece sentencepiece
optimum-quanto

View File

@ -5,4 +5,5 @@ transformers
accelerate accelerate
compel compel
peft peft
sentencepiece sentencepiece
optimum-quanto

View File

@ -8,3 +8,4 @@ accelerate
compel compel
peft peft
sentencepiece sentencepiece
optimum-quanto

View File

@ -10,4 +10,5 @@ transformers
accelerate accelerate
compel compel
peft peft
sentencepiece sentencepiece
optimum-quanto