mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat(diffusers): update, add autopipeline, controlnet (#1432)
* feat(diffusers): update, add autopipeline, controlenet * tests with AutoPipeline * simplify logic
This commit is contained in:
parent
72325fd0a3
commit
7641f92cde
2
Makefile
2
Makefile
@ -383,7 +383,7 @@ help: ## Show this help.
|
||||
protogen: protogen-go protogen-python
|
||||
|
||||
protogen-go:
|
||||
protoc --go_out=. --go_opt=paths=source_relative --go-grpc_out=. --go-grpc_opt=paths=source_relative \
|
||||
protoc -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \
|
||||
backend/backend.proto
|
||||
|
||||
protogen-python:
|
||||
|
@ -27,6 +27,7 @@ func ImageGeneration(height, width, mode, step, seed int, positive_prompt, negat
|
||||
CLIPModel: c.Diffusers.ClipModel,
|
||||
CLIPSubfolder: c.Diffusers.ClipSubFolder,
|
||||
CLIPSkip: int32(c.Diffusers.ClipSkip),
|
||||
ControlNet: c.Diffusers.ControlNet,
|
||||
}),
|
||||
})
|
||||
|
||||
|
@ -38,8 +38,7 @@ type Config struct {
|
||||
|
||||
// Diffusers
|
||||
Diffusers Diffusers `yaml:"diffusers"`
|
||||
|
||||
Step int `yaml:"step"`
|
||||
Step int `yaml:"step"`
|
||||
|
||||
// GRPC Options
|
||||
GRPC GRPC `yaml:"grpc"`
|
||||
@ -77,6 +76,7 @@ type Diffusers struct {
|
||||
ClipSkip int `yaml:"clip_skip"` // Skip every N frames
|
||||
ClipModel string `yaml:"clip_model"` // Clip model to use
|
||||
ClipSubFolder string `yaml:"clip_subfolder"` // Subfolder to use for clip model
|
||||
ControlNet string `yaml:"control_net"`
|
||||
}
|
||||
|
||||
type LLMConfig struct {
|
||||
|
@ -110,6 +110,7 @@ message ModelOptions {
|
||||
string CLIPModel = 31;
|
||||
string CLIPSubfolder = 32;
|
||||
int32 CLIPSkip = 33;
|
||||
string ControlNet = 48;
|
||||
|
||||
// RWKV
|
||||
string Tokenizer = 34;
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
@ -18,9 +18,9 @@ import backend_pb2_grpc
|
||||
import grpc
|
||||
|
||||
from diffusers import StableDiffusionXLPipeline, StableDiffusionDepth2ImgPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline, DiffusionPipeline, EulerAncestralDiscreteScheduler
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers import StableDiffusionImg2ImgPipeline, AutoPipelineForText2Image, ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import safety_checker
|
||||
|
||||
from diffusers.utils import load_image
|
||||
from compel import Compel
|
||||
|
||||
from transformers import CLIPTextModel
|
||||
@ -30,6 +30,7 @@ from safetensors.torch import load_file
|
||||
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
|
||||
COMPEL=os.environ.get("COMPEL", "1") == "1"
|
||||
CLIPSKIP=os.environ.get("CLIPSKIP", "1") == "1"
|
||||
SAFETENSORS=os.environ.get("SAFETENSORS", "1") == "1"
|
||||
|
||||
# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
|
||||
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
|
||||
@ -135,8 +136,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
print(f"Loading model {request.Model}...", file=sys.stderr)
|
||||
print(f"Request {request}", file=sys.stderr)
|
||||
torchType = torch.float32
|
||||
variant = None
|
||||
|
||||
if request.F16Memory:
|
||||
torchType = torch.float16
|
||||
variant="fp16"
|
||||
|
||||
local = False
|
||||
modelFile = request.Model
|
||||
@ -160,14 +164,8 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
|
||||
fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local
|
||||
|
||||
if request.IMG2IMG and request.PipelineType == "":
|
||||
request.PipelineType == "StableDiffusionImg2ImgPipeline"
|
||||
|
||||
if request.PipelineType == "":
|
||||
request.PipelineType == "StableDiffusionPipeline"
|
||||
|
||||
## img2img
|
||||
if request.PipelineType == "StableDiffusionImg2ImgPipeline":
|
||||
if (request.PipelineType == "StableDiffusionImg2ImgPipeline") or (request.IMG2IMG and request.PipelineType == ""):
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionImg2ImgPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
@ -177,12 +175,18 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "StableDiffusionDepth2ImgPipeline":
|
||||
elif request.PipelineType == "StableDiffusionDepth2ImgPipeline":
|
||||
self.pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
## text2img
|
||||
if request.PipelineType == "StableDiffusionPipeline":
|
||||
elif request.PipelineType == "AutoPipelineForText2Image" or request.PipelineType == "":
|
||||
self.pipe = AutoPipelineForText2Image.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=SAFETENSORS,
|
||||
variant=variant,
|
||||
guidance_scale=cfg_scale)
|
||||
elif request.PipelineType == "StableDiffusionPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType,
|
||||
@ -191,13 +195,11 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
self.pipe = StableDiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "DiffusionPipeline":
|
||||
elif request.PipelineType == "DiffusionPipeline":
|
||||
self.pipe = DiffusionPipeline.from_pretrained(request.Model,
|
||||
torch_dtype=torchType,
|
||||
guidance_scale=cfg_scale)
|
||||
|
||||
if request.PipelineType == "StableDiffusionXLPipeline":
|
||||
elif request.PipelineType == "StableDiffusionXLPipeline":
|
||||
if fromSingleFile:
|
||||
self.pipe = StableDiffusionXLPipeline.from_single_file(modelFile,
|
||||
torch_dtype=torchType, use_safetensors=True,
|
||||
@ -207,21 +209,34 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
request.Model,
|
||||
torch_dtype=torchType,
|
||||
use_safetensors=True,
|
||||
# variant="fp16"
|
||||
variant=variant,
|
||||
guidance_scale=cfg_scale)
|
||||
# https://github.com/huggingface/diffusers/issues/4446
|
||||
# do not use text_encoder in the constructor since then
|
||||
# https://github.com/huggingface/diffusers/issues/3212#issuecomment-1521841481
|
||||
|
||||
if CLIPSKIP and request.CLIPSkip != 0:
|
||||
text_encoder = CLIPTextModel.from_pretrained(clipmodel, num_hidden_layers=request.CLIPSkip, subfolder=clipsubfolder, torch_dtype=torchType)
|
||||
self.pipe.text_encoder=text_encoder
|
||||
self.clip_skip = request.CLIPSkip
|
||||
else:
|
||||
self.clip_skip = 0
|
||||
|
||||
# torch_dtype needs to be customized. float16 for GPU, float32 for CPU
|
||||
# TODO: this needs to be customized
|
||||
if request.SchedulerType != "":
|
||||
self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config)
|
||||
|
||||
self.compel = Compel(tokenizer=self.pipe.tokenizer, text_encoder=self.pipe.text_encoder)
|
||||
|
||||
|
||||
if request.ControlNet:
|
||||
self.controlnet = ControlNetModel.from_pretrained(
|
||||
request.ControlNet, torch_dtype=torchType, variant=variant
|
||||
)
|
||||
self.pipe.controlnet = self.controlnet
|
||||
else:
|
||||
self.controlnet = None
|
||||
|
||||
if request.CUDA:
|
||||
self.pipe.to('cuda')
|
||||
if self.controlnet:
|
||||
self.controlnet.to('cuda')
|
||||
# Assume directory from request.ModelFile.
|
||||
# Only if request.LoraAdapter it's not an absolute path
|
||||
if request.LoraAdapter and request.ModelFile != "" and not os.path.isabs(request.LoraAdapter) and request.LoraAdapter:
|
||||
@ -316,9 +331,15 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
"num_inference_steps": steps,
|
||||
}
|
||||
|
||||
if request.src != "":
|
||||
if request.src != "" and not self.controlnet:
|
||||
image = Image.open(request.src)
|
||||
options["image"] = image
|
||||
elif self.controlnet and request.src:
|
||||
pose_image = load_image(request.src)
|
||||
options["image"] = pose_image
|
||||
|
||||
if CLIPSKIP and self.clip_skip != 0:
|
||||
options["clip_skip"]=self.clip_skip
|
||||
|
||||
# Get the keys that we will build the args for our pipe for
|
||||
keys = options.keys()
|
||||
|
File diff suppressed because one or more lines are too long
@ -53,7 +53,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5", PipelineType="StableDiffusionPipeline"))
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5"))
|
||||
self.assertTrue(response.success)
|
||||
self.assertEqual(response.message, "Model loaded successfully")
|
||||
except Exception as err:
|
||||
@ -71,7 +71,7 @@ class TestBackendServicer(unittest.TestCase):
|
||||
self.setUp()
|
||||
with grpc.insecure_channel("localhost:50051") as channel:
|
||||
stub = backend_pb2_grpc.BackendStub(channel)
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5", PipelineType="StableDiffusionPipeline"))
|
||||
response = stub.LoadModel(backend_pb2.ModelOptions(Model="runwayml/stable-diffusion-v1-5"))
|
||||
print(response.message)
|
||||
self.assertTrue(response.success)
|
||||
image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg")
|
||||
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because it is too large
Load Diff
@ -1,8 +1,8 @@
|
||||
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
|
||||
// versions:
|
||||
// - protoc-gen-go-grpc v1.2.0
|
||||
// - protoc v4.23.4
|
||||
// source: backend/backend.proto
|
||||
// - protoc v3.6.1
|
||||
// source: backend.proto
|
||||
|
||||
package proto
|
||||
|
||||
@ -453,5 +453,5 @@ var Backend_ServiceDesc = grpc.ServiceDesc{
|
||||
ServerStreams: true,
|
||||
},
|
||||
},
|
||||
Metadata: "backend/backend.proto",
|
||||
Metadata: "backend.proto",
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user