mirror of
https://github.com/mudler/LocalAI.git
synced 2025-02-18 16:40:17 +00:00
fix: Lora loading (#2893)
- Fixed Lora loading Co-authored-by: Alex <alex@akhbar.home>
This commit is contained in:
parent
f521e50fa8
commit
4e84764787
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
from concurrent import futures
|
||||
|
||||
import traceback
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
@ -17,35 +17,39 @@ import backend_pb2_grpc
|
||||
|
||||
import grpc
|
||||
|
||||
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
from diffusers import StableDiffusion3Pipeline, StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, \
|
||||
EulerAncestralDiscreteScheduler
|
||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel, StableVideoDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||
from diffusers.utils import load_image,export_to_video
|
||||
from diffusers.utils import load_image, export_to_video
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
|
||||
from transformers import CLIPTextModel
|
||||
from safetensors.torch import load_file
|
||||
|
||||
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
COMPEL=os.environ.get("COMPEL", "0") == "1"
|
||||
XPU=os.environ.get("XPU", "0") == "1"
|
||||
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
||||
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
|
||||
CHUNK_SIZE=os.environ.get("CHUNK_SIZE", "8")
|
||||
FPS=os.environ.get("FPS", "7")
|
||||
DISABLE_CPU_OFFLOAD=os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
|
||||
FRAMES=os.environ.get("FRAMES", "64")
|
||||
COMPEL = os.environ.get("COMPEL", "0") == "1"
|
||||
XPU = os.environ.get("XPU", "0") == "1"
|
||||
CLIPSKIP = os.environ.get("CLIPSKIP", "1") == "1"
|
||||
SAFETENSORS = os.environ.get("SAFETENSORS", "1") == "1"
|
||||
CHUNK_SIZE = os.environ.get("CHUNK_SIZE", "8")
|
||||
FPS = os.environ.get("FPS", "7")
|
||||
DISABLE_CPU_OFFLOAD = os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1"
|
||||
FRAMES = os.environ.get("FRAMES", "64")
|
||||
|
||||
if XPU:
|
||||
import intel_extension_for_pytorch as ipex
|
||||
|
||||
print(ipex.xpu.get_device_name(0))
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
|
||||
|
||||
# https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287
|
||||
def sc(self, clip_input, images) : return images, [False for i in images]
|
||||
def sc(self, clip_input, images): return images, [False for i in images]
|
||||
|
||||
|
||||
# edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values
|
||||
safety_checker.StableDiffusionSafetyChecker.forward = sc
|
||||
|
||||
@ -62,6 +66,8 @@ from diffusers.schedulers import (
|
||||
PNDMScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
|
||||
|
||||
# The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39
|
||||
# Credits to https://github.com/neggles
|
||||
# See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111
|
||||
@ -136,10 +142,12 @@ def get_scheduler(name: str, config: dict = {}):
|
||||
|
||||
return sched_class.from_config(config)
|
||||
|
||||
|
||||
# Implement the BackendServicer class with the service methods
|
||||
class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
def Health(self, request, context):
|
||||
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
|
||||
|
||||
def LoadModel(self, request, context):
|
||||
try:
|
||||
print(f"Loading model {request.Model}...", file=sys.stderr)
|
||||
@ -149,7 +157,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
if request.F16Memory:
|
||||
torchType = torch.float16
|
||||
variant="fp16"
|
||||
variant = "fp16"
|
||||
|
||||
local = False
|
||||
modelFile = request.Model
|
||||
@ -157,38 +165,38 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.cfg_scale = 7
|
||||
if request.CFGScale != 0:
|
||||
self.cfg_scale = request.CFGScale
|
||||
|
||||
|
||||
clipmodel = "runwayml/stable-diffusion-v1-5"
|
||||
if request.CLIPModel != "":
|
||||
clipmodel = request.CLIPModel
|
||||
clipsubfolder = "text_encoder"
|
||||
if request.CLIPSubfolder != "":
|
||||
clipsubfolder = request.CLIPSubfolder
|
||||
|
||||
|
||||
# Check if ModelFile exists
|
||||
if request.ModelFile != "":
|
||||
if os.path.exists(request.ModelFile):
|
||||
local = True
|
||||
modelFile = request.ModelFile
|
||||
|
||||
|
||||
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
||||
self.img2vid=False
|
||||
self.txt2vid=False
|
||||
self.img2vid = False
|
||||
self.txt2vid = False
|
||||
## img2img
|
||||
if (request.PipelineType == "StableDiffusionImg2ImgPipeline") or (request.IMG2IMG and request.PipelineType == ""):
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType)
|
||||
else:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType)
|
||||
|
||||
elif request.PipelineType == "StableDiffusionDepth2ImgPipeline":
|
||||
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType)
|
||||
## img2vid
|
||||
elif request.PipelineType == "StableVideoDiffusionPipeline":
|
||||
self.img2vid=True
|
||||
self.img2vid = True
|
||||
self.pipe = StableVideoDiffusionPipeline.from_pretrained(
|
||||
request.Model, torch_dtype=torchType, variant=variant
|
||||
)
|
||||
@ -197,64 +205,63 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
## text2img
|
||||
elif request.PipelineType == "AutoPipelineForText2Image" or request.PipelineType == "":
|
||||
self.pipe = AutoPipelineForText2Image.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=SAFETENSORS,
|
||||
variant=variant)
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=SAFETENSORS,
|
||||
variant=variant)
|
||||
elif request.PipelineType == "StableDiffusionPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType)
|
||||
else:
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType)
|
||||
elif request.PipelineType == "DiffusionPipeline":
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType)
|
||||
elif request.PipelineType == "VideoDiffusionPipeline":
|
||||
self.txt2vid=True
|
||||
self.txt2vid = True
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType)
|
||||
torch_dtype=torchType)
|
||||
elif request.PipelineType == "StableDiffusionXLPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True)
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True)
|
||||
else:
|
||||
self.pipe = StableDiffusionXLPipeline.from_pretrained(
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
variant=variant)
|
||||
elif request.PipelineType == "StableDiffusion3Pipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusion3Pipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True)
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True)
|
||||
else:
|
||||
self.pipe = StableDiffusion3Pipeline.from_pretrained(
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
variant=variant)
|
||||
|
||||
if CLIPSKIP and request.CLIPSkip != 0:
|
||||
self.clip_skip = request.CLIPSkip
|
||||
else:
|
||||
self.clip_skip = 0
|
||||
|
||||
|
||||
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
||||
# TODO: this needs to be customized
|
||||
if request.SchedulerType != "":
|
||||
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
||||
|
||||
|
||||
if COMPEL:
|
||||
self.compel = Compel(
|
||||
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2 ],
|
||||
tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2],
|
||||
text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2],
|
||||
returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
|
||||
requires_pooled=[False, True]
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
if request.ControlNet:
|
||||
self.controlnet = ControlNetModel.from_pretrained(
|
||||
@ -263,13 +270,6 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.pipe.controlnet = self.controlnet
|
||||
else:
|
||||
self.controlnet = None
|
||||
|
||||
if request.CUDA:
|
||||
self.pipe.to('cuda')
|
||||
if self.controlnet:
|
||||
self.controlnet.to('cuda')
|
||||
if XPU:
|
||||
self.pipe = self.pipe.to("xpu")
|
||||
# Assume directory from request.ModelFile.
|
||||
# Only if request.LoraAdapter it's not an absolute path
|
||||
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
||||
@ -282,10 +282,17 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.LoraAdapter:
|
||||
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )
|
||||
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
|
||||
self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
|
||||
# self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
|
||||
self.pipe.load_lora_weights(request.LoraAdapter)
|
||||
else:
|
||||
self.pipe.unet.load_attn_procs(request.LoraAdapter)
|
||||
|
||||
if request.CUDA:
|
||||
self.pipe.to('cuda')
|
||||
if self.controlnet:
|
||||
self.controlnet.to('cuda')
|
||||
if XPU:
|
||||
self.pipe = self.pipe.to("xpu")
|
||||
except Exception as err:
|
||||
return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
|
||||
# Implement your logic here for the LoadModel service
|
||||
@ -358,9 +365,9 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
# create a dictionary of values for the parameters
|
||||
options = {
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"width": request.width,
|
||||
"height": request.height,
|
||||
"negative_prompt": request.negative_prompt,
|
||||
"width": request.width,
|
||||
"height": request.height,
|
||||
"num_inference_steps": steps,
|
||||
}
|
||||
|
||||
@ -372,7 +379,7 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
options["image"] = pose_image
|
||||
|
||||
if CLIPSKIP and self.clip_skip != 0:
|
||||
options["clip_skip"]=self.clip_skip
|
||||
options["clip_skip"] = self.clip_skip
|
||||
|
||||
# Get the keys that we will build the args for our pipe for
|
||||
keys = options.keys()
|
||||
@ -416,20 +423,21 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
image = self.pipe(
|
||||
guidance_scale=self.cfg_scale,
|
||||
**kwargs
|
||||
).images[0]
|
||||
).images[0]
|
||||
else:
|
||||
# pass the kwargs dictionary to the self.pipe method
|
||||
image = self.pipe(
|
||||
prompt,
|
||||
guidance_scale=self.cfg_scale,
|
||||
**kwargs
|
||||
).images[0]
|
||||
).images[0]
|
||||
|
||||
# save the result
|
||||
image.save(request.dst)
|
||||
|
||||
return backend_pb2.Result(message="Media generated", success=True)
|
||||
|
||||
|
||||
def serve(address):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS))
|
||||
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
|
||||
@ -453,6 +461,7 @@ def serve(address):
|
||||
except KeyboardInterrupt:
|
||||
server.stop(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run the gRPC server.")
|
||||
parser.add_argument(
|
||||
@ -460,4 +469,4 @@ if __name__ == "__main__":
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
serve(args.addr)
|
||||
serve(args.addr)
|
||||
|
@ -1,5 +1,7 @@
|
||||
setuptools
|
||||
accelerate
|
||||
compel
|
||||
peft
|
||||
diffusers
|
||||
grpcio==1.65.0
|
||||
opencv-python
|
||||
|
Loading…
x
Reference in New Issue
Block a user