mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-19 20:57:54 +00:00
feat(diffusers): support flux models (#3129)
* feat(diffusers): support flux models This adds support for FLUX models. For instance: https://huggingface.co/black-forest-labs/FLUX.1-dev Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * feat(diffusers): support FluxTransformer2DModel Signed-off-by: Ettore Di Giacinto <mudler@localai.io> * Small fixups Signed-off-by: Ettore Di Giacinto <mudler@localai.io> --------- Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
7ba4a78fcc
commit
74eaf02484
@ -18,13 +18,13 @@ import backend_pb2_grpc
|
|||||||
import grpc
|
import grpc
|
||||||
|
|
||||||
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
|
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
|
||||||
EulerAncestralDiscreteScheduler
|
EulerAncestralDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
|
||||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||||
from diffusers.utils import load_image, export_to_video
|
from diffusers.utils import load_image, export_to_video
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
|
from optimum.quanto import freeze, qfloat8, quantize
|
||||||
from transformers import CLIPTextModel
|
from transformers import CLIPTextModel, T5EncoderModel
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||||
@ -163,6 +163,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
modelFile = request.Model
|
modelFile = request.Model
|
||||||
|
|
||||||
self.cfg_scale = 7
|
self.cfg_scale = 7
|
||||||
|
self.PipelineType = request.PipelineType
|
||||||
|
|
||||||
if request.CFGScale != 0:
|
if request.CFGScale != 0:
|
||||||
self.cfg_scale = request.CFGScale
|
self.cfg_scale = request.CFGScale
|
||||||
|
|
||||||
@ -244,6 +246,30 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
torch_dtype=torchType,
|
torch_dtype=torchType,
|
||||||
use_safetensors=True,
|
use_safetensors=True,
|
||||||
variant=variant)
|
variant=variant)
|
||||||
|
elif request.PipelineType == "FluxPipeline":
|
||||||
|
self.pipe = FluxPipeline.from_pretrained(
|
||||||
|
request.Model,
|
||||||
|
torch_dtype=torch.bfloat16)
|
||||||
|
if request.LowVRAM:
|
||||||
|
self.pipe.enable_model_cpu_offload()
|
||||||
|
elif request.PipelineType == "FluxTransformer2DModel":
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
# specify from environment or default to "ChuckMcSneed/FLUX.1-dev"
|
||||||
|
bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev")
|
||||||
|
|
||||||
|
transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype)
|
||||||
|
quantize(transformer, weights=qfloat8)
|
||||||
|
freeze(transformer)
|
||||||
|
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
|
||||||
|
quantize(text_encoder_2, weights=qfloat8)
|
||||||
|
freeze(text_encoder_2)
|
||||||
|
|
||||||
|
self.pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype)
|
||||||
|
self.pipe.transformer = transformer
|
||||||
|
self.pipe.text_encoder_2 = text_encoder_2
|
||||||
|
|
||||||
|
if request.LowVRAM:
|
||||||
|
self.pipe.enable_model_cpu_offload()
|
||||||
|
|
||||||
if CLIPSKIP and request.CLIPSkip != 0:
|
if CLIPSKIP and request.CLIPSkip != 0:
|
||||||
self.clip_skip = request.CLIPSkip
|
self.clip_skip = request.CLIPSkip
|
||||||
@ -399,6 +425,13 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
|||||||
request.seed
|
request.seed
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.PipelineType == "FluxPipeline":
|
||||||
|
kwargs["max_sequence_length"] = 256
|
||||||
|
|
||||||
|
if self.PipelineType == "FluxTransformer2DModel":
|
||||||
|
kwargs["output_type"] = "pil"
|
||||||
|
kwargs["generator"] = torch.Generator("cpu").manual_seed(0)
|
||||||
|
|
||||||
if self.img2vid:
|
if self.img2vid:
|
||||||
# Load the conditioning image
|
# Load the conditioning image
|
||||||
image = load_image(request.src)
|
image = load_image(request.src)
|
||||||
|
@ -5,4 +5,5 @@ accelerate
|
|||||||
compel
|
compel
|
||||||
peft
|
peft
|
||||||
sentencepiece
|
sentencepiece
|
||||||
torch
|
torch
|
||||||
|
optimum-quanto
|
@ -6,4 +6,5 @@ transformers
|
|||||||
accelerate
|
accelerate
|
||||||
compel
|
compel
|
||||||
peft
|
peft
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
optimum-quanto
|
@ -5,4 +5,5 @@ transformers
|
|||||||
accelerate
|
accelerate
|
||||||
compel
|
compel
|
||||||
peft
|
peft
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
optimum-quanto
|
@ -8,3 +8,4 @@ accelerate
|
|||||||
compel
|
compel
|
||||||
peft
|
peft
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
optimum-quanto
|
@ -10,4 +10,5 @@ transformers
|
|||||||
accelerate
|
accelerate
|
||||||
compel
|
compel
|
||||||
peft
|
peft
|
||||||
sentencepiece
|
sentencepiece
|
||||||
|
optimum-quanto
|
Loading…
Reference in New Issue
Block a user