From 7e2d101a465830e74a526a2a1e027da5e020b001 Mon Sep 17 00:00:00 2001 From: BobMaster Date: Mon, 25 Dec 2023 02:24:52 +0800 Subject: [PATCH] fix: guidance_scale not work in sd (#1488) Signed-off-by: hibobmaster <32976627+hibobmaster@users.noreply.github.com> --- backend/python/diffusers/backend_diffusers.py | 45 ++++++++----------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/backend/python/diffusers/backend_diffusers.py b/backend/python/diffusers/backend_diffusers.py index c66b2476..6780cae6 100755 --- a/backend/python/diffusers/backend_diffusers.py +++ b/backend/python/diffusers/backend_diffusers.py @@ -149,9 +149,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): local = False modelFile = request.Model - cfg_scale = 7 + self.cfg_scale = 7 if request.CFGScale != 0: - cfg_scale = request.CFGScale + self.cfg_scale = request.CFGScale clipmodel = "runwayml/stable-diffusion-v1-5" if request.CLIPModel != "": @@ -173,17 +173,14 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): if (request.PipelineType == "StableDiffusionImg2ImgPipeline") or (request.IMG2IMG and request.PipelineType == ""): if fromSingleFile: self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile, - torch_dtype=torchType, - guidance_scale=cfg_scale) + torch_dtype=torchType) else: self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model, - torch_dtype=torchType, - guidance_scale=cfg_scale) + torch_dtype=torchType) elif request.PipelineType == "StableDiffusionDepth2ImgPipeline": self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model, - torch_dtype=torchType, - guidance_scale=cfg_scale) + torch_dtype=torchType) ## img2vid elif request.PipelineType == "StableVideoDiffusionPipeline": self.img2vid=True @@ -197,38 +194,32 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): self.pipe = AutoPipelineForText2Image.from_pretrained(request.Model, torch_dtype=torchType, use_safetensors=SAFETENSORS, - variant=variant, - guidance_scale=cfg_scale) + variant=variant) elif request.PipelineType == "StableDiffusionPipeline": if fromSingleFile: self.pipe = StableDiffusionPipeline.from_single_file(modelFile, - torch_dtype=torchType, - guidance_scale=cfg_scale) + torch_dtype=torchType) else: self.pipe = StableDiffusionPipeline.from_pretrained(request.Model, - torch_dtype=torchType, - guidance_scale=cfg_scale) + torch_dtype=torchType) elif request.PipelineType == "DiffusionPipeline": self.pipe = DiffusionPipeline.from_pretrained(request.Model, - torch_dtype=torchType, - guidance_scale=cfg_scale) + torch_dtype=torchType) elif request.PipelineType == "VideoDiffusionPipeline": self.txt2vid=True self.pipe = DiffusionPipeline.from_pretrained(request.Model, - torch_dtype=torchType, - guidance_scale=cfg_scale) + torch_dtype=torchType) elif request.PipelineType == "StableDiffusionXLPipeline": if fromSingleFile: self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile, - torch_dtype=torchType, use_safetensors=True, - guidance_scale=cfg_scale) + torch_dtype=torchType, + use_safetensors=True) else: self.pipe = StableDiffusionXLPipeline.from_pretrained( request.Model, torch_dtype=torchType, use_safetensors=True, - variant=variant, - guidance_scale=cfg_scale) + variant=variant) if CLIPSKIP and request.CLIPSkip != 0: self.clip_skip = request.CLIPSkip @@ -384,12 +375,12 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): image = image.resize((1024, 576)) generator = torch.manual_seed(request.seed) - frames = self.pipe(image, decode_chunk_size=CHUNK_SIZE, generator=generator).frames[0] + frames = self.pipe(image, guidance_scale=self.cfg_scale, decode_chunk_size=CHUNK_SIZE, generator=generator).frames[0] export_to_video(frames, request.dst, fps=FPS) return backend_pb2.Result(message="Media generated successfully", success=True) if self.txt2vid: - video_frames = self.pipe(prompt, num_inference_steps=steps, num_frames=int(FRAMES)).frames + video_frames = self.pipe(prompt, guidance_scale=self.cfg_scale, num_inference_steps=steps, num_frames=int(FRAMES)).frames export_to_video(video_frames, request.dst) return backend_pb2.Result(message="Media generated successfully", success=True) @@ -398,13 +389,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer): conditioning = self.compel.build_conditioning_tensor(prompt) kwargs["prompt_embeds"]= conditioning # pass the kwargs dictionary to the self.pipe method - image = self.pipe( + image = self.pipe( + guidance_scale=self.cfg_scale, **kwargs ).images[0] else: # pass the kwargs dictionary to the self.pipe method image = self.pipe( - prompt, + prompt, + guidance_scale=self.cfg_scale, **kwargs ).images[0]