mirror of
https://github.com/ParisNeo/lollms.git
synced 2025-04-05 09:59:08 +00:00
removed useless stuff
This commit is contained in:
parent
5d9afbca17
commit
8a5a585954
@ -618,43 +618,52 @@ class LollmsApplication(LoLLMsCom):
|
||||
self.warning(f"Couldn't load vllm")
|
||||
ASCIIColors.execute_with_animation("Loading TTT services", start_ttt,ASCIIColors.color_blue)
|
||||
|
||||
def start_stt(*args, **kwargs):
|
||||
self.stt = self.load_service_from_folder(self.lollms_paths.services_zoo_path/"stt", self.config.active_stt_service)
|
||||
ASCIIColors.execute_with_animation("Loading loacal STT services", start_stt, ASCIIColors.color_blue)
|
||||
if self.config.active_stt_service:
|
||||
def start_stt(*args, **kwargs):
|
||||
self.stt = self.load_service_from_folder(self.lollms_paths.services_zoo_path/"stt", self.config.active_stt_service)
|
||||
ASCIIColors.execute_with_animation("Loading loacal STT services", start_stt, ASCIIColors.color_blue)
|
||||
|
||||
def start_tts(*args, **kwargs):
|
||||
if self.config.active_tts_service == "xtts":
|
||||
ASCIIColors.yellow("Loading XTTS")
|
||||
try:
|
||||
from lollms.services.tts.xtts.lollms_xtts import LollmsXTTS
|
||||
# def start_tts(*args, **kwargs):
|
||||
# if self.config.active_tts_service == "xtts":
|
||||
# ASCIIColors.yellow("Loading XTTS")
|
||||
# try:
|
||||
# from lollms.services.tts.xtts.lollms_xtts import LollmsXTTS
|
||||
|
||||
self.tts = LollmsXTTS(
|
||||
self
|
||||
)
|
||||
except Exception as ex:
|
||||
trace_exception(ex)
|
||||
self.warning(f"Couldn't load XTTS")
|
||||
if self.config.active_tts_service == "eleven_labs_tts":
|
||||
from lollms.services.tts.eleven_labs_tts.lollms_eleven_labs_tts import LollmsElevenLabsTTS
|
||||
self.tts = LollmsElevenLabsTTS(self)
|
||||
elif self.config.active_tts_service == "openai_tts":
|
||||
from lollms.services.tts.open_ai_tts.lollms_openai_tts import LollmsOpenAITTS
|
||||
self.tts = LollmsOpenAITTS(self)
|
||||
elif self.config.active_tts_service == "fish_tts":
|
||||
from lollms.services.tts.fish.lollms_fish_tts import LollmsFishAudioTTS
|
||||
self.tts = LollmsFishAudioTTS(self)
|
||||
# self.tts = LollmsXTTS(
|
||||
# self
|
||||
# )
|
||||
# except Exception as ex:
|
||||
# trace_exception(ex)
|
||||
# self.warning(f"Couldn't load XTTS")
|
||||
# if self.config.active_tts_service == "eleven_labs_tts":
|
||||
# from lollms.services.tts.eleven_labs_tts.lollms_eleven_labs_tts import LollmsElevenLabsTTS
|
||||
# self.tts = LollmsElevenLabsTTS(self)
|
||||
# elif self.config.active_tts_service == "openai_tts":
|
||||
# from lollms.services.tts.open_ai_tts.lollms_openai_tts import LollmsOpenAITTS
|
||||
# self.tts = LollmsOpenAITTS(self)
|
||||
# elif self.config.active_tts_service == "fish_tts":
|
||||
# from lollms.services.tts.fish.lollms_fish_tts import LollmsFishAudioTTS
|
||||
# self.tts = LollmsFishAudioTTS(self)
|
||||
if self.config.active_tts_service:
|
||||
def start_tts(*args, **kwargs):
|
||||
self.tti = self.load_service_from_folder(self.lollms_paths.services_zoo_path/"tts", self.config.active_tts_service)
|
||||
ASCIIColors.execute_with_animation("Loading TTS services", start_tts, ASCIIColors.color_blue)
|
||||
|
||||
ASCIIColors.execute_with_animation("Loading TTS services", start_tts, ASCIIColors.color_blue)
|
||||
if self.config.active_tti_service:
|
||||
def start_tti(*args, **kwargs):
|
||||
self.tti = self.load_service_from_folder(self.lollms_paths.services_zoo_path/"tti", self.config.active_tti_service)
|
||||
ASCIIColors.execute_with_animation("Loading loacal TTI services", start_tti, ASCIIColors.color_blue)
|
||||
|
||||
def start_tti(*args, **kwargs):
|
||||
self.tti = self.load_service_from_folder(self.lollms_paths.services_zoo_path/"tti", self.config.active_tti_service)
|
||||
ASCIIColors.execute_with_animation("Loading loacal TTI services", start_tti, ASCIIColors.color_blue)
|
||||
if self.config.active_ttm_service:
|
||||
def start_ttm(*args, **kwargs):
|
||||
self.ttv = self.load_service_from_folder(self.lollms_paths.services_zoo_path/"ttv", self.config.active_ttm_service)
|
||||
ASCIIColors.execute_with_animation("Loading loacal TTM services", start_ttm, ASCIIColors.color_blue)
|
||||
print("OK")
|
||||
|
||||
def start_ttv(*args, **kwargs):
|
||||
self.ttv = self.load_service_from_folder(self.lollms_paths.services_zoo_path/"ttv", self.config.active_ttv_service)
|
||||
|
||||
|
||||
ASCIIColors.execute_with_animation("Loading loacal TTV services", start_ttv, ASCIIColors.color_blue)
|
||||
if self.config.active_ttv_service:
|
||||
def start_ttv(*args, **kwargs):
|
||||
self.ttv = self.load_service_from_folder(self.lollms_paths.services_zoo_path/"ttv", self.config.active_ttv_service)
|
||||
ASCIIColors.execute_with_animation("Loading loacal TTV services", start_ttv, ASCIIColors.color_blue)
|
||||
print("OK")
|
||||
|
||||
|
||||
|
@ -22,7 +22,7 @@ services_zoo_repo = "https://github.com/ParisNeo/lollms_services_zoo.git"
|
||||
functions_zoo_repo = "https://github.com/ParisNeo/lollms_functions_zoo.git"
|
||||
gptqlora_repo = "https://github.com/ParisNeo/gptqlora.git"
|
||||
|
||||
lollms_webui_version = "v19 (codename Omni 🔗)"
|
||||
lollms_webui_version = "v19.2242 (codename Twins 🔗)"
|
||||
|
||||
# Now we speify the personal folders
|
||||
class LollmsPaths:
|
||||
|
@ -1,487 +0,0 @@
|
||||
# Title LollmsComfyUI
|
||||
# Licence: GPL-3.0
|
||||
# Author : Paris Neo
|
||||
# Forked from comfyanonymous's Comfyui nodes system
|
||||
# check it out : https://github.com/comfyanonymous/ComfyUI
|
||||
# Here is a copy of the LICENCE https://github.com/comfyanonymous/ComfyUI/blob/main/LICENSE
|
||||
# All rights are reserved
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
|
||||
import time
|
||||
import io
|
||||
import sys
|
||||
import requests
|
||||
import os
|
||||
import base64
|
||||
import subprocess
|
||||
import time
|
||||
import json
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from PIL import Image, PngImagePlugin
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any
|
||||
import uuid
|
||||
from ascii_colors import ASCIIColors, trace_exception
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.utilities import git_pull, show_yes_no_dialog, PackageManager
|
||||
from lollms.tti import LollmsTTI
|
||||
import subprocess
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
import threading
|
||||
import pipmaster as pm
|
||||
|
||||
if not pm.is_installed("websocket"):
|
||||
pm.install("websocket-client")
|
||||
import websocket
|
||||
if not pm.is_installed("urllib"):
|
||||
pm.install("urllib")
|
||||
from urllib import request, parse
|
||||
|
||||
|
||||
|
||||
def download_file(url, folder_path, local_filename):
|
||||
# Make sure 'folder_path' exists
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True)
|
||||
with open(folder_path / local_filename, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
progress_bar.close()
|
||||
|
||||
return local_filename
|
||||
|
||||
|
||||
|
||||
|
||||
def get_comfyui(lollms_paths:LollmsPaths):
|
||||
root_dir = lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
comfyui_folder = shared_folder / "comfyui"
|
||||
comfyui_script_path = comfyui_folder / "main.py"
|
||||
git_pull(comfyui_folder)
|
||||
|
||||
if comfyui_script_path.exists():
|
||||
ASCIIColors.success("comfyui found.")
|
||||
ASCIIColors.success("Loading source file...",end="")
|
||||
# use importlib to load the module from the file path
|
||||
from lollms.services.tti.comfyui.lollms_comfyui import LollmsComfyUI
|
||||
ASCIIColors.success("ok")
|
||||
return LollmsComfyUI
|
||||
|
||||
class LollmsComfyUI(LollmsTTI):
|
||||
has_controlnet = False
|
||||
def __init__(self, app:LollmsApplication, output_folder:str|Path=None):
|
||||
"""
|
||||
Initializes the LollmsDalle binding.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for authentication.
|
||||
output_folder (Path|str): The output folder where to put the generated data
|
||||
"""
|
||||
service_config = TypedConfig(
|
||||
ConfigTemplate([
|
||||
{
|
||||
"name": "base_url",
|
||||
"type": "str",
|
||||
"value": "http://127.0.0.1:8188/",
|
||||
"help": "The base URL for the service. This is the address where the service is hosted (e.g., http://127.0.0.1:8188/)."
|
||||
},
|
||||
{
|
||||
"name": "wm",
|
||||
"type": "str",
|
||||
"value": "lollms",
|
||||
"help": "Watermarking text or identifier to be used in the service."
|
||||
},
|
||||
{
|
||||
"name": "max_retries",
|
||||
"type": "int",
|
||||
"value": 50,
|
||||
"help": "The maximum number of retries to attempt before determining that the service is unavailable."
|
||||
},
|
||||
{
|
||||
"name": "local_service",
|
||||
"type": "bool",
|
||||
"value": False,
|
||||
"help": "If set to true, a local instance of the service will be installed and used."
|
||||
},
|
||||
{
|
||||
"name": "start_service_at_startup",
|
||||
"type": "bool",
|
||||
"value": False,
|
||||
"help": "If set to true, the service will automatically start at startup. This also enables the local service option."
|
||||
},
|
||||
{
|
||||
"name": "share",
|
||||
"type": "bool",
|
||||
"value": False,
|
||||
"help": "If set to true, the server will be accessible from outside your local machine (e.g., over the internet)."
|
||||
}
|
||||
]),
|
||||
BaseConfig(config={
|
||||
})
|
||||
)
|
||||
super().__init__("comfyui",app, service_config)
|
||||
# Get the current directory
|
||||
lollms_paths = app.lollms_paths
|
||||
self.app = app
|
||||
root_dir = lollms_paths.personal_path
|
||||
|
||||
# If this is requiring a local service then verify if it is on
|
||||
if self.service_config.local_service:
|
||||
if not self.verify_comfyui():
|
||||
self.install()
|
||||
|
||||
self.comfyui_url = self.service_config.base_url+"/comfyuiapi/v1"
|
||||
shared_folder = root_dir/"shared"
|
||||
self.comfyui_folder = shared_folder / "comfyui"
|
||||
self.output_dir = root_dir / "outputs/comfyui"
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ASCIIColors.red(" _ ____ _ _ __ __ _____ _____ __ _ ")
|
||||
ASCIIColors.red("| | / __ \| | | | | \/ |/ ____|/ ____| / _| (_)")
|
||||
ASCIIColors.red("| | | | | | | | | | \ / | (___ | | ___ _ __ ___ | |_ _ _ _ _ _ ")
|
||||
ASCIIColors.red("| | | | | | | | | | |\/| |\___ \| | / _ \| '_ ` _ \| _| | | | | | | |")
|
||||
ASCIIColors.red("| |___| |__| | |____| |____| | | |____) | |___| (_) | | | | | | | | |_| | |_| | |")
|
||||
ASCIIColors.red("|______\____/|______|______|_| |_|_____/ \_____\___/|_| |_| |_|_| \__, |\__,_|_|")
|
||||
ASCIIColors.red(" ______ __/ | ")
|
||||
ASCIIColors.red(" |______| |___/ ")
|
||||
|
||||
ASCIIColors.red(" Forked from comfyanonymous's Comfyui nodes system")
|
||||
ASCIIColors.red(" Integration in lollms by ParisNeo")
|
||||
|
||||
if not self.wait_for_service(1,False) and self.service_config.local_service and self.service_config.start_service_at_startup and self.service_config.base_url is None:
|
||||
ASCIIColors.info("Loading lollms_comfyui")
|
||||
if platform.system() == "Windows":
|
||||
ASCIIColors.info("Running on windows")
|
||||
script_path = self.comfyui_folder / "main.py"
|
||||
|
||||
if self.service_config.share:
|
||||
pass # TODO: implement
|
||||
#run_python_script_in_env("comfyui", str(script_path), cwd=self.comfyui_folder, wait=False)
|
||||
# subprocess.Popen("conda activate " + str(script_path) +" --share", cwd=self.comfyui_folder)
|
||||
else:
|
||||
pass # TODO: implement
|
||||
# run_python_script_in_env("comfyui", str(script_path), cwd=self.comfyui_folder, wait=False)
|
||||
# subprocess.Popen(script_path, cwd=self.comfyui_folder)
|
||||
else:
|
||||
ASCIIColors.info("Running on linux/MacOs")
|
||||
script_path = str(self.comfyui_folder / "lollms_comfyui.sh")
|
||||
ASCIIColors.info(f"launcher path: {script_path}")
|
||||
ASCIIColors.info(f"comfyui path: {self.comfyui_folder}")
|
||||
|
||||
if self.service_config.share:
|
||||
pass # TODO: implement
|
||||
# run_script_in_env("comfyui","bash " + script_path +" --share", cwd=self.comfyui_folder)
|
||||
# subprocess.Popen("conda activate " + str(script_path) +" --share", cwd=self.comfyui_folder)
|
||||
else:
|
||||
pass # TODO: implement
|
||||
# run_script_in_env("comfyui","bash " + script_path, cwd=self.comfyui_folder)
|
||||
ASCIIColors.info("Process done")
|
||||
ASCIIColors.success("Launching Comfyui succeeded")
|
||||
|
||||
# Wait until the service is available at http://127.0.0.1:8188//
|
||||
|
||||
def verify_comfyui(self):
|
||||
# Clone repository
|
||||
root_dir = self.app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
comfyui_folder = shared_folder / "comfyui"
|
||||
return comfyui_folder.exists()
|
||||
|
||||
def install(self):
|
||||
root_dir = self.app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
comfyui_folder = shared_folder / "comfyui"
|
||||
if comfyui_folder.exists():
|
||||
if show_yes_no_dialog("warning!","I have detected that there is a previous installation of Comfyui.\nShould I remove it and continue installing?"):
|
||||
shutil.rmtree(comfyui_folder)
|
||||
elif show_yes_no_dialog("warning!","Continue installation?"):
|
||||
ASCIIColors.cyan("Installing comfyui conda environment with python 3.10")
|
||||
create_conda_env("comfyui","3.10")
|
||||
ASCIIColors.cyan("Done")
|
||||
return
|
||||
else:
|
||||
return
|
||||
|
||||
subprocess.run(["git", "clone", "https://github.com/ParisNeo/ComfyUI.git", str(comfyui_folder)])
|
||||
subprocess.run(["git", "clone", "https://github.com/ParisNeo/ComfyUI-Manager.git", str(comfyui_folder/"custom_nodes/ComfyUI-Manager")])
|
||||
subprocess.run(["git", "clone", "https://github.com/AlekPet/ComfyUI_Custom_Nodes_AlekPet.git", str(comfyui_folder/"custom_nodes/ComfyUI_Custom_Nodes_AlekPet")])
|
||||
subprocess.run(["git", "clone", "https://github.com/ParisNeo/lollms_nodes_suite.git", str(comfyui_folder/"custom_nodes/lollms_nodes_suite")])
|
||||
|
||||
|
||||
subprocess.run(["git", "clone", "https://github.com/jags111/efficiency-nodes-comfyui.git", str(comfyui_folder/"custom_nodes/efficiency-nodes-comfyui")])
|
||||
subprocess.run(["git", "clone", "https://github.com/Kosinkadink/ComfyUI-Advanced-ControlNet.git", str(comfyui_folder/"custom_nodes/ComfyUI-Advanced-ControlNet")])
|
||||
subprocess.run(["git", "clone", "https://github.com/Kosinkadink/ComfyUI-VideoHelperSuite.git", str(comfyui_folder/"custom_nodes/ComfyUI-VideoHelperSuite")])
|
||||
subprocess.run(["git", "clone", "https://github.com/LykosAI/ComfyUI-Inference-Core-Nodes.git", str(comfyui_folder/"custom_nodes/ComfyUI-Inference-Core-Nodes")])
|
||||
subprocess.run(["git", "clone", "https://github.com/Fannovel16/comfyui_controlnet_aux.git", str(comfyui_folder/"custom_nodes/comfyui_controlnet_aux")])
|
||||
subprocess.run(["git", "clone", "https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved.git", str(comfyui_folder/"custom_nodes/ComfyUI-AnimateDiff-Evolved")])
|
||||
|
||||
if show_yes_no_dialog("warning!","You will need to install an image generation model.\nDo you want to install an image model from civitai?\nI suggest Juggernaut XL.\nIt is a very good model.\nyou can always install more models afterwards in your comfyui folder/models.checkpoints"):
|
||||
download_file("https://civitai.com/api/download/models/357609", comfyui_folder/"models/checkpoints","Juggernaut_XL.safetensors")
|
||||
|
||||
if show_yes_no_dialog("warning!","Do you want to install a video model from hugging face?\nIsuggest SVD XL."):
|
||||
download_file("https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt/blob/main/svd_xt.safetensors", comfyui_folder/"models/checkpoints","svd_xt.safetensors")
|
||||
|
||||
if show_yes_no_dialog("warning!","Do you want to install all control net models?"):
|
||||
(comfyui_folder/"models/controlnet").mkdir(parents=True, exist_ok=True)
|
||||
download_file("https://huggingface.co/thibaud/controlnet-openpose-sdxl-1.0/resolve/main/OpenPoseXL2.safetensors", comfyui_folder/"models/controlnet","OpenPoseXL2.safetensors")
|
||||
download_file("https://huggingface.co/diffusers/controlnet-depth-sdxl-1.0/resolve/main/diffusion_pytorch_model.safetensors", comfyui_folder/"models/controlnet","DepthMap_XL.safetensors")
|
||||
|
||||
|
||||
if show_yes_no_dialog("warning!","Do you want to install all animation models?"):
|
||||
(comfyui_folder/"models/animatediff_models").mkdir(parents=True, exist_ok=True)
|
||||
download_file("https://huggingface.co/guoyww/animatediff/resolve/cd71ae134a27ec6008b968d6419952b0c0494cf2/mm_sdxl_v10_beta.ckpt", comfyui_folder/"models/animatediff_models","mm_sdxl_v10_beta.ckpt")
|
||||
|
||||
# TODO: fix
|
||||
# create_conda_env("comfyui","3.10")
|
||||
# if self.app.config.hardware_mode in ["nvidia", "nvidia-tensorcores"]:
|
||||
# run_python_script_in_env("comfyui", "-m pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121")
|
||||
# if self.app.config.hardware_mode in ["amd", "amd-noavx"]:
|
||||
# run_python_script_in_env("comfyui", "-m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.7")
|
||||
# elif self.app.config.hardware_mode in ["cpu", "cpu-noavx"]:
|
||||
# run_python_script_in_env("comfyui", "-m pip install --pre torch torchvision torchaudio")
|
||||
# run_python_script_in_env("comfyui", f"-m pip install -r {comfyui_folder}/requirements.txt")
|
||||
|
||||
self.app.comfyui = LollmsComfyUI(self.app)
|
||||
ASCIIColors.green("Comfyui installed successfully")
|
||||
self.app.HideBlockingMessage()
|
||||
|
||||
|
||||
def upgrade(self):
|
||||
root_dir = self.app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
comfyui_folder = shared_folder / "comfyui"
|
||||
if not comfyui_folder.exists():
|
||||
self.app.InfoMessage("Comfyui is not installed, install it first")
|
||||
return
|
||||
|
||||
subprocess.run(["git", "pull", str(comfyui_folder)])
|
||||
subprocess.run(["git", "pull", str(comfyui_folder/"custom_nodes/ComfyUI-Manager")])
|
||||
subprocess.run(["git", "pull", str(comfyui_folder/"custom_nodes/efficiency-nodes-comfyui")])
|
||||
ASCIIColors.success("DONE")
|
||||
|
||||
|
||||
def wait_for_service_in_another_thread(self, max_retries=150, show_warning=True):
|
||||
thread = threading.Thread(target=self.wait_for_service, args=(max_retries, show_warning))
|
||||
thread.start()
|
||||
return thread
|
||||
|
||||
@staticmethod
|
||||
def get_models_list(app):
|
||||
return [str(f.name) for f in (app.lollms_paths.personal_path/"shared"/"comfyui"/"models"/"checkpoints").iterdir()]
|
||||
|
||||
def wait_for_service(self, max_retries = 50, show_warning=True):
|
||||
url = f"{self.comfyui_base_url}"
|
||||
# Adjust this value as needed
|
||||
retries = 0
|
||||
|
||||
while retries < max_retries or max_retries<0:
|
||||
try:
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
print("Service is available.")
|
||||
if self.app is not None:
|
||||
self.app.success("Comfyui Service is now available.")
|
||||
return True
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
retries += 1
|
||||
time.sleep(1)
|
||||
if show_warning:
|
||||
print("Service did not become available within the given time.")
|
||||
if self.app is not None:
|
||||
self.app.error("Comfyui Service did not become available within the given time.")
|
||||
return False
|
||||
|
||||
def paint(
|
||||
self,
|
||||
positive_prompt,
|
||||
negative_prompt,
|
||||
sampler_name="",
|
||||
seed=-1,
|
||||
scale=7.5,
|
||||
steps=20,
|
||||
img2img_denoising_strength=0.9,
|
||||
width=512,
|
||||
height=512,
|
||||
restore_faces=True,
|
||||
output_path=None
|
||||
):
|
||||
if output_path is None:
|
||||
output_path = self.output_dir
|
||||
client_id = str(uuid.uuid4())
|
||||
url = self.comfyui_base_url[7:-1]
|
||||
|
||||
def queue_prompt(prompt):
|
||||
p = {"prompt": prompt, "client_id": client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
full_url = "http://{}/prompt".format(url)
|
||||
req = request.Request(full_url, data=data)
|
||||
output = request.urlopen(req).read()
|
||||
return json.loads(output)
|
||||
|
||||
def get_image(filename, subfolder):
|
||||
data = {"filename": filename, "subfolder": subfolder}
|
||||
url_values = parse.urlencode(data)
|
||||
full_url = "http://{}/view?{}".format(url, url_values)
|
||||
with request.urlopen(full_url) as response:
|
||||
return response.read()
|
||||
|
||||
def get_history(prompt_id):
|
||||
url_values = "http://{}/history/{}".format(url, prompt_id)
|
||||
with request.urlopen(url_values) as response:
|
||||
return json.loads(response.read())
|
||||
|
||||
def get_images(ws, prompt):
|
||||
prompt_id = queue_prompt(prompt)['prompt_id']
|
||||
output_images = {}
|
||||
while True:
|
||||
out = ws.recv()
|
||||
if isinstance(out, str):
|
||||
message = json.loads(out)
|
||||
if message['type'] == 'executing':
|
||||
data = message['data']
|
||||
if data['node'] is None and data['prompt_id'] == prompt_id:
|
||||
break #Execution is done
|
||||
else:
|
||||
continue #previews are binary data
|
||||
|
||||
history = get_history(prompt_id)[prompt_id]
|
||||
for o in history['outputs']:
|
||||
for node_id in history['outputs']:
|
||||
node_output = history['outputs'][node_id]
|
||||
if 'images' in node_output:
|
||||
images_output = []
|
||||
for image in node_output['images']:
|
||||
if image["type"]=="output":
|
||||
image_data = get_image(image['filename'], image['subfolder'])
|
||||
images_output.append(image_data)
|
||||
|
||||
return images_output
|
||||
|
||||
def save_images(images:dict, folder_path:str|Path):
|
||||
# Create the folder if it doesn't exist
|
||||
folder = Path(folder_path)
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Save each image to the folder
|
||||
for i, img_data in enumerate(images):
|
||||
img_path = folder / f'image_{i+1}.png'
|
||||
with open(img_path, 'wb') as img_file:
|
||||
img_file.write(img_data)
|
||||
|
||||
# Return the path to the first image
|
||||
return str(folder / 'image_1.png')
|
||||
prompt_text = """
|
||||
{
|
||||
"1": {
|
||||
"inputs": {
|
||||
"base_ckpt_name": "juggernaut.safetensors",
|
||||
"base_clip_skip": -2,
|
||||
"refiner_ckpt_name": "None",
|
||||
"refiner_clip_skip": -2,
|
||||
"positive_ascore": 6,
|
||||
"negative_ascore": 2,
|
||||
"vae_name": "Baked VAE",
|
||||
"positive": "smart robot icon, slick, flat design, high res, W in the center, black background",
|
||||
"negative": "ugly, deformed, badly rendered, fuzzy",
|
||||
"token_normalization": "none",
|
||||
"weight_interpretation": "comfy",
|
||||
"empty_latent_width": 1024,
|
||||
"empty_latent_height": 1024,
|
||||
"batch_size": 1
|
||||
},
|
||||
"class_type": "Eff. Loader SDXL",
|
||||
"_meta": {
|
||||
"title": "Eff. Loader SDXL"
|
||||
}
|
||||
},
|
||||
"2": {
|
||||
"inputs": {
|
||||
"noise_seed": 74738751167752,
|
||||
"steps": 20,
|
||||
"cfg": 7,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "normal",
|
||||
"start_at_step": 0,
|
||||
"refine_at_step": -1,
|
||||
"preview_method": "auto",
|
||||
"vae_decode": "true",
|
||||
"sdxl_tuple": [
|
||||
"1",
|
||||
0
|
||||
],
|
||||
"latent_image": [
|
||||
"1",
|
||||
1
|
||||
],
|
||||
"optional_vae": [
|
||||
"1",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "KSampler SDXL (Eff.)",
|
||||
"_meta": {
|
||||
"title": "KSampler SDXL (Eff.)"
|
||||
}
|
||||
},
|
||||
"3": {
|
||||
"inputs": {
|
||||
"filename_prefix": "ComfyUI",
|
||||
"images": [
|
||||
"2",
|
||||
3
|
||||
]
|
||||
},
|
||||
"class_type": "SaveImage",
|
||||
"_meta": {
|
||||
"title": "Save Image"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
prompt = json.loads(prompt_text)
|
||||
#set the text prompt for our positive CLIPTextEncode
|
||||
prompt["1"]["inputs"]["positive"] = prompt_text
|
||||
prompt["1"]["inputs"]["negative"] = negative_prompt
|
||||
prompt["1"]["inputs"]["empty_latent_width"] = width
|
||||
prompt["1"]["inputs"]["empty_latent_height"] = height
|
||||
|
||||
prompt["1"]["inputs"]["base_ckpt_name"] = self.app.config.comfyui_model
|
||||
|
||||
ws = websocket.WebSocket()
|
||||
ws.connect("ws://{}/ws?clientId={}".format(url, client_id))
|
||||
images = get_images(ws, prompt)
|
||||
|
||||
return save_images(images, output_path), {"prompt":prompt,"negative_prompt":negative_prompt}
|
||||
|
||||
|
||||
def paint_from_images(self, positive_prompt: str,
|
||||
images: List[str],
|
||||
negative_prompt: str = "",
|
||||
sampler_name="",
|
||||
seed=-1,
|
||||
scale=7.5,
|
||||
steps=20,
|
||||
img2img_denoising_strength=0.9,
|
||||
width=512,
|
||||
height=512,
|
||||
restore_faces=True,
|
||||
output_path=None
|
||||
) -> List[Dict[str, str]]:
|
||||
return None
|
||||
|
@ -1,204 +0,0 @@
|
||||
# Title LollmsDalle
|
||||
# Licence: Apache 2.0
|
||||
# Author : Paris Neo
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
|
||||
import time
|
||||
import io
|
||||
import sys
|
||||
import requests
|
||||
import os
|
||||
import base64
|
||||
import subprocess
|
||||
import time
|
||||
import json
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from PIL import Image, PngImagePlugin
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ascii_colors import ASCIIColors, trace_exception
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.utilities import PackageManager, find_next_available_filename
|
||||
from lollms.tti import LollmsTTI
|
||||
import subprocess
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
import threading
|
||||
from io import BytesIO
|
||||
import os
|
||||
|
||||
|
||||
class LollmsDalle(LollmsTTI):
|
||||
def __init__(self, app, output_folder:str|Path=None):
|
||||
"""
|
||||
Initializes the LollmsDalle binding.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for authentication.
|
||||
output_folder (Path|str): The output folder where to put the generated data
|
||||
"""
|
||||
# Check for the OPENAI_KEY environment variable if no API key is provided
|
||||
api_key = os.getenv("OPENAI_KEY","")
|
||||
service_config = TypedConfig(
|
||||
ConfigTemplate([
|
||||
{"name":"api_key", "type":"str", "value":api_key, "help":"A valid Open AI key to generate text using anthropic api"},
|
||||
{"name":"generation_engine", "type":"str", "value":"dall-e-3", "options":["dall-e-2","dall-e-3"], "help":"The engine to be used"},
|
||||
]),
|
||||
BaseConfig(config={
|
||||
"api_key": "", # use avx2
|
||||
})
|
||||
)
|
||||
|
||||
super().__init__("dall-e", app, service_config, output_folder)
|
||||
|
||||
def settings_updated(self):
|
||||
pass
|
||||
|
||||
def paint(
|
||||
self,
|
||||
positive_prompt,
|
||||
negative_prompt,
|
||||
sampler_name="Euler",
|
||||
seed=-1,
|
||||
scale=7.5,
|
||||
steps=20,
|
||||
img2img_denoising_strength=0.9,
|
||||
width=512,
|
||||
height=512,
|
||||
restore_faces=True,
|
||||
output_path=None,
|
||||
generation_engine=None
|
||||
):
|
||||
if output_path is None:
|
||||
output_path = self.output_path
|
||||
if generation_engine is None:
|
||||
generation_engine = self.service_config.generation_engine
|
||||
if not PackageManager.check_package_installed("openai"):
|
||||
PackageManager.install_package("openai")
|
||||
import openai
|
||||
openai.api_key = self.service_config.api_key
|
||||
if generation_engine=="dall-e-2":
|
||||
supported_resolutions = [
|
||||
[512, 512],
|
||||
[1024, 1024],
|
||||
]
|
||||
# Find the closest resolution
|
||||
closest_resolution = min(supported_resolutions, key=lambda res: abs(res[0] - width) + abs(res[1] - height))
|
||||
|
||||
else:
|
||||
supported_resolutions = [
|
||||
[1024, 1024],
|
||||
[1024, 1792],
|
||||
[1792, 1024]
|
||||
]
|
||||
# Find the closest resolution
|
||||
if width>height:
|
||||
closest_resolution = [1792, 1024]
|
||||
elif width<height:
|
||||
closest_resolution = [1024, 1792]
|
||||
else:
|
||||
closest_resolution = [1024, 1024]
|
||||
|
||||
|
||||
# Update the width and height
|
||||
width = closest_resolution[0]
|
||||
height = closest_resolution[1]
|
||||
|
||||
|
||||
response = openai.images.generate(
|
||||
model=generation_engine,
|
||||
prompt=positive_prompt.strip(),
|
||||
quality="standard",
|
||||
size=f"{width}x{height}",
|
||||
n=1,
|
||||
|
||||
)
|
||||
# download image to outputs
|
||||
output_dir = Path(output_path)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
image_url = response.data[0].url
|
||||
|
||||
# Get the image data from the URL
|
||||
response = requests.get(image_url)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Generate the full path for the image file
|
||||
file_name = output_dir/find_next_available_filename(output_dir, "img_dalle_") # You can change the filename if needed
|
||||
|
||||
# Save the image to the specified folder
|
||||
with open(file_name, "wb") as file:
|
||||
file.write(response.content)
|
||||
ASCIIColors.yellow(f"Image saved to {file_name}")
|
||||
else:
|
||||
ASCIIColors.red("Failed to download the image")
|
||||
|
||||
return file_name, {"positive_prompt":positive_prompt}
|
||||
|
||||
def paint_from_images(self, positive_prompt: str, images: List[str], negative_prompt: str = "") -> List[Dict[str, str]]:
|
||||
if output_path is None:
|
||||
output_path = self.output_path
|
||||
if not PackageManager.check_package_installed("openai"):
|
||||
PackageManager.install_package("openai")
|
||||
import openai
|
||||
openai.api_key = self.service_config.api_key
|
||||
generation_engine="dall-e-2"
|
||||
supported_resolutions = [
|
||||
[512, 512],
|
||||
[1024, 1024],
|
||||
]
|
||||
# Find the closest resolution
|
||||
closest_resolution = min(supported_resolutions, key=lambda res: abs(res[0] - width) + abs(res[1] - height))
|
||||
|
||||
|
||||
|
||||
# Update the width and height
|
||||
width = closest_resolution[0]
|
||||
height = closest_resolution[1]
|
||||
|
||||
# Read the image file from disk and resize it
|
||||
image = Image.open(images[0])
|
||||
width, height = width, height
|
||||
image = image.resize((width, height))
|
||||
|
||||
# Convert the image to a BytesIO object
|
||||
byte_stream = BytesIO()
|
||||
image.save(byte_stream, format='PNG')
|
||||
byte_array = byte_stream.getvalue()
|
||||
response = openai.images.create_variation(
|
||||
image=byte_array,
|
||||
n=1,
|
||||
model=generation_engine, # for now only dalle 2 supports variations
|
||||
size=f"{width}x{height}"
|
||||
)
|
||||
# download image to outputs
|
||||
output_dir = Path(output_path)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
image_url = response.data[0].url
|
||||
|
||||
# Get the image data from the URL
|
||||
response = requests.get(image_url)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Generate the full path for the image file
|
||||
file_name = output_dir/find_next_available_filename(output_dir, "img_dalle_") # You can change the filename if needed
|
||||
|
||||
# Save the image to the specified folder
|
||||
with open(file_name, "wb") as file:
|
||||
file.write(response.content)
|
||||
ASCIIColors.yellow(f"Image saved to {file_name}")
|
||||
else:
|
||||
ASCIIColors.red("Failed to download the image")
|
||||
|
||||
return file_name
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get(app:LollmsApplication):
|
||||
return LollmsDalle
|
@ -1,324 +0,0 @@
|
||||
# Title LollmsDiffusers
|
||||
# Licence: MIT
|
||||
# Author : Paris Neo
|
||||
# All rights are reserved
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.utilities import PackageManager, check_and_install_torch, find_next_available_filename
|
||||
|
||||
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
|
||||
import sys
|
||||
import requests
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ascii_colors import ASCIIColors, trace_exception
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.tti import LollmsTTI
|
||||
from lollms.utilities import git_pull
|
||||
from tqdm import tqdm
|
||||
|
||||
import pipmaster as pm
|
||||
if not pm.is_installed("torch"):
|
||||
ASCIIColors.yellow("Diffusers: Torch not found. Installing it")
|
||||
pm.install_multiple(["torch","torchvision","torchaudio"], "https://download.pytorch.org/whl/cu121", force_reinstall=True)
|
||||
|
||||
import torch
|
||||
if not torch.cuda.is_available():
|
||||
ASCIIColors.yellow("Diffusers: Torch not using cuda. Reinstalling it")
|
||||
pm.install_multiple(["torch","torchvision","torchaudio"], "https://download.pytorch.org/whl/cu121", force_reinstall=True)
|
||||
if not pm.is_installed("transformers"):
|
||||
pm.install("transformers")
|
||||
|
||||
if not pm.is_installed("diffusers"):
|
||||
pm.install("diffusers")
|
||||
|
||||
|
||||
|
||||
def adjust_dimensions(value: int) -> int:
|
||||
"""Adjusts the given value to be divisible by 8."""
|
||||
return (value // 8) * 8
|
||||
|
||||
def download_file(url, folder_path, local_filename):
|
||||
# Make sure 'folder_path' exists
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True)
|
||||
with open(folder_path / local_filename, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
progress_bar.close()
|
||||
|
||||
return local_filename
|
||||
|
||||
def install_model(lollms_app:LollmsApplication, model_url):
|
||||
root_dir = lollms_app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
diffusers_folder = shared_folder / "diffusers"
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
from diffusers import PixArtSigmaPipeline
|
||||
|
||||
# You can replace the checkpoint id with "PixArt-alpha/PixArt-Sigma-XL-2-512-MS" too.
|
||||
pipe = PixArtSigmaPipeline.from_pretrained(
|
||||
"PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def upgrade_diffusers(lollms_app:LollmsApplication):
|
||||
PackageManager.install_or_update("diffusers")
|
||||
PackageManager.install_or_update("xformers")
|
||||
|
||||
|
||||
class LollmsDiffusers(LollmsTTI):
|
||||
has_controlnet = False
|
||||
def __init__(self, app, output_folder:str|Path=None):
|
||||
"""
|
||||
Initializes the LollmsDalle binding.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for authentication.
|
||||
output_folder (Path|str): The output folder where to put the generated data
|
||||
"""
|
||||
service_config = TypedConfig(
|
||||
ConfigTemplate([
|
||||
{"name":"model", "type":"str", "value":"v2ray/stable-diffusion-3-medium-diffusers", "help":"The model to be used"},
|
||||
{"name":"wm", "type":"str", "value":"lollms", "help":"The water marking"},
|
||||
]),
|
||||
BaseConfig(config={
|
||||
"api_key": "", # use avx2
|
||||
})
|
||||
)
|
||||
|
||||
super().__init__("diffusers", app, service_config)
|
||||
if not pm.is_installed("torch"):
|
||||
pm.install("torch torchvision torchaudio", "https://download.pytorch.org/whl/cu121")
|
||||
|
||||
if not pm.is_installed("transformers"):
|
||||
pm.install("transformers")
|
||||
|
||||
if not pm.is_installed("diffusers"):
|
||||
pm.install("diffusers")
|
||||
|
||||
super().__init__("diffusers",app)
|
||||
self.ready = False
|
||||
# Get the current directory
|
||||
lollms_paths = app.lollms_paths
|
||||
root_dir = lollms_paths.personal_path
|
||||
|
||||
shared_folder = root_dir/"shared"
|
||||
self.diffusers_folder = shared_folder / "diffusers"
|
||||
self.output_dir = root_dir / "outputs/diffusers"
|
||||
self.tti_models_dir = self.diffusers_folder / "models"
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.tti_models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ASCIIColors.red("")
|
||||
ASCIIColors.red(" _ _ _ _ _ __ __ ")
|
||||
ASCIIColors.red(" | | | | | | (_)/ _|/ _| ")
|
||||
ASCIIColors.red(" | | ___ | | |_ __ ___ ___ __| |_| |_| |_ _ _ ___ ___ _ __ ___ ")
|
||||
ASCIIColors.red(" | | / _ \| | | '_ ` _ \/ __| / _` | | _| _| | | / __|/ _ \ '__/ __| ")
|
||||
ASCIIColors.red(" | |___| (_) | | | | | | | \__ \| (_| | | | | | | |_| \__ \ __/ | \__ \ ")
|
||||
ASCIIColors.red(" |______\___/|_|_|_| |_| |_|___/ \__,_|_|_| |_| \__,_|___/\___|_| |___/ ")
|
||||
ASCIIColors.red(" ______ ")
|
||||
ASCIIColors.red(" |______| ")
|
||||
ASCIIColors.red("")
|
||||
ASCIIColors.yellow(f"Using model: {self.service_config.model}")
|
||||
import torch
|
||||
|
||||
try:
|
||||
if "stable-diffusion-3" in self.service_config.model:
|
||||
from diffusers import StableDiffusion3Pipeline # AutoPipelineForImage2Image#PixArtSigmaPipeline
|
||||
self.tti_model = StableDiffusion3Pipeline.from_pretrained(
|
||||
self.service_config.model, torch_dtype=torch.float16, cache_dir=self.tti_models_dir,
|
||||
use_safetensors=True,
|
||||
)
|
||||
self.iti_model = None
|
||||
else:
|
||||
from diffusers import AutoPipelineForText2Image # AutoPipelineForImage2Image#PixArtSigmaPipeline
|
||||
self.tti_model = AutoPipelineForText2Image.from_pretrained(
|
||||
self.service_config.model, torch_dtype=torch.float16, cache_dir=self.tti_models_dir,
|
||||
use_safetensors=True,
|
||||
)
|
||||
self.iti_model = None
|
||||
|
||||
# AutoPipelineForText2Image
|
||||
# self.tti_model = StableDiffusionPipeline.from_pretrained(
|
||||
# "CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16, cache_dir=self.tti_models_dir,
|
||||
# use_safetensors=True,
|
||||
# ) # self.service_config.model
|
||||
# Enable memory optimizations.
|
||||
try:
|
||||
if app.config.diffusers_offloading_mode=="sequential_cpu_offload":
|
||||
self.tti_model.enable_sequential_cpu_offload()
|
||||
elif app.coinfig.diffusers_offloading_mode=="model_cpu_offload":
|
||||
self.tti_model.enable_model_cpu_offload()
|
||||
except:
|
||||
pass
|
||||
except Exception as ex:
|
||||
self.tti_model= None
|
||||
trace_exception(ex)
|
||||
|
||||
|
||||
def install_diffusers(self):
|
||||
root_dir = self.app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
diffusers_folder = shared_folder / "diffusers"
|
||||
diffusers_folder.mkdir(exist_ok=True, parents=True)
|
||||
models_dir = diffusers_folder / "models"
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
pm.install("diffusers")
|
||||
pm.install("xformers")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def verify(app:LollmsApplication):
|
||||
# Clone repository
|
||||
root_dir = app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
diffusers_folder = shared_folder / "diffusers"
|
||||
return diffusers_folder.exists()
|
||||
|
||||
def get(app:LollmsApplication):
|
||||
root_dir = app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
diffusers_folder = shared_folder / "diffusers"
|
||||
diffusers_script_path = diffusers_folder / "lollms_diffusers.py"
|
||||
git_pull(diffusers_folder)
|
||||
|
||||
if diffusers_script_path.exists():
|
||||
ASCIIColors.success("lollms_diffusers found.")
|
||||
ASCIIColors.success("Loading source file...",end="")
|
||||
# use importlib to load the module from the file path
|
||||
from lollms.services.tti.diffusers.lollms_diffusers import LollmsDiffusers
|
||||
ASCIIColors.success("ok")
|
||||
return LollmsDiffusers
|
||||
|
||||
def get_scheduler_by_name(self, scheduler_name="LMS"):
|
||||
if scheduler_name == "LMS":
|
||||
from diffusers import LMSDiscreteScheduler
|
||||
return LMSDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
beta_schedule="scaled_linear"
|
||||
)
|
||||
elif scheduler_name == "Euler":
|
||||
from diffusers import EulerDiscreteScheduler
|
||||
return LMSDiscreteScheduler()
|
||||
elif scheduler_name == "DDPMS":
|
||||
from diffusers import DDPMScheduler
|
||||
return DDPMScheduler()
|
||||
elif scheduler_name == "DDIMS":
|
||||
from diffusers import DDIMScheduler
|
||||
return DDIMScheduler()
|
||||
|
||||
|
||||
|
||||
def paint(
|
||||
self,
|
||||
positive_prompt,
|
||||
negative_prompt,
|
||||
sampler_name="",
|
||||
seed=-1,
|
||||
scale=7.5,
|
||||
steps=20,
|
||||
img2img_denoising_strength=0.9,
|
||||
width=512,
|
||||
height=512,
|
||||
restore_faces=True,
|
||||
output_path=None
|
||||
):
|
||||
import torch
|
||||
if sampler_name!="":
|
||||
sc = self.get_scheduler_by_name(sampler_name)
|
||||
if sc:
|
||||
self.tti_model.scheduler = sc
|
||||
width = adjust_dimensions(int(width))
|
||||
height = adjust_dimensions(int(height))
|
||||
|
||||
|
||||
def process_output_path(output_path, self_output_dir):
|
||||
if output_path is None:
|
||||
output_path = Path(self_output_dir)
|
||||
fn = find_next_available_filename(output_path, "diff_img_")
|
||||
else:
|
||||
output_path = Path(output_path)
|
||||
if output_path.is_file():
|
||||
fn = output_path
|
||||
elif output_path.is_dir():
|
||||
fn = find_next_available_filename(output_path, "diff_img_")
|
||||
else:
|
||||
# If the path doesn't exist, assume it's intended to be a file
|
||||
fn = output_path
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
# Usage in the original context
|
||||
if output_path is None:
|
||||
output_path = self.output_dir
|
||||
|
||||
fn = process_output_path(output_path, self.output_dir)
|
||||
|
||||
if seed!=-1:
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
image = self.tti_model(positive_prompt, negative_prompt=negative_prompt, height=height, width=width, guidance_scale=scale, num_inference_steps=steps, generator=generator).images[0]
|
||||
else:
|
||||
image = self.tti_model(positive_prompt, negative_prompt=negative_prompt, height=height, width=width, guidance_scale=scale, num_inference_steps=steps).images[0]
|
||||
# Save the image
|
||||
image.save(fn)
|
||||
return fn, {"prompt":positive_prompt, "negative_prompt":negative_prompt}
|
||||
|
||||
def paint_from_images(self, positive_prompt: str,
|
||||
image: str,
|
||||
negative_prompt: str = "",
|
||||
sampler_name="",
|
||||
seed=-1,
|
||||
scale=7.5,
|
||||
steps=20,
|
||||
img2img_denoising_strength=0.9,
|
||||
width=512,
|
||||
height=512,
|
||||
restore_faces=True,
|
||||
output_path=None
|
||||
) -> List[Dict[str, str]]:
|
||||
import torch
|
||||
from diffusers.utils import make_image_grid, load_image
|
||||
|
||||
if not self.iti_model:
|
||||
from diffusers import AutoPipelineForImage2Image
|
||||
|
||||
self.iti_model = AutoPipelineForImage2Image.from_pretrained(
|
||||
self.self.service_config.model, torch_dtype=torch.float16, variant="fp16", use_safetensors=True
|
||||
)
|
||||
if sampler_name!="":
|
||||
sc = self.get_scheduler_by_name(sampler_name)
|
||||
if sc:
|
||||
self.iti_model.scheduler = sc
|
||||
|
||||
img = load_image(image)
|
||||
if output_path is None:
|
||||
output_path = self.output_dir
|
||||
if seed!=-1:
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
image = self.titi_model(positive_prompt,image=img, negative_prompt=negative_prompt, height=height, width=width, guidance_scale=scale, num_inference_steps=steps, generator=generator).images[0]
|
||||
else:
|
||||
image = self.iti_model(positive_prompt,image=img, negative_prompt=negative_prompt, height=height, width=width, guidance_scale=scale, num_inference_steps=steps).images[0]
|
||||
output_path = Path(output_path)
|
||||
fn = find_next_available_filename(output_path,"diff_img_")
|
||||
# Save the image
|
||||
image.save(fn)
|
||||
return fn, {"prompt":positive_prompt, "negative_prompt":negative_prompt}
|
||||
|
@ -1,238 +0,0 @@
|
||||
# Title LollmsDiffusers
|
||||
# Licence: MIT
|
||||
# Author : Paris Neo
|
||||
# All rights are reserved
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.utilities import PackageManager, check_and_install_torch, find_next_available_filename, install_cuda, check_torch_version
|
||||
|
||||
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
|
||||
import sys
|
||||
import requests
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ascii_colors import ASCIIColors, trace_exception
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.tti import LollmsTTI
|
||||
from lollms.utilities import git_pull
|
||||
from tqdm import tqdm
|
||||
from PIL import Image
|
||||
import threading
|
||||
import base64
|
||||
from PIL import Image
|
||||
import io
|
||||
import pipmaster as pm
|
||||
|
||||
|
||||
def adjust_dimensions(value: int) -> int:
|
||||
"""Adjusts the given value to be divisible by 8."""
|
||||
return (value // 8) * 8
|
||||
|
||||
def download_file(url, folder_path, local_filename):
|
||||
# Make sure 'folder_path' exists
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True)
|
||||
with open(folder_path / local_filename, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
progress_bar.close()
|
||||
|
||||
return local_filename
|
||||
|
||||
|
||||
def install_diffusers(lollms_app:LollmsApplication):
|
||||
root_dir = lollms_app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
diffusers_folder = shared_folder / "diffusers"
|
||||
diffusers_folder.mkdir(exist_ok=True, parents=True)
|
||||
models_dir = diffusers_folder / "models"
|
||||
models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
PackageManager.reinstall("diffusers")
|
||||
PackageManager.reinstall("xformers")
|
||||
|
||||
|
||||
|
||||
|
||||
def upgrade_diffusers(lollms_app:LollmsApplication):
|
||||
PackageManager.install_or_update("diffusers")
|
||||
PackageManager.install_or_update("xformers")
|
||||
|
||||
|
||||
class LollmsDiffusersClient(LollmsTTI):
|
||||
has_controlnet = False
|
||||
def __init__(self, app, output_folder:str|Path=None):
|
||||
"""
|
||||
Initializes the LollmsDalle binding.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for authentication.
|
||||
output_folder (Path|str): The output folder where to put the generated data
|
||||
"""
|
||||
service_config = TypedConfig(
|
||||
ConfigTemplate([
|
||||
{
|
||||
"name": "base_url",
|
||||
"type": "str",
|
||||
"value": "http://127.0.0.1:8188/",
|
||||
"help": "The base URL for the service. This is the address where the service is hosted (e.g., http://127.0.0.1:8188/)."
|
||||
},
|
||||
{
|
||||
"name": "wm",
|
||||
"type": "str",
|
||||
"value": "lollms",
|
||||
"help": "Watermarking text or identifier to be used in the service."
|
||||
},
|
||||
{"name":"model", "type":"str", "value":"v2ray/stable-diffusion-3-medium-diffusers", "help":"The model to be used"},
|
||||
{"name":"wm", "type":"str", "value":"lollms", "help":"The water marking"},
|
||||
]),
|
||||
BaseConfig(config={
|
||||
"api_key": "", # use avx2
|
||||
})
|
||||
)
|
||||
|
||||
super().__init__("diffusers_client", app, service_config)
|
||||
self.ready = False
|
||||
# Get the current directory
|
||||
lollms_paths = app.lollms_paths
|
||||
root_dir = lollms_paths.personal_path
|
||||
|
||||
|
||||
shared_folder = root_dir/"shared"
|
||||
self.diffusers_folder = shared_folder / "diffusers"
|
||||
self.output_dir = root_dir / "outputs/diffusers"
|
||||
self.models_dir = self.diffusers_folder / "models"
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
ASCIIColors.green(" _ _ _ _ _ __ __ ")
|
||||
ASCIIColors.green(" | | | | | | (_)/ _|/ _| ")
|
||||
ASCIIColors.green(" | | ___ | | |_ __ ___ ___ __| |_| |_| |_ _ _ ___ ___ _ __ ___ ")
|
||||
ASCIIColors.green(" | | / _ \| | | '_ ` _ \/ __| / _` | | _| _| | | / __|/ _ \ '__/ __| ")
|
||||
ASCIIColors.green(" | |___| (_) | | | | | | | \__ \| (_| | | | | | | |_| \__ \ __/ | \__ \ ")
|
||||
ASCIIColors.green(" |______\___/|_|_|_| |_| |_|___/ \__,_|_|_| |_| \__,_|___/\___|_| |___/ ")
|
||||
ASCIIColors.green(" ______ ")
|
||||
ASCIIColors.green(" |______| ")
|
||||
|
||||
|
||||
@staticmethod
|
||||
def verify(app:LollmsApplication):
|
||||
# Clone repository
|
||||
root_dir = app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
diffusers_folder = shared_folder / "diffusers"
|
||||
return diffusers_folder.exists()
|
||||
|
||||
def get(app:LollmsApplication):
|
||||
root_dir = app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
diffusers_folder = shared_folder / "diffusers"
|
||||
diffusers_script_path = diffusers_folder / "lollms_diffusers.py"
|
||||
git_pull(diffusers_folder)
|
||||
|
||||
if diffusers_script_path.exists():
|
||||
ASCIIColors.success("lollms_diffusers found.")
|
||||
ASCIIColors.success("Loading source file...",end="")
|
||||
# use importlib to load the module from the file path
|
||||
from lollms.services.tti.diffusers.lollms_diffusers import LollmsDiffusers
|
||||
ASCIIColors.success("ok")
|
||||
return LollmsDiffusers
|
||||
def paint(
|
||||
self,
|
||||
positive_prompt,
|
||||
negative_prompt="",
|
||||
sampler_name="",
|
||||
seed=-1,
|
||||
scale=7.5,
|
||||
steps=20,
|
||||
img2img_denoising_strength=0.9,
|
||||
width=512,
|
||||
height=512,
|
||||
restore_faces=True,
|
||||
output_path=None
|
||||
):
|
||||
url = f"{self.service_config.base_url}/generate-image"
|
||||
|
||||
payload = {
|
||||
"positive_prompt": positive_prompt,
|
||||
"negative_prompt": negative_prompt,
|
||||
"sampler_name": sampler_name,
|
||||
"seed": seed,
|
||||
"scale": scale,
|
||||
"steps": steps,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"restore_faces": restore_faces
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=payload)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Assuming the server returns the image path
|
||||
server_image_path = result['image_path']
|
||||
|
||||
# If output_path is not provided, use the server's image path
|
||||
if output_path is None:
|
||||
output_path = server_image_path
|
||||
else:
|
||||
# Copy the image from server path to output_path
|
||||
# This part needs to be implemented based on how you want to handle file transfer
|
||||
pass
|
||||
|
||||
return {
|
||||
"image_path": output_path,
|
||||
"prompt": result['prompt'],
|
||||
"negative_prompt": result['negative_prompt']
|
||||
}
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"An error occurred: {e}")
|
||||
return None
|
||||
|
||||
def save_image(self, image_data, output_path):
|
||||
image = Image.open(io.BytesIO(base64.b64decode(image_data)))
|
||||
image.save(output_path)
|
||||
print(f"Image saved to {output_path}")
|
||||
|
||||
def paint_from_images(self, positive_prompt: str,
|
||||
images: List[str],
|
||||
negative_prompt: str = "",
|
||||
sampler_name="",
|
||||
seed=-1,
|
||||
scale=7.5,
|
||||
steps=20,
|
||||
img2img_denoising_strength=0.9,
|
||||
width=512,
|
||||
height=512,
|
||||
restore_faces=True,
|
||||
output_path=None
|
||||
) -> List[Dict[str, str]]:
|
||||
import torch
|
||||
if sampler_name!="":
|
||||
sc = self.get_scheduler_by_name(sampler_name)
|
||||
if sc:
|
||||
self.model.scheduler = sc
|
||||
|
||||
if output_path is None:
|
||||
output_path = self.output_dir
|
||||
if seed!=-1:
|
||||
generator = torch.Generator("cuda").manual_seed(seed)
|
||||
image = self.model(positive_prompt, negative_prompt=negative_prompt, height=height, width=width, guidance_scale=scale, num_inference_steps=steps, generator=generator).images[0]
|
||||
else:
|
||||
image = self.model(positive_prompt, negative_prompt=negative_prompt, height=height, width=width, guidance_scale=scale, num_inference_steps=steps).images[0]
|
||||
output_path = Path(output_path)
|
||||
fn = find_next_available_filename(output_path,"diff_img_")
|
||||
# Save the image
|
||||
image.save(fn)
|
||||
return fn, {"prompt":positive_prompt, "negative_prompt":negative_prompt}
|
||||
|
@ -1,277 +0,0 @@
|
||||
# Title LollmsFooocus
|
||||
# Licence: MIT
|
||||
# Author : Paris Neo
|
||||
# All rights are reserved
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
|
||||
from lollms.utilities import PackageManager, check_and_install_torch, find_next_available_filename
|
||||
import time
|
||||
import io
|
||||
import sys
|
||||
import requests
|
||||
import os
|
||||
import base64
|
||||
import subprocess
|
||||
import time
|
||||
import json
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from PIL import Image, PngImagePlugin
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ascii_colors import ASCIIColors, trace_exception
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.tti import LollmsTTI
|
||||
from lollms.utilities import git_pull, show_yes_no_dialog, run_script_in_env, create_conda_env
|
||||
import subprocess
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
import threading
|
||||
|
||||
|
||||
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
|
||||
|
||||
def download_file(url, folder_path, local_filename):
|
||||
# Make sure 'folder_path' exists
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with requests.get(url, stream=True) as r:
|
||||
r.raise_for_status()
|
||||
total_size = int(r.headers.get('content-length', 0))
|
||||
progress_bar = tqdm(total=total_size, unit='B', unit_scale=True)
|
||||
with open(folder_path / local_filename, 'wb') as f:
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
f.write(chunk)
|
||||
progress_bar.update(len(chunk))
|
||||
progress_bar.close()
|
||||
|
||||
return local_filename
|
||||
|
||||
|
||||
def install_fooocus(lollms_app:LollmsApplication):
|
||||
root_dir = lollms_app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
fooocus_folder = shared_folder / "fooocus"
|
||||
fooocus_folder.mkdir(exist_ok=True, parents=True)
|
||||
if not PackageManager.check_package_installed("fooocus"):
|
||||
PackageManager.install_or_update("gradio_client")
|
||||
|
||||
|
||||
def upgrade_fooocus(lollms_app:LollmsApplication):
|
||||
PackageManager.install_or_update("fooocus")
|
||||
PackageManager.install_or_update("xformers")
|
||||
|
||||
|
||||
class LollmsFooocus(LollmsTTI):
|
||||
def __init__(self, app, output_folder:str|Path=None):
|
||||
"""
|
||||
Initializes the LollmsDalle binding.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for authentication.
|
||||
output_folder (Path|str): The output folder where to put the generated data
|
||||
"""
|
||||
service_config = TypedConfig(
|
||||
ConfigTemplate([
|
||||
{
|
||||
"name": "base_url",
|
||||
"type": "str",
|
||||
"value": "localhost:1024",
|
||||
"help": "The base URL for the service. This is the address where the service is hosted (e.g., http://127.0.0.1:8188/)."
|
||||
},
|
||||
{
|
||||
"name": "wm",
|
||||
"type": "str",
|
||||
"value": "lollms",
|
||||
"help": "Watermarking text or identifier to be used in the service."
|
||||
},
|
||||
{"name":"model", "type":"str", "value":"v2ray/stable-diffusion-3-medium-diffusers", "help":"The model to be used"},
|
||||
{"name":"wm", "type":"str", "value":"lollms", "help":"The water marking"},
|
||||
]),
|
||||
BaseConfig(config={
|
||||
"api_key": "", # use avx2
|
||||
})
|
||||
)
|
||||
|
||||
super().__init__("fooocus", app, service_config)
|
||||
self.ready = False
|
||||
# Get the current directory
|
||||
lollms_paths = app.lollms_paths
|
||||
root_dir = lollms_paths.personal_path
|
||||
|
||||
shared_folder = root_dir/"shared"
|
||||
self.fooocus_folder = shared_folder / "fooocus"
|
||||
self.output_dir = root_dir / "outputs/fooocus"
|
||||
self.models_dir = self.fooocus_folder / "models"
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
ASCIIColors.red(" _ _ _ ______ ")
|
||||
ASCIIColors.red("| | | | | | ___| ")
|
||||
ASCIIColors.red("| | ___ | | |_ __ ___ ___ | |_ ___ ___ ___ ___ _ _ ___ ")
|
||||
ASCIIColors.red("| | / _ \| | | '_ ` _ \/ __| | _/ _ \ / _ \ / _ \ / __| | | / __|")
|
||||
ASCIIColors.red("| |___| (_) | | | | | | | \__ \ | || (_) | (_) | (_) | (__| |_| \__ \ ")
|
||||
ASCIIColors.red("\_____/\___/|_|_|_| |_| |_|___/ \_| \___/ \___/ \___/ \___|\__,_|___/")
|
||||
ASCIIColors.red(" ______ ")
|
||||
ASCIIColors.red(" |______| ")
|
||||
if not PackageManager.check_package_installed("gradio_client"):
|
||||
PackageManager.install_or_update("gradio_client")
|
||||
from gradio_client import Client
|
||||
self.client = Client(base_url)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def verify(app:LollmsApplication):
|
||||
# Clone repository
|
||||
root_dir = app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
fooocus_folder = shared_folder / "fooocus"
|
||||
return fooocus_folder.exists()
|
||||
|
||||
def get(app:LollmsApplication):
|
||||
root_dir = app.lollms_paths.personal_path
|
||||
shared_folder = root_dir/"shared"
|
||||
fooocus_folder = shared_folder / "fooocus"
|
||||
fooocus_script_path = fooocus_folder / "lollms_fooocus.py"
|
||||
git_pull(fooocus_folder)
|
||||
|
||||
if fooocus_script_path.exists():
|
||||
ASCIIColors.success("lollms_fooocus found.")
|
||||
ASCIIColors.success("Loading source file...",end="")
|
||||
# use importlib to load the module from the file path
|
||||
from lollms.services.fooocus.lollms_fooocus import LollmsFooocus
|
||||
ASCIIColors.success("ok")
|
||||
return LollmsFooocus
|
||||
|
||||
|
||||
def paint(
|
||||
self,
|
||||
positive_prompt,
|
||||
negative_prompt,
|
||||
sampler_name="Euler",
|
||||
seed=-1,
|
||||
scale=7.5,
|
||||
steps=20,
|
||||
img2img_denoising_strength=0.9,
|
||||
width=512,
|
||||
height=512,
|
||||
restore_faces=True,
|
||||
output_path=None
|
||||
):
|
||||
if output_path is None:
|
||||
output_path = self.output_dir
|
||||
|
||||
self.client.predict()
|
||||
image = self.model(positive_prompt, negative_prompt=negative_prompt, guidance_scale=scale, num_inference_steps=steps,).images[0]
|
||||
output_path = Path(output_path)
|
||||
fn = find_next_available_filename(output_path,"diff_img_")
|
||||
# Save the image
|
||||
image.save(fn)
|
||||
return fn, {"prompt":positive_prompt, "negative_prompt":negative_prompt}
|
||||
|
||||
|
||||
|
||||
# from gradio_client import Client
|
||||
|
||||
# client = Client("https://fooocus.mydomain.fr/",verify_ssl=False)
|
||||
# result = client.predict(
|
||||
# True, # bool in 'Generate Image Grid for Each Batch' Checkbox component
|
||||
# "Howdy!", # str in 'parameter_11' Textbox component
|
||||
# "Howdy!", # str in 'Negative Prompt' Textbox component
|
||||
# ["Fooocus V2"], # List[str] in 'Selected Styles' Checkboxgroup component
|
||||
# "Quality", # str in 'Performance' Radio component
|
||||
# '704×1408 <span style="color: grey;"> ∣ 1:2</span>', # str in 'Aspect Ratios' Radio component
|
||||
# 1, # int | float (numeric value between 1 and 32) in 'Image Number' Slider component
|
||||
# "png", # str in 'Output Format' Radio component
|
||||
# "Howdy!", # str in 'Seed' Textbox component
|
||||
# True, # bool in 'Read wildcards in order' Checkbox component
|
||||
# 0, # int | float (numeric value between 0.0 and 30.0) in 'Image Sharpness' Slider component
|
||||
# 1, # int | float (numeric value between 1.0 and 30.0) in 'Guidance Scale' Slider component
|
||||
# "animaPencilXL_v100.safetensors", # str (Option from: ['animaPencilXL_v100.safetensors', 'juggernautXL_v8Rundiffusion.safetensors', 'realisticStockPhoto_v20.safetensors', 'sd_xl_base_1.0_0.9vae.safetensors', 'sd_xl_refiner_1.0_0.9vae.safetensors']) in 'Base Model (SDXL only)' Dropdown component
|
||||
# "None", # str (Option from: ['None', 'animaPencilXL_v100.safetensors', 'juggernautXL_v8Rundiffusion.safetensors', 'realisticStockPhoto_v20.safetensors', 'sd_xl_base_1.0_0.9vae.safetensors', 'sd_xl_refiner_1.0_0.9vae.safetensors']) in 'Refiner (SDXL or SD 1.5)' Dropdown component
|
||||
# 0.1, # int | float (numeric value between 0.1 and 1.0) in 'Refiner Switch At' Slider component
|
||||
# True, # bool in 'Enable' Checkbox component
|
||||
# "None", # str (Option from: ['None', 'sd_xl_offset_example-lora_1.0.safetensors', 'SDXL_FILM_PHOTOGRAPHY_STYLE_BetaV0.4.safetensors', 'sdxl_lcm_lora.safetensors', 'sdxl_lightning_4step_lora.safetensors']) in 'LoRA 1' Dropdown component
|
||||
# -2, # int | float (numeric value between -2 and 2) in 'Weight' Slider component
|
||||
# True, # bool in 'Enable' Checkbox component
|
||||
# "None", # str (Option from: ['None', 'sd_xl_offset_example-lora_1.0.safetensors', 'SDXL_FILM_PHOTOGRAPHY_STYLE_BetaV0.4.safetensors', 'sdxl_lcm_lora.safetensors', 'sdxl_lightning_4step_lora.safetensors']) in 'LoRA 2' Dropdown component
|
||||
# -2, # int | float (numeric value between -2 and 2) in 'Weight' Slider component
|
||||
# True, # bool in 'Enable' Checkbox component
|
||||
# "None", # str (Option from: ['None', 'sd_xl_offset_example-lora_1.0.safetensors', 'SDXL_FILM_PHOTOGRAPHY_STYLE_BetaV0.4.safetensors', 'sdxl_lcm_lora.safetensors', 'sdxl_lightning_4step_lora.safetensors']) in 'LoRA 3' Dropdown component
|
||||
# -2, # int | float (numeric value between -2 and 2) in 'Weight' Slider component
|
||||
# True, # bool in 'Enable' Checkbox component
|
||||
# "None", # str (Option from: ['None', 'sd_xl_offset_example-lora_1.0.safetensors', 'SDXL_FILM_PHOTOGRAPHY_STYLE_BetaV0.4.safetensors', 'sdxl_lcm_lora.safetensors', 'sdxl_lightning_4step_lora.safetensors']) in 'LoRA 4' Dropdown component
|
||||
# -2, # int | float (numeric value between -2 and 2) in 'Weight' Slider component
|
||||
# True, # bool in 'Enable' Checkbox component
|
||||
# "None", # str (Option from: ['None', 'sd_xl_offset_example-lora_1.0.safetensors', 'SDXL_FILM_PHOTOGRAPHY_STYLE_BetaV0.4.safetensors', 'sdxl_lcm_lora.safetensors', 'sdxl_lightning_4step_lora.safetensors']) in 'LoRA 5' Dropdown component
|
||||
# -2, # int | float (numeric value between -2 and 2) in 'Weight' Slider component
|
||||
# True, # bool in 'Input Image' Checkbox component
|
||||
# "Howdy!", # str in 'parameter_91' Textbox component
|
||||
# "Disabled", # str in 'Upscale or Variation:' Radio component
|
||||
# "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", # str (filepath or URL to image) in 'Drag above image to here' Image component
|
||||
# ["Left"], # List[str] in 'Outpaint Direction' Checkboxgroup component
|
||||
# "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", # str (filepath or URL to image) in 'Drag inpaint or outpaint image to here' Image component
|
||||
# "Howdy!", # str in 'Inpaint Additional Prompt' Textbox component
|
||||
# "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", # str (filepath or URL to image) in 'Mask Upload' Image component
|
||||
# True, # bool in 'Disable Preview' Checkbox component
|
||||
# True, # bool in 'Disable Intermediate Results' Checkbox component
|
||||
# True, # bool in 'Disable seed increment' Checkbox component
|
||||
# 0.1, # int | float (numeric value between 0.1 and 3.0) in 'Positive ADM Guidance Scaler' Slider component
|
||||
# 0.1, # int | float (numeric value between 0.1 and 3.0) in 'Negative ADM Guidance Scaler' Slider component
|
||||
# 0, # int | float (numeric value between 0.0 and 1.0) in 'ADM Guidance End At Step' Slider component
|
||||
# 1, # int | float (numeric value between 1.0 and 30.0) in 'CFG Mimicking from TSNR' Slider component
|
||||
# "euler", # str (Option from: ['euler', 'euler_ancestral', 'heun', 'heunpp2', 'dpm_2', 'dpm_2_ancestral', 'lms', 'dpm_fast', 'dpm_adaptive', 'dpmpp_2s_ancestral', 'dpmpp_sde', 'dpmpp_sde_gpu', 'dpmpp_2m', 'dpmpp_2m_sde', 'dpmpp_2m_sde_gpu', 'dpmpp_3m_sde', 'dpmpp_3m_sde_gpu', 'ddpm', 'lcm', 'ddim', 'uni_pc', 'uni_pc_bh2']) in 'Sampler' Dropdown component
|
||||
# "normal", # str (Option from: ['normal', 'karras', 'exponential', 'sgm_uniform', 'simple', 'ddim_uniform', 'lcm', 'turbo']) in 'Scheduler' Dropdown component
|
||||
# -1, # int | float (numeric value between -1 and 200) in 'Forced Overwrite of Sampling Step' Slider component
|
||||
# -1, # int | float (numeric value between -1 and 200) in 'Forced Overwrite of Refiner Switch Step' Slider component
|
||||
# -1, # int | float (numeric value between -1 and 2048) in 'Forced Overwrite of Generating Width' Slider component
|
||||
# -1, # int | float (numeric value between -1 and 2048) in 'Forced Overwrite of Generating Height' Slider component
|
||||
# -1, # int | float (numeric value between -1 and 1.0) in 'Forced Overwrite of Denoising Strength of "Vary"' Slider component
|
||||
# -1, # int | float (numeric value between -1 and 1.0) in 'Forced Overwrite of Denoising Strength of "Upscale"' Slider component
|
||||
# True, # bool in 'Mixing Image Prompt and Vary/Upscale' Checkbox component
|
||||
# True, # bool in 'Mixing Image Prompt and Inpaint' Checkbox component
|
||||
# True, # bool in 'Debug Preprocessors' Checkbox component
|
||||
# True, # bool in 'Skip Preprocessors' Checkbox component
|
||||
# 1, # int | float (numeric value between 1 and 255) in 'Canny Low Threshold' Slider component
|
||||
# 1, # int | float (numeric value between 1 and 255) in 'Canny High Threshold' Slider component
|
||||
# "joint", # str (Option from: ['joint', 'separate', 'vae']) in 'Refiner swap method' Dropdown component
|
||||
# 0, # int | float (numeric value between 0.0 and 1.0) in 'Softness of ControlNet' Slider component
|
||||
# True, # bool in 'Enabled' Checkbox component
|
||||
# 0, # int | float (numeric value between 0 and 2) in 'B1' Slider component
|
||||
# 0, # int | float (numeric value between 0 and 2) in 'B2' Slider component
|
||||
# 0, # int | float (numeric value between 0 and 4) in 'S1' Slider component
|
||||
# 0, # int | float (numeric value between 0 and 4) in 'S2' Slider component
|
||||
# True, # bool in 'Debug Inpaint Preprocessing' Checkbox component
|
||||
# True, # bool in 'Disable initial latent in inpaint' Checkbox component
|
||||
# "None", # str (Option from: ['None', 'v1', 'v2.5', 'v2.6']) in 'Inpaint Engine' Dropdown component
|
||||
# 0, # int | float (numeric value between 0.0 and 1.0) in 'Inpaint Denoising Strength' Slider component
|
||||
# 0, # int | float (numeric value between 0.0 and 1.0) in 'Inpaint Respective Field' Slider component
|
||||
# True, # bool in 'Enable Mask Upload' Checkbox component
|
||||
# True, # bool in 'Invert Mask' Checkbox component
|
||||
# -64, # int | float (numeric value between -64 and 64) in 'Mask Erode or Dilate' Slider component
|
||||
# True, # bool in 'Save Metadata to Images' Checkbox component
|
||||
# "fooocus", # str in 'Metadata Scheme' Radio component
|
||||
# "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", # str (filepath or URL to image) in 'Image' Image component
|
||||
# 0, # int | float (numeric value between 0.0 and 1.0) in 'Stop At' Slider component
|
||||
# 0, # int | float (numeric value between 0.0 and 2.0) in 'Weight' Slider component
|
||||
# "ImagePrompt", # str in 'Type' Radio component
|
||||
# "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", # str (filepath or URL to image) in 'Image' Image component
|
||||
# 0, # int | float (numeric value between 0.0 and 1.0) in 'Stop At' Slider component
|
||||
# 0, # int | float (numeric value between 0.0 and 2.0) in 'Weight' Slider component
|
||||
# "ImagePrompt", # str in 'Type' Radio component
|
||||
# "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", # str (filepath or URL to image)in 'Image' Image component
|
||||
# 0, # int | float (numeric value between 0.0 and 1.0) in 'Stop At' Slider component
|
||||
# 0, # int | float (numeric value between 0.0 and 2.0) in 'Weight' Slider component
|
||||
# "ImagePrompt", # str in 'Type' Radio component
|
||||
# "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", # str (filepath or URL to image) in 'Image' Image component
|
||||
# 0, # int | float (numeric value between 0.0 and 1.0) in 'Stop At' Slider component
|
||||
# 0, # int | float (numeric value between 0.0 and 2.0) in 'Weight' Slider component
|
||||
# "ImagePrompt", # str in 'Type' Radio component
|
||||
# fn_index=40
|
||||
# )
|
||||
# print(result)
|
@ -1,334 +0,0 @@
|
||||
# Title LollmsMidjourney
|
||||
# Licence: Apache 2.0
|
||||
# Author : Paris Neo
|
||||
|
||||
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
|
||||
import time
|
||||
import io
|
||||
import sys
|
||||
import requests
|
||||
import os
|
||||
import base64
|
||||
import subprocess
|
||||
import time
|
||||
import json
|
||||
import platform
|
||||
from dataclasses import dataclass
|
||||
from PIL import Image, PngImagePlugin
|
||||
from enum import Enum
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from ascii_colors import ASCIIColors, trace_exception
|
||||
from lollms.paths import LollmsPaths
|
||||
from lollms.utilities import PackageManager, find_next_available_filename
|
||||
from lollms.tti import LollmsTTI
|
||||
import subprocess
|
||||
import shutil
|
||||
from tqdm import tqdm
|
||||
|
||||
import os
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
|
||||
MIDJOURNEY_API_URL = "https://api.imaginepro.ai/api/v1/nova"
|
||||
|
||||
|
||||
|
||||
def split_image(file_path, folder_path, i):
|
||||
with Image.open(file_path) as img:
|
||||
width, height = img.size
|
||||
|
||||
# Calculate the size of each quadrant
|
||||
quad_width = width // 2
|
||||
quad_height = height // 2
|
||||
|
||||
quadrants = [
|
||||
(0, 0, quad_width, quad_height),
|
||||
(quad_width, 0, width, quad_height),
|
||||
(0, quad_height, quad_width, height),
|
||||
(quad_width, quad_height, width, height)
|
||||
]
|
||||
|
||||
split_paths = []
|
||||
for index, box in enumerate(quadrants):
|
||||
quadrant = img.crop(box)
|
||||
split_path = os.path.join(folder_path, f"midjourney_{i}_{index+1}.png")
|
||||
quadrant.save(split_path)
|
||||
split_paths.append(split_path)
|
||||
|
||||
return split_paths
|
||||
|
||||
class LollmsMidjourney(LollmsTTI):
|
||||
def __init__(self, app, output_folder:str|Path=None):
|
||||
"""
|
||||
Initializes the LollmsDalle binding.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for authentication.
|
||||
output_folder (Path|str): The output folder where to put the generated data
|
||||
"""
|
||||
# Check for the MIDJOURNEY_KEY environment variable if no API key is provided
|
||||
api_key = os.getenv("MIDJOURNEY_KEY","")
|
||||
service_config = TypedConfig(
|
||||
ConfigTemplate([
|
||||
{
|
||||
"name": "api_key",
|
||||
"type": "str",
|
||||
"value": api_key,
|
||||
"help": "A valid API key for Midjourney, used to access the text generation service via the Anthropic API."
|
||||
},
|
||||
{
|
||||
"name": "timeout",
|
||||
"type": "int",
|
||||
"value": 300,
|
||||
"help": "The maximum time (in seconds) to wait for a response from the API before timing out."
|
||||
},
|
||||
{
|
||||
"name": "retries",
|
||||
"type": "int",
|
||||
"value": 2,
|
||||
"help": "The number of times to retry the request if it fails or times out."
|
||||
},
|
||||
{
|
||||
"name": "interval",
|
||||
"type": "int",
|
||||
"value": 1,
|
||||
"help": "The time interval (in seconds) between retry attempts."
|
||||
}
|
||||
]),
|
||||
BaseConfig(config={
|
||||
"api_key": "", # use avx2
|
||||
})
|
||||
)
|
||||
|
||||
super().__init__("midjourney", app, service_config)
|
||||
self.output_folder = output_folder
|
||||
|
||||
self.session = requests.Session()
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.service_config.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
def settings_updated(self):
|
||||
self.session = requests.Session()
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.service_config.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def send_prompt(self, prompt: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a prompt to the MidJourney API to generate an image.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt for image generation.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The response from the API.
|
||||
"""
|
||||
url = f"{MIDJOURNEY_API_URL}/imagine"
|
||||
payload = {"prompt": prompt}
|
||||
response = self.session.post(url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def check_progress(self, message_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check the progress of the image generation.
|
||||
|
||||
Args:
|
||||
message_id (str): The message ID from the initial request.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The response from the API.
|
||||
"""
|
||||
url = f"{MIDJOURNEY_API_URL}/message/{message_id}"
|
||||
response = self.session.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def upscale_image(self, message_id: str, button: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Upscale the generated image.
|
||||
|
||||
Args:
|
||||
message_id (str): The message ID from the initial request.
|
||||
button (str): The button action for upscaling.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The response from the API.
|
||||
"""
|
||||
url = f"{MIDJOURNEY_API_URL}/button"
|
||||
payload = {"messageId": message_id, "button": button}
|
||||
response = self.session.post(url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def send_prompt_with_retry(self, prompt: str, retries: int = 3) -> Dict[str, Any]:
|
||||
"""
|
||||
Send a prompt to the MidJourney API with retry mechanism.
|
||||
|
||||
Args:
|
||||
prompt (str): The prompt for image generation.
|
||||
retries (int): Number of retry attempts.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The response from the API.
|
||||
"""
|
||||
for attempt in range(retries):
|
||||
try:
|
||||
return self.send_prompt(prompt)
|
||||
except requests.exceptions.RequestException as e:
|
||||
if attempt < retries - 1:
|
||||
ASCIIColors.warning(f"Attempt {attempt + 1} failed: {e}. Retrying...")
|
||||
time.sleep(2 ** attempt)
|
||||
else:
|
||||
ASCIIColors.error(f"All {retries} attempts failed.")
|
||||
raise e
|
||||
|
||||
def poll_progress(self, message_id: str, timeout: int = 300, interval: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll the progress of the image generation until it's done or timeout.
|
||||
|
||||
Args:
|
||||
message_id (str): The message ID from the initial request.
|
||||
timeout (int): The maximum time to wait for the image generation.
|
||||
interval (int): The interval between polling attempts.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The response from the API.
|
||||
"""
|
||||
start_time = time.time()
|
||||
with tqdm(total=100, desc="Image Generation Progress", unit="%") as pbar:
|
||||
while time.time() - start_time < timeout:
|
||||
progress_response = self.check_progress(message_id)
|
||||
if progress_response.get("status") == "DONE":
|
||||
pbar.update(100 - pbar.n) # Ensure the progress bar is complete
|
||||
print(progress_response)
|
||||
return progress_response
|
||||
elif progress_response.get("status") == "FAIL":
|
||||
ASCIIColors.error("Image generation failed.")
|
||||
return {"error": "Image generation failed"}
|
||||
|
||||
progress = progress_response.get("progress", 0)
|
||||
pbar.update(progress - pbar.n) # Update the progress bar
|
||||
time.sleep(interval)
|
||||
|
||||
ASCIIColors.error("Timeout while waiting for image generation.")
|
||||
return {"error": "Timeout while waiting for image generation"}
|
||||
|
||||
|
||||
|
||||
def download_image(self, uri, folder_path, split=False):
|
||||
if not os.path.exists(folder_path):
|
||||
os.makedirs(folder_path)
|
||||
|
||||
i = 1
|
||||
while True:
|
||||
file_path = os.path.join(folder_path, f"midjourney_{i}.png")
|
||||
if not os.path.exists(file_path):
|
||||
break
|
||||
i += 1
|
||||
|
||||
response = requests.get(uri)
|
||||
if response.status_code == 200:
|
||||
with open(file_path, 'wb') as file:
|
||||
file.write(response.content)
|
||||
print(f"Image downloaded and saved as {file_path}")
|
||||
|
||||
if split:
|
||||
return split_image(file_path, folder_path, i)
|
||||
else:
|
||||
return [file_path]
|
||||
else:
|
||||
print(f"Failed to download image. Status code: {response.status_code}")
|
||||
return None
|
||||
|
||||
def get_nearest_aspect_ratio(self, width: int, height: int) -> str:
|
||||
# Define the available aspect ratios
|
||||
aspect_ratios = {
|
||||
"1:2": 0.5,
|
||||
"2:3": 0.6667,
|
||||
"3:4": 0.75,
|
||||
"4:5": 0.8,
|
||||
"1:1": 1,
|
||||
"5:4": 1.25,
|
||||
"4:3": 1.3333,
|
||||
"3:2": 1.5,
|
||||
"16:9": 1.7778,
|
||||
"7:4": 1.75,
|
||||
"2:1": 2
|
||||
}
|
||||
|
||||
# Calculate the input aspect ratio
|
||||
input_ratio = width / height
|
||||
|
||||
# Find the nearest aspect ratio
|
||||
nearest_ratio = min(aspect_ratios.items(), key=lambda x: abs(x[1] - input_ratio))
|
||||
|
||||
# Return the formatted string
|
||||
return f"--ar {nearest_ratio[0]}"
|
||||
|
||||
def paint(
|
||||
self,
|
||||
positive_prompt,
|
||||
negative_prompt,
|
||||
sampler_name="Euler",
|
||||
seed=-1,
|
||||
scale=7.5,
|
||||
steps=20,
|
||||
img2img_denoising_strength=0.9,
|
||||
width=512,
|
||||
height=512,
|
||||
restore_faces=True,
|
||||
output_path=None
|
||||
):
|
||||
if output_path is None:
|
||||
output_path = self.output_path
|
||||
|
||||
try:
|
||||
# Send prompt and get initial response
|
||||
positive_prompt += self.get_nearest_aspect_ratio(width, height)
|
||||
initial_response = self.send_prompt_with_retry(positive_prompt, self.service_config.retries)
|
||||
message_id = initial_response.get("messageId")
|
||||
if not message_id:
|
||||
raise ValueError("No messageId returned from initial prompt")
|
||||
|
||||
# Poll progress until image generation is done
|
||||
progress_response = self.poll_progress(message_id, self.service_config.timeout, self.service_config.interval)
|
||||
if "error" in progress_response:
|
||||
raise ValueError(progress_response["error"])
|
||||
|
||||
if width<=1024:
|
||||
file_names = self.download_image(progress_response["uri"], output_path, True)
|
||||
|
||||
return file_names[0], {"prompt":positive_prompt, "negative_prompt":negative_prompt}
|
||||
|
||||
# Upscale the generated image
|
||||
upscale_response = self.upscale_image(message_id, "U1")
|
||||
message_id = upscale_response.get("messageId")
|
||||
if not message_id:
|
||||
raise ValueError("No messageId returned from initial prompt")
|
||||
|
||||
# Poll progress until image generation is done
|
||||
progress_response = self.poll_progress(message_id, self.service_config.timeout, self.service_config.interval)
|
||||
if "error" in progress_response:
|
||||
raise ValueError(progress_response["error"])
|
||||
|
||||
file_name = self.download_image(progress_response["uri"], output_path)
|
||||
return file_name, {"prompt":positive_prompt, "negative_prompt":negative_prompt}
|
||||
|
||||
except Exception as e:
|
||||
trace_exception(e)
|
||||
ASCIIColors.error(f"An error occurred: {e}")
|
||||
return "", {"prompt":positive_prompt, "negative_prompt":negative_prompt}
|
||||
|
||||
@staticmethod
|
||||
def get(app:LollmsApplication):
|
||||
return LollmsMidjourney
|
File diff suppressed because it is too large
Load Diff
@ -1,181 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
import pipmaster as pm
|
||||
|
||||
# Install required libraries if not already present
|
||||
if not pm.is_installed("torch"):
|
||||
pm.install_multiple(["torch","torchvision"," torchaudio"], "https://download.pytorch.org/whl/cu118") # Adjust CUDA version as needed
|
||||
if not pm.is_installed("diffusers"):
|
||||
pm.install("diffusers")
|
||||
if not pm.is_installed("transformers"):
|
||||
pm.install("transformers")
|
||||
if not pm.is_installed("accelerate"):
|
||||
pm.install("accelerate")
|
||||
if not pm.is_installed("imageio-ffmpeg"):
|
||||
pm.install("imageio-ffmpeg")
|
||||
|
||||
import torch
|
||||
from diffusers import CogVideoXPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.main_config import LOLLMSConfig
|
||||
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
|
||||
from lollms.utilities import find_next_available_filename
|
||||
from lollms.ttv import LollmsTTV
|
||||
from ascii_colors import ASCIIColors
|
||||
|
||||
class LollmsCogVideoX(LollmsTTV):
|
||||
"""
|
||||
LollmsCogVideoX is an implementation of LollmsTTV using CogVideoX for text-to-video generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: LollmsApplication,
|
||||
output_folder: str | Path = None
|
||||
):
|
||||
# Define service configuration
|
||||
service_config = TypedConfig(
|
||||
ConfigTemplate([
|
||||
{"name": "model_name", "type": "str", "value": "THUDM/CogVideoX-2b", "options": ["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"], "help": "CogVideoX model to use"},
|
||||
{"name": "use_gpu", "type": "bool", "value": True, "help": "Use GPU if available"},
|
||||
{"name": "dtype", "type": "str", "value": "float16", "options": ["float16", "bfloat16"], "help": "Data type for model precision"},
|
||||
]),
|
||||
BaseConfig(config={
|
||||
"model_name": "THUDM/CogVideoX-2b", # Default to 2B model (less VRAM-intensive)
|
||||
"use_gpu": True,
|
||||
"dtype": "float16",
|
||||
})
|
||||
)
|
||||
super().__init__("cogvideox", app, service_config, output_folder)
|
||||
|
||||
# Initialize CogVideoX pipeline
|
||||
self.pipeline = None
|
||||
self.load_pipeline()
|
||||
|
||||
def load_pipeline(self):
|
||||
"""Loads or reloads the CogVideoX pipeline based on config."""
|
||||
try:
|
||||
dtype = torch.float16 if self.service_config.dtype == "float16" else torch.bfloat16
|
||||
self.pipeline = CogVideoXPipeline.from_pretrained(
|
||||
self.service_config.model_name,
|
||||
torch_dtype=dtype
|
||||
)
|
||||
if self.service_config.use_gpu and torch.cuda.is_available():
|
||||
self.pipeline.to("cuda")
|
||||
self.pipeline.enable_model_cpu_offload() # Optimize VRAM usage
|
||||
else:
|
||||
ASCIIColors.warning("GPU not available or disabled. Running on CPU (slower).")
|
||||
ASCIIColors.success(f"Loaded CogVideoX model: {self.service_config.model_name}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load CogVideoX pipeline: {str(e)}")
|
||||
|
||||
def settings_updated(self):
|
||||
"""Reloads the pipeline if settings change."""
|
||||
self.load_pipeline()
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
model_name: str = "",
|
||||
height: int = 480,
|
||||
width: int = 720,
|
||||
steps: int = 50,
|
||||
seed: int = -1,
|
||||
nb_frames: int = 49,
|
||||
output_dir: str | Path = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generates a video from a text prompt using CogVideoX.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt describing the video content.
|
||||
negative_prompt (Optional[str]): Ignored (CogVideoX doesn't support it natively).
|
||||
model_name (str): Overrides config model if provided (optional).
|
||||
height (int): Desired height of the video (default 480).
|
||||
width (int): Desired width of the video (default 720).
|
||||
steps (int): Number of inference steps (default 50).
|
||||
seed (int): Random seed (default -1 for random).
|
||||
nb_frames (int): Number of frames (default 49, ~6 seconds at 8 fps).
|
||||
output_dir (str | Path): Optional custom output directory.
|
||||
|
||||
Returns:
|
||||
str: Path to the generated video file.
|
||||
"""
|
||||
output_path = Path(output_dir) if output_dir else self.output_folder
|
||||
output_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Handle unsupported parameters
|
||||
if negative_prompt:
|
||||
ASCIIColors.warning("Warning: CogVideoX does not support negative prompts. Ignoring negative_prompt.")
|
||||
if model_name and model_name != self.service_config.model_name:
|
||||
ASCIIColors.warning(f"Overriding config model {self.service_config.model_name} with {model_name}")
|
||||
self.service_config.model_name = model_name
|
||||
self.load_pipeline()
|
||||
|
||||
# Generation parameters
|
||||
gen_params = {
|
||||
"prompt": prompt,
|
||||
"num_frames": nb_frames,
|
||||
"num_inference_steps": steps,
|
||||
"guidance_scale": 6.0, # Default value from CogVideoX docs
|
||||
"height": height,
|
||||
"width": width,
|
||||
}
|
||||
if seed != -1:
|
||||
gen_params["generator"] = torch.Generator(device="cuda" if self.service_config.use_gpu else "cpu").manual_seed(seed)
|
||||
|
||||
# Generate video
|
||||
try:
|
||||
ASCIIColors.info("Generating video with CogVideoX...")
|
||||
start_time = time.time()
|
||||
video_frames = self.pipeline(**gen_params).frames[0] # CogVideoX returns a list of frame batches
|
||||
output_filename = find_next_available_filename(output_path, "cogvideox_output.mp4")
|
||||
export_to_video(video_frames, output_filename, fps=8)
|
||||
elapsed_time = time.time() - start_time
|
||||
ASCIIColors.success(f"Video generated and saved to {output_filename} in {elapsed_time:.2f} seconds")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to generate video: {str(e)}")
|
||||
|
||||
return str(output_filename)
|
||||
|
||||
def generate_video_by_frames(self, prompts: List[str], frames: List[int], negative_prompt: str, fps: int = 8,
|
||||
num_inference_steps: int = 50, guidance_scale: float = 6.0,
|
||||
seed: Optional[int] = None) -> str:
|
||||
"""
|
||||
Generates a video from a list of prompts. Since CogVideoX doesn't natively support multi-prompt videos,
|
||||
this concatenates prompts into a single description.
|
||||
|
||||
Args:
|
||||
prompts (List[str]): List of prompts for each segment.
|
||||
frames (List[int]): Number of frames per segment (summed to total frames).
|
||||
negative_prompt (str): Ignored.
|
||||
fps (int): Frames per second (default 8).
|
||||
num_inference_steps (int): Inference steps (default 50).
|
||||
guidance_scale (float): Guidance scale (default 6.0).
|
||||
seed (Optional[int]): Random seed.
|
||||
|
||||
Returns:
|
||||
str: Path to the generated video file.
|
||||
"""
|
||||
if not prompts or not frames:
|
||||
raise ValueError("Prompts and frames lists cannot be empty.")
|
||||
|
||||
# Combine prompts into a single narrative
|
||||
combined_prompt = " ".join(prompts)
|
||||
total_frames = sum(frames)
|
||||
|
||||
return self.generate_video(
|
||||
prompt=combined_prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
steps=num_inference_steps,
|
||||
seed=seed if seed is not None else -1,
|
||||
nb_frames=total_frames
|
||||
)
|
||||
|
||||
def getModels(self):
|
||||
"""Returns available CogVideoX models."""
|
||||
return ["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"]
|
@ -1,163 +0,0 @@
|
||||
import os
|
||||
import time
|
||||
import requests
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
import pipmaster as pm
|
||||
if not pm.is_installed("lumaai"):
|
||||
pm.install("lumaai")
|
||||
from lumaai import LumaAI
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.main_config import LOLLMSConfig
|
||||
from lollms.config import TypedConfig
|
||||
from lollms.utilities import find_next_available_filename
|
||||
from lollms.service import LollmsSERVICE
|
||||
from lollms.ttv import LollmsTTV
|
||||
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
|
||||
from ascii_colors import ASCIIColors
|
||||
class LollmsLumaLabs(LollmsTTV):
|
||||
"""
|
||||
LollmsLumaLabs is an implementation of LollmsTTV using LumaAI for text-to-image generation.
|
||||
Note: LumaAI currently supports image generation, so video output will be limited to single-frame representations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: LollmsApplication,
|
||||
output_folder: str | Path = None
|
||||
):
|
||||
|
||||
# Initialize LumaAI client
|
||||
api_key = os.environ.get("LUMAAI_API_KEY")
|
||||
service_config = TypedConfig(
|
||||
ConfigTemplate([
|
||||
{"name":"api_key", "type":"str", "value":api_key, "help":"A valid Novita AI key to generate text using anthropic api"},
|
||||
]),
|
||||
BaseConfig(config={
|
||||
"api_key": "", # use avx2
|
||||
})
|
||||
)
|
||||
super().__init__("lumalabs", app, service_config, output_folder)
|
||||
try:
|
||||
self.client = LumaAI(auth_token=self.service_config.api_key)
|
||||
except:
|
||||
ASCIIColors.error("Couldn't create a client")
|
||||
self.client = None
|
||||
|
||||
def settings_updated(self):
|
||||
try:
|
||||
self.client = LumaAI(auth_token=self.service_config.api_key)
|
||||
except:
|
||||
ASCIIColors.error("Couldn't create a client")
|
||||
self.client = None
|
||||
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
model_name: str = "",
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
steps: int = 20,
|
||||
seed: int = -1,
|
||||
nb_frames: int = None,
|
||||
output_dir: str | Path = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generates a 'video' from a text prompt using LumaAI. Currently limited to a single image due to API constraints.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt describing the content.
|
||||
negative_prompt (Optional[str]): Text describing elements to avoid (not supported by LumaAI, ignored).
|
||||
model_name (str): Model name (not supported by LumaAI, ignored).
|
||||
height (int): Desired height of the output image (default 512, LumaAI may override).
|
||||
width (int): Desired width of the output image (default 512, LumaAI may override).
|
||||
steps (int): Number of inference steps (default 20, ignored by LumaAI).
|
||||
seed (int): Random seed for reproducibility (default -1, ignored by LumaAI).
|
||||
nb_frames (int): Number of frames (default None, limited to 1 due to LumaAI image-only support).
|
||||
output_dir (str | Path): Optional custom output directory.
|
||||
|
||||
Returns:
|
||||
str: Path to the generated image file (single-frame 'video').
|
||||
"""
|
||||
output_path = Path(output_dir) if output_dir else self.output_folder
|
||||
output_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
# Warn about unsupported features
|
||||
if negative_prompt:
|
||||
ASCIIColors.warning("Warning: LumaAI does not support negative prompts. Ignoring negative_prompt.")
|
||||
if model_name:
|
||||
ASCIIColors.warning("Warning: LumaAI does not support model selection in this implementation. Ignoring model_name.")
|
||||
if steps != 20:
|
||||
ASCIIColors.warning("Warning: LumaAI controls inference steps internally. Ignoring steps parameter.")
|
||||
if seed != -1:
|
||||
ASCIIColors.warning("Warning: LumaAI does not support seed specification. Ignoring seed.")
|
||||
if nb_frames and nb_frames > 1:
|
||||
ASCIIColors.warning("Warning: LumaAI only supports single-image generation. Generating 1 frame instead of requested nb_frames.")
|
||||
|
||||
# Note: LumaAI's API (as shown) doesn't support width/height directly in the provided example,
|
||||
# but we'll include them in case the API supports it in a newer version
|
||||
generation_params = {
|
||||
"prompt": prompt,
|
||||
# Uncomment and use these if LumaAI supports them in the future:
|
||||
# "height": height,
|
||||
# "width": width,
|
||||
}
|
||||
|
||||
# Create generation request
|
||||
try:
|
||||
generation = self.client.generations.image.create(**generation_params)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to initiate generation: {str(e)}")
|
||||
|
||||
# Poll for completion
|
||||
completed = False
|
||||
while not completed:
|
||||
try:
|
||||
generation = self.client.generations.get(id=generation.id)
|
||||
if generation.state == "completed":
|
||||
completed = True
|
||||
elif generation.state == "failed":
|
||||
raise RuntimeError(f"Generation failed: {generation.failure_reason}")
|
||||
print("Dreaming...")
|
||||
time.sleep(2)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error polling generation status: {str(e)}")
|
||||
|
||||
# Download the image
|
||||
image_url = generation.assets.image
|
||||
output_filename = find_next_available_filename(output_path, f"{generation.id}.jpg")
|
||||
|
||||
try:
|
||||
response = requests.get(image_url, stream=True)
|
||||
response.raise_for_status()
|
||||
with open(output_filename, 'wb') as file:
|
||||
file.write(response.content)
|
||||
print(f"File downloaded as {output_filename}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to download image: {str(e)}")
|
||||
|
||||
return str(output_filename)
|
||||
|
||||
def generate_video_by_frames(self, prompts: List[str], frames: List[int], negative_prompt: str, fps: int = 8,
|
||||
num_inference_steps: int = 50, guidance_scale: float = 6.0,
|
||||
seed: Optional[int] = None) -> str:
|
||||
"""
|
||||
Generates a 'video' from a list of prompts. Since LumaAI only supports single images,
|
||||
this will generate the first prompt's image and return it as a static representation.
|
||||
"""
|
||||
if not prompts:
|
||||
raise ValueError("Prompts list cannot be empty.")
|
||||
|
||||
return self.generate_video(
|
||||
prompt=prompts[0],
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed if seed is not None else -1
|
||||
)
|
||||
|
||||
def getModels(self):
|
||||
"""
|
||||
Gets the list of models. LumaAI doesn't expose model selection, so returns a placeholder.
|
||||
"""
|
||||
return ["LumaAI_Default"]
|
@ -1,371 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Dict, Any
|
||||
from lollms.ttv import LollmsTTV
|
||||
from lollms.app import LollmsApplication
|
||||
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
|
||||
from lollms.utilities import find_next_available_filename
|
||||
import requests
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
class LollmsNovitaAITextToVideo(LollmsTTV):
|
||||
"""
|
||||
A binding for the Novita.ai Text-to-Video API.
|
||||
This class allows generating videos from text prompts using the Novita.ai service.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
app:LollmsApplication,
|
||||
output_folder:str|Path=None
|
||||
):
|
||||
"""
|
||||
Initializes the NovitaAITextToVideo binding.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key for authentication.
|
||||
base_url (str): The base URL for the Novita.ai API. Defaults to "https://api.novita.ai/v3/async".
|
||||
"""
|
||||
# Check for the NOVITA_AI_KEY environment variable if no API key is provided
|
||||
api_key = os.getenv("NOVITA_AI_KEY","")
|
||||
service_config = TypedConfig(
|
||||
ConfigTemplate([
|
||||
{"name":"api_key", "type":"str", "value":api_key, "help":"A valid Novita AI key to generate text using anthropic api"},
|
||||
{"name":"generation_engine","type":"str","value":"stable_diffusion", "options": ["stable_diffusion", "hunyuan-video-fast", "wan-t2v"], "help":"The engine name"},
|
||||
{"name":"sd_model_name","type":"str","value":"darkSushiMixMix_225D_64380.safetensors", "options": ["darkSushiMixMix_225D_64380.safetensors"], "help":"The model name"},
|
||||
{"name":"n_frames","type":"int","value":129, "help":"The number of frames in the video"},
|
||||
{"name":"guidance_scale", "type":"float", "value":7.5, "help":"The guidance scale for the generation"},
|
||||
{"name":"loras", "type":"str", "value":None, "help":"List of LoRA configurations"},
|
||||
{"name":"embeddings", "type":"str", "value":None, "help":"List of embedding configurations"},
|
||||
{"name":"closed_loop", "type":"bool", "value":False, "help":"Whether to use closed loop generation"},
|
||||
{"name":"clip_skip", "type":"int", "value":0, "help":"Number of layers to skip in CLIP"}
|
||||
]),
|
||||
BaseConfig(config={
|
||||
"api_key": "", # use avx2
|
||||
})
|
||||
)
|
||||
|
||||
super().__init__("novita_ai", app, service_config,output_folder)
|
||||
self.sd_model_name = self.service_config.sd_model_name
|
||||
self.base_url = "https://api.novita.ai/v3/async"
|
||||
|
||||
models = self.getModels()
|
||||
service_config.config_template["sd_model_name"]["options"] = [model["model_name"] for model in models]
|
||||
|
||||
def settings_updated(self):
|
||||
models = self.getModels()
|
||||
self.service_config.config_template["sd_model_name"]["options"] = models
|
||||
|
||||
def getModels(self):
|
||||
"""
|
||||
Gets the list of models
|
||||
"""
|
||||
url = "https://api.novita.ai/v3/model"
|
||||
headers = {
|
||||
"Content-Type": "<content-type>",
|
||||
"Authorization": f"Bearer {self.service_config.api_key}"
|
||||
}
|
||||
|
||||
response = requests.request("GET", url, headers=headers)
|
||||
js = response.json()
|
||||
return js["models"]
|
||||
|
||||
def generate_video(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
model_name: str = "",
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
steps: int = 20,
|
||||
seed: int = -1,
|
||||
nb_frames: int = None,
|
||||
output_dir:str | Path =None,
|
||||
) -> str:
|
||||
"""
|
||||
Generates a video from text prompts using the Novita.ai API.
|
||||
|
||||
Args:
|
||||
model_name (str): Name of the model checkpoint.
|
||||
height (int): Height of the video, range [256, 1024].
|
||||
width (int): Width of the video, range [256, 1024].
|
||||
steps (int): Number of denoising steps, range [1, 50].
|
||||
prompts (List[Dict[str, Any]]): List of prompts with frames and text descriptions.
|
||||
negative_prompt (Optional[str]): Text input to avoid in the video. Defaults to None.
|
||||
seed (int): Random seed for reproducibility. Defaults to -1.
|
||||
guidance_scale (Optional[float]): Controls adherence to the prompt. Defaults to None.
|
||||
loras (Optional[List[Dict[str, Any]]]): List of LoRA parameters. Defaults to None.
|
||||
embeddings (Optional[List[Dict[str, Any]]]): List of embeddings. Defaults to None.
|
||||
closed_loop (Optional[bool]): Controls animation loop behavior. Defaults to None.
|
||||
clip_skip (Optional[int]): Number of layers to skip during optimization. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The task_id for retrieving the generated video.
|
||||
"""
|
||||
if output_dir is None:
|
||||
output_dir = self.output_folder
|
||||
|
||||
if nb_frames is None:
|
||||
nb_frames =self.service_config.n_frames
|
||||
|
||||
if self.service_config.generation_engine=="hunyuan-video-fast":
|
||||
width, height, nb_frames, steps = self.pin_dimensions_frames_steps(width, height, nb_frames, steps)
|
||||
url = "https://api.novita.ai/v3/async/hunyuan-video-fast"
|
||||
|
||||
payload = {
|
||||
"model_name": "hunyuan-video-fast",
|
||||
"width": width,
|
||||
"height": height,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"prompt": prompt,
|
||||
"frames": nb_frames
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.service_config.api_key}"
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
elif self.service_config.generation_engine=="wan-t2v":
|
||||
width, height, nb_frames, steps = self.pin_dimensions_frames_steps(width, height, nb_frames, steps)
|
||||
url = "https://api.novita.ai/v3/async/wan-t2v"
|
||||
|
||||
payload = {
|
||||
"model_name": "wan-t2v",
|
||||
"width": width,
|
||||
"height": height,
|
||||
"seed": seed,
|
||||
"steps": steps,
|
||||
"prompt": prompt,
|
||||
"frames": nb_frames
|
||||
}
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.service_config.api_key}"
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
|
||||
elif self.service_config.generation_engine=="stable_diffusion":
|
||||
print(response.text)
|
||||
if model_name=="":
|
||||
model_name = self.sd_model_name
|
||||
|
||||
|
||||
url = f"{self.base_url}/txt2video"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.service_config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"extra": {
|
||||
"response_video_type": "mp4", # gif
|
||||
"enterprise_plan": {"enabled": False}
|
||||
},
|
||||
"sd_model_name": model_name,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"prompts": [
|
||||
{
|
||||
"frames": nb_frames,
|
||||
"prompt": prompt
|
||||
}
|
||||
],
|
||||
"negative_prompt": negative_prompt,
|
||||
"guidance_scale": self.service_config.guidance_scale,
|
||||
"seed": seed,
|
||||
"loras": self.service_config.loras,
|
||||
"embeddings": self.service_config.embeddings,
|
||||
"closed_loop": self.service_config.closed_loop,
|
||||
"clip_skip": self.service_config.clip_skip
|
||||
}
|
||||
# Remove None values from the payload to avoid sending null fields
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
response = requests.post(url, headers=headers, data=json.dumps(payload))
|
||||
else:
|
||||
return "Unsupported engine name"
|
||||
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
task_id = response.json().get("task_id")
|
||||
|
||||
|
||||
url = f"https://api.novita.ai/v3/async/task-result?task_id={task_id}"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.service_config.api_key}",
|
||||
}
|
||||
done = False
|
||||
while not done:
|
||||
response = requests.request("GET", url, headers=headers)
|
||||
infos = response.json()
|
||||
if infos["task"]["status"]=="TASK_STATUS_SUCCEED" or infos["task"]["status"]=="TASK_STATUS_FAILED":
|
||||
done = True
|
||||
time.sleep(1)
|
||||
if infos["task"]["status"]=="TASK_STATUS_SUCCEED":
|
||||
if output_dir:
|
||||
output_dir = Path(output_dir)
|
||||
file_name = output_dir/find_next_available_filename(output_dir, "vid_novita_","mp4") # You can change the filename if needed
|
||||
self.download_video(infos["videos"][0]["video_url"], file_name )
|
||||
return file_name
|
||||
return None
|
||||
|
||||
def pin_dimensions_frames_steps(self, width, height, nframes, steps):
|
||||
# Supported widths
|
||||
standard_widths = [480, 640, 720, 864, 1280]
|
||||
|
||||
# Width-to-height mapping
|
||||
width_height_map = {
|
||||
480: [640, 864], # 480 width supports 640 or 864 height
|
||||
640: [480], # 640 width supports 480 height
|
||||
720: [1280], # 720 width supports 1280 height
|
||||
864: [480], # 864 width supports 480 height
|
||||
1280: [720] # 1280 width supports 720 height
|
||||
}
|
||||
|
||||
# Supported nframes
|
||||
standard_nframes = [85, 129]
|
||||
|
||||
# Supported steps range
|
||||
min_steps, max_steps = 2, 30
|
||||
|
||||
# Pin the width to the nearest standard width
|
||||
pinned_width = min(standard_widths, key=lambda x: abs(x - width))
|
||||
|
||||
# Pin the height to the nearest supported height for the pinned width
|
||||
supported_heights = width_height_map[pinned_width]
|
||||
pinned_height = min(supported_heights, key=lambda x: abs(x - height))
|
||||
|
||||
# Pin the nframes to the nearest standard nframes
|
||||
pinned_nframes = min(standard_nframes, key=lambda x: abs(x - nframes))
|
||||
|
||||
# Pin the steps to the valid range (2 to 30)
|
||||
pinned_steps = max(min_steps, min(max_steps, steps))
|
||||
|
||||
return pinned_width, pinned_height, pinned_nframes, pinned_steps
|
||||
def generate_video_by_frames(self, prompts: List[str], frames: List[int], negative_prompt: str, fps: int = 8,
|
||||
sd_model_name: str = "",
|
||||
height: int = 512,
|
||||
width: int = 512,
|
||||
steps: int = 20,
|
||||
seed: int = -1,
|
||||
output_dir:str | Path =None,
|
||||
) -> str:
|
||||
"""
|
||||
Generates a video from a list of prompts and corresponding frames.
|
||||
|
||||
Args:
|
||||
prompts (List[str]): List of text prompts for each frame.
|
||||
frames (List[int]): List of frame indices corresponding to each prompt.
|
||||
negative_prompt (str): Text describing elements to avoid in the video.
|
||||
fps (int): Frames per second. Default is 8.
|
||||
num_inference_steps (int): Number of steps for the model to infer. Default is 50.
|
||||
guidance_scale (float): Controls how closely the model adheres to the prompt. Default is 6.0.
|
||||
seed (Optional[int]): Random seed for reproducibility. Default is None.
|
||||
|
||||
Returns:
|
||||
str: The path to the generated video.
|
||||
"""
|
||||
if sd_model_name=="":
|
||||
sd_model_name = self.sd_model_name
|
||||
if output_dir is None:
|
||||
output_dir = self.output_folder
|
||||
|
||||
|
||||
url = f"{self.base_url}/txt2video"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.service_config.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"extra": {
|
||||
"response_video_type": "mp4", # gif
|
||||
"enterprise_plan": {"enabled": False}
|
||||
},
|
||||
"sd_model_name": sd_model_name,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"steps": steps,
|
||||
"prompts": [
|
||||
{
|
||||
"frames": nb_frames,
|
||||
"prompt": prompt,
|
||||
|
||||
} for nb_frames, prompt in zip(prompts, frames)
|
||||
],
|
||||
"negative_prompt": negative_prompt,
|
||||
"guidance_scale": guidance_scale,
|
||||
"seed": seed,
|
||||
"loras": [],
|
||||
"embeddings": [],
|
||||
"closed_loop": closed_loop,
|
||||
"clip_skip": clip_skip
|
||||
}
|
||||
# Remove None values from the payload to avoid sending null fields
|
||||
payload = {k: v for k, v in payload.items() if v is not None}
|
||||
|
||||
response = requests.post(url, headers=headers, data=json.dumps(payload))
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
task_id = response.json().get("task_id")
|
||||
|
||||
|
||||
url = f"https://api.novita.ai/v3/async/task-result?task_id={task_id}"
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.service_config.api_key}",
|
||||
}
|
||||
done = False
|
||||
while not done:
|
||||
response = requests.request("GET", url, headers=headers)
|
||||
infos = response.json()
|
||||
if infos["task"]["status"]=="TASK_STATUS_SUCCEED" or infos["task"]["status"]=="TASK_STATUS_FAILED":
|
||||
done = True
|
||||
time.sleep(1)
|
||||
if infos["task"]["status"]=="TASK_STATUS_SUCCEED":
|
||||
if output_dir:
|
||||
output_dir = Path(output_dir)
|
||||
file_name = output_dir/find_next_available_filename(output_dir, "vid_novita_","mp4") # You can change the filename if needed
|
||||
self.download_video(infos["videos"][0]["video_url"], file_name )
|
||||
return file_name
|
||||
return None
|
||||
|
||||
def get_task_result(self, task_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves the result of a video generation task using the task_id.
|
||||
|
||||
Args:
|
||||
task_id (str): The task_id returned by the generate_video method.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: The task result containing the video URL and other details.
|
||||
"""
|
||||
url = f"{self.base_url}/task-result"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.service_config.api_key}",
|
||||
}
|
||||
params = {
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
|
||||
return response.json()
|
||||
|
||||
def download_video(self, video_url: str, save_path: Path) -> None:
|
||||
"""
|
||||
Downloads the generated video from the provided URL and saves it to the specified path.
|
||||
|
||||
Args:
|
||||
video_url (str): The URL of the video to download.
|
||||
save_path (Path): The path where the video will be saved.
|
||||
"""
|
||||
response = requests.get(video_url)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors
|
||||
|
||||
with open(save_path, "wb") as file:
|
||||
file.write(response.content)
|
179
lollms/ttm.py
179
lollms/ttm.py
@ -2,14 +2,15 @@
|
||||
Lollms TTM Module
|
||||
=================
|
||||
|
||||
This module is part of the Lollms library, designed to provide Text-to-Music (TTM) functionalities within the LollmsApplication framework. The base class `LollmsTTM` is intended to be inherited and implemented by other classes that provide specific TTM functionalities.
|
||||
This module is part of the Lollms library, designed to provide Text-to-Music (TTM) functionalities within the LollmsApplication framework. The base class `LollmsTTM` is intended to be inherited and implemented by other classes that provide specific TTM functionalities using various models or APIs.
|
||||
|
||||
Author: ParisNeo, a computer geek passionate about AI
|
||||
Inspired by the LollmsTTI structure.
|
||||
"""
|
||||
|
||||
from lollms.app import LollmsApplication
|
||||
from pathlib import Path
|
||||
from typing import List, Dict
|
||||
from typing import List, Dict, Optional
|
||||
from lollms.main_config import LOLLMSConfig
|
||||
from lollms.config import TypedConfig
|
||||
from lollms.service import LollmsSERVICE
|
||||
@ -17,100 +18,188 @@ from lollms.service import LollmsSERVICE
|
||||
class LollmsTTM(LollmsSERVICE):
|
||||
"""
|
||||
LollmsTTM is a base class for implementing Text-to-Music (TTM) functionalities within the LollmsApplication.
|
||||
Subclasses should implement the actual music generation logic by overriding the `generate` method.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name:str,
|
||||
name: str,
|
||||
app: LollmsApplication,
|
||||
service_config: TypedConfig,
|
||||
output_folder: str|Path=None
|
||||
output_folder: Optional[str | Path] = None
|
||||
):
|
||||
"""
|
||||
Initializes the LollmsTTI class with the given parameters.
|
||||
Initializes the LollmsTTM class with the given parameters.
|
||||
|
||||
Args:
|
||||
name (str): The unique name of the TTM service.
|
||||
app (LollmsApplication): The instance of the main Lollms application.
|
||||
model (str, optional): The TTI model to be used for image generation. Defaults to an empty string.
|
||||
api_key (str, optional): API key for accessing external TTI services. Defaults to an empty string.
|
||||
output_path (Path or str, optional): Path where the output image files will be saved. Defaults to None.
|
||||
service_config (TypedConfig): Configuration specific to this TTM service.
|
||||
output_folder (str | Path, optional): Path where the output audio files will be saved.
|
||||
If None, defaults to a subfolder named `name` within
|
||||
`app.lollms_paths.personal_outputs_path`.
|
||||
"""
|
||||
super().__init__(name, app, service_config)
|
||||
if output_folder is not None:
|
||||
self.output_folder = Path(output_folder)
|
||||
else:
|
||||
self.output_folder = app.lollms_paths.personal_outputs_path/name
|
||||
self.output_folder.mkdir(exist_ok=True, parents=True)
|
||||
# Default output path within the standard Lollms outputs structure
|
||||
self.output_folder = app.lollms_paths.personal_outputs_path / name
|
||||
|
||||
# Ensure the output directory exists
|
||||
self.output_folder.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
|
||||
def generate(self,
|
||||
positive_prompt: str,
|
||||
negative_prompt: str = "",
|
||||
duration=30,
|
||||
generation_engine=None,
|
||||
output_path = None) -> List[Dict[str, str]]:
|
||||
def generate(self,
|
||||
prompt: str,
|
||||
negative_prompt: str = "",
|
||||
duration_s: float = 10.0,
|
||||
seed: Optional[int] = None,
|
||||
# Add other common TTM parameters as needed by specific models
|
||||
# e.g., model_name: Optional[str] = None,
|
||||
# e.g., tempo_bpm: Optional[int] = None,
|
||||
# e.g., genre: Optional[str] = None,
|
||||
output_dir: Optional[str | Path] = None,
|
||||
output_file_name: Optional[str] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Generates images based on the given positive and negative prompts.
|
||||
Generates audio based on the given text prompt.
|
||||
|
||||
This method must be implemented by subclasses to perform the actual text-to-music generation.
|
||||
|
||||
Args:
|
||||
positive_prompt (str): The positive prompt describing the desired image.
|
||||
negative_prompt (str, optional): The negative prompt describing what should be avoided in the image. Defaults to an empty string.
|
||||
prompt (str): The positive prompt describing the desired music.
|
||||
negative_prompt (str, optional): A prompt describing elements to avoid in the music. Defaults to "".
|
||||
duration_s (float, optional): The desired duration of the generated audio in seconds. Defaults to 10.0.
|
||||
seed (int, optional): A seed for reproducibility. If None, a random seed may be used. Defaults to None.
|
||||
output_dir (str | Path, optional): Directory to save the output file(s). If None, uses self.output_folder.
|
||||
output_file_name (str, optional): Desired name for the output file (without extension).
|
||||
If None, a unique name will be generated.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: A list of dictionaries containing image paths, URLs, and metadata.
|
||||
List[Dict[str, str]]: A list of dictionaries, each containing details about a generated audio file.
|
||||
Expected keys might include 'path' (local file path), 'url' (if served),
|
||||
'prompt', 'duration_s', 'seed', 'format', etc.
|
||||
"""
|
||||
pass
|
||||
# Base implementation does nothing - subclasses must override this
|
||||
raise NotImplementedError("Subclasses must implement the 'generate' method.")
|
||||
|
||||
def generate_from_samples(self, positive_prompt: str, samples: List[str], negative_prompt: str = "") -> List[Dict[str, str]]:
|
||||
"""
|
||||
Generates images based on the given positive prompt and reference images.
|
||||
|
||||
Args:
|
||||
positive_prompt (str): The positive prompt describing the desired image.
|
||||
images (List[str]): A list of paths to reference images.
|
||||
negative_prompt (str, optional): The negative prompt describing what should be avoided in the image. Defaults to an empty string.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, str]]: A list of dictionaries containing image paths, URLs, and metadata.
|
||||
"""
|
||||
pass
|
||||
# Optional: Add methods for other TTM functionalities like music continuation, variation, etc.
|
||||
# def generate_continuation(self, audio_path: str | Path, prompt: str, ...):
|
||||
# pass
|
||||
#
|
||||
# def generate_variation(self, audio_path: str | Path, prompt: str, ...):
|
||||
# pass
|
||||
|
||||
@staticmethod
|
||||
def verify(app: LollmsApplication) -> bool:
|
||||
"""
|
||||
Verifies if the TTM service is available.
|
||||
Verifies if the TTM service's dependencies are met or if it's ready to use (e.g., API key configured).
|
||||
|
||||
This base implementation returns True. Subclasses should override this
|
||||
to perform actual checks (e.g., check for installed libraries, API connectivity).
|
||||
|
||||
Args:
|
||||
app (LollmsApplication): The instance of the main Lollms application.
|
||||
|
||||
Returns:
|
||||
bool: True if the service is available, False otherwise.
|
||||
bool: True if the service is considered available/verified, False otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def install(app: LollmsApplication) -> bool:
|
||||
"""
|
||||
Installs the necessary components for the TTM service.
|
||||
Installs the necessary components or dependencies for the TTM service.
|
||||
|
||||
This base implementation returns True. Subclasses should override this
|
||||
to perform actual installation steps (e.g., pip install required packages).
|
||||
|
||||
Args:
|
||||
app (LollmsApplication): The instance of the main Lollms application.
|
||||
|
||||
Returns:
|
||||
bool: True if the installation was successful, False otherwise.
|
||||
bool: True if the installation was successful (or not needed), False otherwise.
|
||||
"""
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
|
||||
@staticmethod
|
||||
def get(app: LollmsApplication) -> 'LollmsTTM':
|
||||
"""
|
||||
Returns the LollmsTTM class.
|
||||
Returns the LollmsTTM class type itself.
|
||||
|
||||
Used for discovery or instantiation purposes within the Lollms framework.
|
||||
|
||||
Args:
|
||||
app (LollmsApplication): The instance of the main Lollms application.
|
||||
|
||||
Returns:
|
||||
LollmsTTM: The LollmsTTM class.
|
||||
LollmsTTM: The LollmsTTM class type.
|
||||
"""
|
||||
return LollmsTTM
|
||||
|
||||
# Example of how a specific TTM implementation might inherit (Conceptual)
|
||||
# class MySpecificTTM(LollmsTTM):
|
||||
# def __init__(self, app: LollmsApplication, service_config: TypedConfig, output_folder: Optional[str | Path] = None):
|
||||
# super().__init__("my_specific_ttm", app, service_config, output_folder)
|
||||
# # Initialize specific model, API client, etc.
|
||||
# # self.model = ... load model based on service_config ...
|
||||
#
|
||||
# def generate(self, prompt: str, negative_prompt: str = "", duration_s: float = 10.0, seed: Optional[int] = None, ...) -> List[Dict[str, str]]:
|
||||
# # ... Actual implementation using self.model or an API ...
|
||||
# self.app.ShowBlockingMessage(f"Generating music for: {prompt}")
|
||||
# try:
|
||||
# # 1. Prepare parameters for the specific model/API
|
||||
# # 2. Call the generation function
|
||||
# # audio_data = self.model.generate(prompt=prompt, neg_prompt=negative_prompt, duration=duration_s, seed=seed, ...)
|
||||
# # 3. Determine output path and filename
|
||||
# output_path = Path(output_dir or self.output_folder)
|
||||
# output_path.mkdir(parents=True, exist_ok=True)
|
||||
# if output_file_name:
|
||||
# base_filename = output_file_name
|
||||
# else:
|
||||
# # Generate a unique filename (e.g., using timestamp or hash)
|
||||
# import time
|
||||
# base_filename = f"ttm_output_{int(time.time())}"
|
||||
#
|
||||
# # Assume generated format is WAV for example
|
||||
# full_output_path = output_path / f"{base_filename}.wav"
|
||||
#
|
||||
# # 4. Save the generated audio data to the file
|
||||
# # save_audio(audio_data, full_output_path) # Replace with actual saving logic
|
||||
#
|
||||
# # 5. Prepare the result dictionary
|
||||
# result = {
|
||||
# "path": str(full_output_path),
|
||||
# # "url": Optional URL if served by Lollms web server
|
||||
# "prompt": prompt,
|
||||
# "negative_prompt": negative_prompt,
|
||||
# "duration_s": duration_s, # Actual duration might differ slightly
|
||||
# "seed": seed,
|
||||
# "format": "wav"
|
||||
# }
|
||||
# self.app.HideBlockingMessage()
|
||||
# return [result]
|
||||
# except Exception as e:
|
||||
# self.app.HideBlockingMessage()
|
||||
# self.app.ShowError(f"Error generating music: {e}")
|
||||
# # Log the error properly
|
||||
# print(f"Error in MySpecificTTM.generate: {e}") # Use app.error or logging framework
|
||||
# return []
|
||||
#
|
||||
# @staticmethod
|
||||
# def verify(app: LollmsApplication) -> bool:
|
||||
# # Check if 'my_specific_library' is installed
|
||||
# try:
|
||||
# import my_specific_library
|
||||
# return True
|
||||
# except ImportError:
|
||||
# return False
|
||||
#
|
||||
# @staticmethod
|
||||
# def install(app: LollmsApplication) -> bool:
|
||||
# # Install 'my_specific_library'
|
||||
# return app.binding_pip_install("my_specific_library")
|
||||
#
|
||||
# @staticmethod
|
||||
# def get(app: LollmsApplication) -> 'LollmsTTM':
|
||||
# return MySpecificTTM
|
@ -927,7 +927,148 @@ def find_first_available_file_index(folder_path, prefix, extension=""):
|
||||
return available_number
|
||||
|
||||
|
||||
def find_first_available_file_path(folder_path, prefix, extension=""):
|
||||
"""
|
||||
Finds the full path for the first available filename in a folder,
|
||||
based on a prefix and an optional extension.
|
||||
|
||||
The numbering starts from 1 (e.g., prefix1.ext, prefix2.ext, ...).
|
||||
|
||||
Args:
|
||||
folder_path (str or Path): The path to the folder.
|
||||
The folder will be created if it doesn't exist.
|
||||
prefix (str): The desired file prefix.
|
||||
extension (str, optional): The desired file extension (including the dot, e.g., ".txt").
|
||||
Defaults to "".
|
||||
|
||||
Returns:
|
||||
Path: A Path object representing the first available file path
|
||||
(e.g., /path/to/folder/prefix1.txt if it doesn't exist).
|
||||
Returns None if the folder cannot be created or accessed.
|
||||
"""
|
||||
try:
|
||||
# Ensure folder_path is a Path object
|
||||
folder = Path(folder_path)
|
||||
|
||||
# Create the folder if it doesn't exist
|
||||
# os.makedirs(folder, exist_ok=True) # Using exist_ok=True prevents errors if it already exists
|
||||
# Using Pathlib's equivalent:
|
||||
folder.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
available_number = 1
|
||||
while True:
|
||||
# Construct the potential file path using an f-string
|
||||
potential_path = folder / f"{prefix}{available_number}{extension}"
|
||||
|
||||
# Check if this path already exists (works for files and directories)
|
||||
if not potential_path.exists():
|
||||
# If it doesn't exist, this is the first available path
|
||||
return potential_path
|
||||
else:
|
||||
# If it exists, increment the number and try the next one
|
||||
available_number += 1
|
||||
except OSError as e:
|
||||
print(f"Error accessing or creating folder {folder_path}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"An unexpected error occurred in find_first_available_file_path: {e}")
|
||||
return None
|
||||
|
||||
def is_file_path(path_string: Union[str, Path, None]) -> bool:
|
||||
"""
|
||||
Checks if a given string or Path object structurally resembles a file path.
|
||||
|
||||
This function performs basic checks:
|
||||
1. Is the input a non-empty string or Path?
|
||||
2. Does it contain path separators ('/' or '\\')?
|
||||
3. Does it likely have a filename component (doesn't end with a separator)?
|
||||
4. Does it potentially have an extension (contains a '.')?
|
||||
|
||||
It does NOT check if the path actually exists on the filesystem.
|
||||
It aims to distinguish plausible file paths from simple strings, URLs (basic check),
|
||||
or paths explicitly ending like directories.
|
||||
|
||||
Args:
|
||||
path_string: The string or Path object to check.
|
||||
|
||||
Returns:
|
||||
True if the string looks like a file path, False otherwise.
|
||||
"""
|
||||
# --- Basic Input Validation ---
|
||||
if path_string is None:
|
||||
return False
|
||||
|
||||
# Convert Path object to string for consistent processing
|
||||
if isinstance(path_string, Path):
|
||||
path_string = str(path_string)
|
||||
|
||||
if not isinstance(path_string, str):
|
||||
# If it's not None, not a Path, and not a string, it's invalid input
|
||||
return False
|
||||
|
||||
# Remove leading/trailing whitespace
|
||||
path_string = path_string.strip()
|
||||
|
||||
# Empty string is not a valid file path
|
||||
if not path_string:
|
||||
return False
|
||||
|
||||
# --- Structural Checks ---
|
||||
|
||||
# Very basic check to filter out obvious URLs (can be expanded if needed)
|
||||
if path_string.startswith(('http://', 'https://', 'ftp://', 'mailto:')):
|
||||
return False
|
||||
|
||||
# Check if it ends with a path separator (more likely a directory)
|
||||
# os.path.sep is the primary separator ('/' on Unix, '\' on Windows)
|
||||
# os.path.altsep is the alternative ('/' on Windows)
|
||||
ends_with_separator = path_string.endswith(os.path.sep)
|
||||
if os.path.altsep: # Check altsep only if it exists (it's None on Unix)
|
||||
ends_with_separator = ends_with_separator or path_string.endswith(os.path.altsep)
|
||||
|
||||
if ends_with_separator:
|
||||
return False # Paths ending in separators usually denote directories
|
||||
|
||||
# Check for the presence of path separators within the string
|
||||
has_separator = os.path.sep in path_string
|
||||
if os.path.altsep:
|
||||
has_separator = has_separator or os.path.altsep in path_string
|
||||
|
||||
# Use os.path.splitext to check for an extension
|
||||
# It splits "path/to/file.txt" into ("path/to/file", ".txt")
|
||||
# It splits "path/to/file" into ("path/to/file", "")
|
||||
# It splits "path/.bashrc" into ("path/.bashrc", "") - important edge case!
|
||||
# It splits "path/archive.tar.gz" into ("path/archive.tar", ".gz")
|
||||
base, extension = os.path.splitext(path_string)
|
||||
|
||||
# A simple filename like "file.txt" is a valid relative path
|
||||
# It won't have separators but will likely have an extension
|
||||
has_extension = bool(extension) and extension != '.' # Ensure extension is not just a single dot
|
||||
|
||||
# Check if the part *before* the extension (or the whole string if no extension)
|
||||
# contains a '.' which might indicate a hidden file like '.bashrc' when
|
||||
# there are no separators. We need the base name for this.
|
||||
filename = os.path.basename(path_string)
|
||||
is_likely_hidden_file = filename.startswith('.') and '.' not in filename[1:] and not has_separator
|
||||
|
||||
|
||||
# --- Decision Logic ---
|
||||
# It looks like a file path if:
|
||||
# 1. It contains separators (e.g., "folder/file", "folder/file.txt") OR
|
||||
# 2. It has a valid extension (e.g., "file.txt") OR
|
||||
# 3. It looks like a "hidden" file in the current directory (e.g., ".bashrc")
|
||||
# AND it doesn't end with a separator (checked earlier).
|
||||
if has_separator or has_extension or is_likely_hidden_file:
|
||||
# Further refinement: Avoid matching just "." or ".."
|
||||
if path_string == '.' or path_string == '..':
|
||||
return False
|
||||
return True
|
||||
else:
|
||||
# If it has no separators and no extension (e.g., "myfile"),
|
||||
# it's ambiguous - could be a directory name or a file without extension.
|
||||
# Let's default to False for this ambiguity unless separators are present.
|
||||
return False
|
||||
|
||||
# Prompting tools
|
||||
def detect_antiprompt(text:str, anti_prompts=["!@>"]) -> bool:
|
||||
@ -1544,3 +1685,51 @@ def remove_text_from_string(string: str, text_to_find:str):
|
||||
string = string[:index]
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def short_desc(text: str, max_length: int = 80) -> str:
|
||||
"""
|
||||
Creates a shortened description of a text string, adding ellipsis if truncated.
|
||||
|
||||
Tries to break at a word boundary (space) if possible within the length limit.
|
||||
|
||||
Args:
|
||||
text: The input string. Can be None.
|
||||
max_length: The maximum desired length of the output string (including ellipsis).
|
||||
Must be at least 4 to accommodate "...".
|
||||
|
||||
Returns:
|
||||
The shortened string, or the original string if it's already short enough.
|
||||
Returns an empty string if the input is None or not a string.
|
||||
"""
|
||||
if text is None:
|
||||
return ""
|
||||
|
||||
# Ensure input is treated as a string
|
||||
if not isinstance(text, str):
|
||||
try:
|
||||
text = str(text)
|
||||
except Exception:
|
||||
return "[Invalid Input]" # Or return "" depending on desired behavior
|
||||
|
||||
# If text is already short enough, return it as is.
|
||||
if len(text) <= max_length:
|
||||
return text
|
||||
|
||||
# Ensure max_length is usable
|
||||
if max_length < 4:
|
||||
# Cannot add ellipsis, just truncate hard
|
||||
return text[:max_length]
|
||||
|
||||
# Calculate the ideal truncation point before adding "..."
|
||||
trunc_point = max_length - 3
|
||||
|
||||
# Find the last space character at or before the truncation point
|
||||
last_space = text.rfind(' ', 0, trunc_point + 1) # Include trunc_point itself
|
||||
|
||||
if last_space != -1:
|
||||
# Found a space, truncate there
|
||||
return text[:last_space] + "..."
|
||||
else:
|
||||
# No space found in the initial part, hard truncate
|
||||
return text[:trunc_point] + "..."
|
Loading…
x
Reference in New Issue
Block a user