mirror of
https://github.com/mudler/LocalAI.git
synced 2024-12-18 20:27:57 +00:00
feat(diffusers): allow multiple lora adapters (#4081)
Signed-off-by: Ettore Di Giacinto <mudler@localai.io>
This commit is contained in:
parent
20cd8814c1
commit
947224b952
@ -235,6 +235,9 @@ message ModelOptions {
|
||||
bool NoKVOffload = 57;
|
||||
|
||||
string ModelPath = 59;
|
||||
|
||||
repeated string LoraAdapters = 60;
|
||||
repeated float LoraScales = 61;
|
||||
}
|
||||
|
||||
message Result {
|
||||
|
@ -311,10 +311,24 @@ class BackendServicer(backend_pb2_grpc.BackendServicer):
|
||||
if request.LoraAdapter:
|
||||
# Check if its a local file and not a directory ( we load lora differently for a safetensor file )
|
||||
if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter):
|
||||
# self.load_lora_weights(request.LoraAdapter, 1, device, torchType)
|
||||
self.pipe.load_lora_weights(request.LoraAdapter)
|
||||
else:
|
||||
self.pipe.unet.load_attn_procs(request.LoraAdapter)
|
||||
if len(request.LoraAdapters) > 0:
|
||||
i = 0
|
||||
adapters_name = []
|
||||
adapters_weights = []
|
||||
for adapter in request.LoraAdapters:
|
||||
if not os.path.isabs(adapter):
|
||||
adapter = os.path.join(request.ModelPath, adapter)
|
||||
self.pipe.load_lora_weights(adapter, adapter_name=f"adapter_{i}")
|
||||
adapters_name.append(f"adapter_{i}")
|
||||
i += 1
|
||||
|
||||
for adapters_weight in request.LoraScales:
|
||||
adapters_weights.append(adapters_weight)
|
||||
|
||||
self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights)
|
||||
|
||||
if request.CUDA:
|
||||
self.pipe.to('cuda')
|
||||
|
@ -125,6 +125,8 @@ func grpcModelOpts(c config.BackendConfig) *pb.ModelOptions {
|
||||
CFGScale: c.Diffusers.CFGScale,
|
||||
LoraAdapter: c.LoraAdapter,
|
||||
LoraScale: c.LoraScale,
|
||||
LoraAdapters: c.LoraAdapters,
|
||||
LoraScales: c.LoraScales,
|
||||
F16Memory: f16,
|
||||
LoraBase: c.LoraBase,
|
||||
IMG2IMG: c.Diffusers.IMG2IMG,
|
||||
|
@ -134,23 +134,25 @@ type LLMConfig struct {
|
||||
TrimSpace []string `yaml:"trimspace"`
|
||||
TrimSuffix []string `yaml:"trimsuffix"`
|
||||
|
||||
ContextSize *int `yaml:"context_size"`
|
||||
NUMA bool `yaml:"numa"`
|
||||
LoraAdapter string `yaml:"lora_adapter"`
|
||||
LoraBase string `yaml:"lora_base"`
|
||||
LoraScale float32 `yaml:"lora_scale"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq"`
|
||||
DraftModel string `yaml:"draft_model"`
|
||||
NDraft int32 `yaml:"n_draft"`
|
||||
Quantization string `yaml:"quantization"`
|
||||
LoadFormat string `yaml:"load_format"`
|
||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
|
||||
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
|
||||
EnforceEager bool `yaml:"enforce_eager"` // vLLM
|
||||
SwapSpace int `yaml:"swap_space"` // vLLM
|
||||
MaxModelLen int `yaml:"max_model_len"` // vLLM
|
||||
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
|
||||
MMProj string `yaml:"mmproj"`
|
||||
ContextSize *int `yaml:"context_size"`
|
||||
NUMA bool `yaml:"numa"`
|
||||
LoraAdapter string `yaml:"lora_adapter"`
|
||||
LoraBase string `yaml:"lora_base"`
|
||||
LoraAdapters []string `yaml:"lora_adapters"`
|
||||
LoraScales []float32 `yaml:"lora_scales"`
|
||||
LoraScale float32 `yaml:"lora_scale"`
|
||||
NoMulMatQ bool `yaml:"no_mulmatq"`
|
||||
DraftModel string `yaml:"draft_model"`
|
||||
NDraft int32 `yaml:"n_draft"`
|
||||
Quantization string `yaml:"quantization"`
|
||||
LoadFormat string `yaml:"load_format"`
|
||||
GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization"` // vLLM
|
||||
TrustRemoteCode bool `yaml:"trust_remote_code"` // vLLM
|
||||
EnforceEager bool `yaml:"enforce_eager"` // vLLM
|
||||
SwapSpace int `yaml:"swap_space"` // vLLM
|
||||
MaxModelLen int `yaml:"max_model_len"` // vLLM
|
||||
TensorParallelSize int `yaml:"tensor_parallel_size"` // vLLM
|
||||
MMProj string `yaml:"mmproj"`
|
||||
|
||||
FlashAttention bool `yaml:"flash_attention"`
|
||||
NoKVOffloading bool `yaml:"no_kv_offloading"`
|
||||
|
Loading…
Reference in New Issue
Block a user