CogVideoX added to ttv

This commit is contained in:
Saifeddine ALOUI 2025-03-07 23:17:50 +01:00
parent 1ec1422b86
commit 4e9a06cc85
5 changed files with 313 additions and 166 deletions

View File

@ -182,6 +182,7 @@ class LollmsContextDetails:
" }",
"}",
"```",
"It is mandatory to use the function markdown tag (not json) or it won't be executed."
"Important Notes:",
"- **Always** enclose the function call in a `function` markdown code block.",
"- Make sure the content of the function markdown code block is a valid json.",

View File

@ -1,63 +1,181 @@
import os
import time
from pathlib import Path
from typing import List, Optional
import pipmaster as pm
import pkg_resources
# Check and install required packages
required_packages = [
["torch","","https://download.pytorch.org/whl/cu121"],
["diffusers","0.30.1",None],
["transformers","4.44.2",None],
["accelerate","0.33.0",None],
["imageio-ffmpeg","0.5.1",None]
]
for package, min_version, index_url in required_packages:
if not pm.is_installed(package):
pm.install_or_update(package, index_url)
else:
if min_version:
if pkg_resources.parse_version(pm.get_installed_version(package))< pkg_resources.parse_version(min_version):
pm.install_or_update(package, index_url)
# Install required libraries if not already present
if not pm.is_installed("torch"):
pm.install_multiple(["torch","torchvision"," torchaudio"], "https://download.pytorch.org/whl/cu118") # Adjust CUDA version as needed
if not pm.is_installed("diffusers"):
pm.install("diffusers")
if not pm.is_installed("transformers"):
pm.install("transformers")
if not pm.is_installed("accelerate"):
pm.install("accelerate")
if not pm.is_installed("imageio-ffmpeg"):
pm.install("imageio-ffmpeg")
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video
from typing import List, Optional
from abc import ABC, abstractmethod
from lollms.app import LollmsApplication
from lollms.main_config import LOLLMSConfig
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
from lollms.utilities import find_next_available_filename
from lollms.ttv import LollmsTTV
from ascii_colors import ASCIIColors
class CogVideoX(LollmsTTV):
def __init__(self, model_name: str = "THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16):
self.pipe = CogVideoXPipeline.from_pretrained(
model_name,
torch_dtype=torch_dtype
class LollmsCogVideoX(LollmsTTV):
"""
LollmsCogVideoX is an implementation of LollmsTTV using CogVideoX for text-to-video generation.
"""
def __init__(
self,
app: LollmsApplication,
output_folder: str | Path = None
):
# Define service configuration
service_config = TypedConfig(
ConfigTemplate([
{"name": "model_name", "type": "str", "value": "THUDM/CogVideoX-2b", "options": ["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"], "help": "CogVideoX model to use"},
{"name": "use_gpu", "type": "bool", "value": True, "help": "Use GPU if available"},
{"name": "dtype", "type": "str", "value": "float16", "options": ["float16", "bfloat16"], "help": "Data type for model precision"},
]),
BaseConfig(config={
"model_name": "THUDM/CogVideoX-2b", # Default to 2B model (less VRAM-intensive)
"use_gpu": True,
"dtype": "float16",
})
)
self.pipe.enable_model_cpu_offload()
self.pipe.vae.enable_tiling()
super().__init__("cogvideox", app, service_config, output_folder)
def generate_video(self, prompt: str, num_frames: int = 49, fps: int = 8,
num_inference_steps: int = 50, guidance_scale: float = 6.0,
seed: Optional[int] = None) -> str:
if seed is not None:
generator = torch.Generator(device="cuda").manual_seed(seed)
else:
generator = None
# Initialize CogVideoX pipeline
self.pipeline = None
self.load_pipeline()
video = self.pipe(
prompt=prompt,
num_videos_per_prompt=1,
num_inference_steps=num_inference_steps,
num_frames=num_frames,
guidance_scale=guidance_scale,
generator=generator,
).frames[0]
def load_pipeline(self):
"""Loads or reloads the CogVideoX pipeline based on config."""
try:
dtype = torch.float16 if self.service_config.dtype == "float16" else torch.bfloat16
self.pipeline = CogVideoXPipeline.from_pretrained(
self.service_config.model_name,
torch_dtype=dtype
)
if self.service_config.use_gpu and torch.cuda.is_available():
self.pipeline.to("cuda")
self.pipeline.enable_model_cpu_offload() # Optimize VRAM usage
else:
ASCIIColors.warning("GPU not available or disabled. Running on CPU (slower).")
ASCIIColors.success(f"Loaded CogVideoX model: {self.service_config.model_name}")
except Exception as e:
raise RuntimeError(f"Failed to load CogVideoX pipeline: {str(e)}")
output_path = "output.mp4"
export_to_video(video, output_path, fps=fps)
return output_path
def settings_updated(self):
"""Reloads the pipeline if settings change."""
self.load_pipeline()
# Usage example:
if __name__ == "__main__":
cogvideox = CogVideoX()
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes."
output_video = cogvideox.generate_video(prompt)
print(f"Video generated and saved to: {output_video}")
def generate_video(
self,
prompt: str,
negative_prompt: Optional[str] = None,
model_name: str = "",
height: int = 480,
width: int = 720,
steps: int = 50,
seed: int = -1,
nb_frames: int = 49,
output_dir: str | Path = None,
) -> str:
"""
Generates a video from a text prompt using CogVideoX.
Args:
prompt (str): The text prompt describing the video content.
negative_prompt (Optional[str]): Ignored (CogVideoX doesn't support it natively).
model_name (str): Overrides config model if provided (optional).
height (int): Desired height of the video (default 480).
width (int): Desired width of the video (default 720).
steps (int): Number of inference steps (default 50).
seed (int): Random seed (default -1 for random).
nb_frames (int): Number of frames (default 49, ~6 seconds at 8 fps).
output_dir (str | Path): Optional custom output directory.
Returns:
str: Path to the generated video file.
"""
output_path = Path(output_dir) if output_dir else self.output_folder
output_path.mkdir(exist_ok=True, parents=True)
# Handle unsupported parameters
if negative_prompt:
ASCIIColors.warning("Warning: CogVideoX does not support negative prompts. Ignoring negative_prompt.")
if model_name and model_name != self.service_config.model_name:
ASCIIColors.warning(f"Overriding config model {self.service_config.model_name} with {model_name}")
self.service_config.model_name = model_name
self.load_pipeline()
# Generation parameters
gen_params = {
"prompt": prompt,
"num_frames": nb_frames,
"num_inference_steps": steps,
"guidance_scale": 6.0, # Default value from CogVideoX docs
"height": height,
"width": width,
}
if seed != -1:
gen_params["generator"] = torch.Generator(device="cuda" if self.service_config.use_gpu else "cpu").manual_seed(seed)
# Generate video
try:
ASCIIColors.info("Generating video with CogVideoX...")
start_time = time.time()
video_frames = self.pipeline(**gen_params).frames[0] # CogVideoX returns a list of frame batches
output_filename = find_next_available_filename(output_path, "cogvideox_output.mp4")
export_to_video(video_frames, output_filename, fps=8)
elapsed_time = time.time() - start_time
ASCIIColors.success(f"Video generated and saved to {output_filename} in {elapsed_time:.2f} seconds")
except Exception as e:
raise RuntimeError(f"Failed to generate video: {str(e)}")
return str(output_filename)
def generate_video_by_frames(self, prompts: List[str], frames: List[int], negative_prompt: str, fps: int = 8,
num_inference_steps: int = 50, guidance_scale: float = 6.0,
seed: Optional[int] = None) -> str:
"""
Generates a video from a list of prompts. Since CogVideoX doesn't natively support multi-prompt videos,
this concatenates prompts into a single description.
Args:
prompts (List[str]): List of prompts for each segment.
frames (List[int]): Number of frames per segment (summed to total frames).
negative_prompt (str): Ignored.
fps (int): Frames per second (default 8).
num_inference_steps (int): Inference steps (default 50).
guidance_scale (float): Guidance scale (default 6.0).
seed (Optional[int]): Random seed.
Returns:
str: Path to the generated video file.
"""
if not prompts or not frames:
raise ValueError("Prompts and frames lists cannot be empty.")
# Combine prompts into a single narrative
combined_prompt = " ".join(prompts)
total_frames = sum(frames)
return self.generate_video(
prompt=combined_prompt,
negative_prompt=negative_prompt,
steps=num_inference_steps,
seed=seed if seed is not None else -1,
nb_frames=total_frames
)
def getModels(self):
"""Returns available CogVideoX models."""
return ["THUDM/CogVideoX-2b", "THUDM/CogVideoX-5b"]

View File

@ -1,132 +1,155 @@
import os
import time
import requests
from typing import Optional, Dict
from pathlib import Path
from typing import List, Optional
import pipmaster as pm
if not pm.is_installed("lumaai"):
pm.install("lumaai")
from lumaai import LumaAI
from lollms.app import LollmsApplication
from lollms.main_config import LOLLMSConfig
from lollms.config import TypedConfig
from lollms.utilities import find_next_available_filename
from lollms.service import LollmsSERVICE
from lollms.ttv import LollmsTTV
from lollms.config import TypedConfig, ConfigTemplate, BaseConfig
import os
from pathlib import Path
from ascii_colors import ASCIIColors
class LollmsLumaLabs(LollmsTTV):
def __init__(self, app, output_folder:str|Path=None):
"""
Initializes the NovitaAITextToVideo binding.
Args:
api_key (str): The API key for authentication.
base_url (str): The base URL for the Novita.ai API. Defaults to "https://api.novita.ai/v3/async".
"""
# Check for the LUMALABS_KEY environment variable if no API key is provided
api_key = os.getenv("LUMALABS_KEY","")
"""
LollmsLumaLabs is an implementation of LollmsTTV using LumaAI for text-to-image generation.
Note: LumaAI currently supports image generation, so video output will be limited to single-frame representations.
"""
def __init__(
self,
app: LollmsApplication,
output_folder: str | Path = None
):
# Initialize LumaAI client
api_key = os.environ.get("LUMAAI_API_KEY")
service_config = TypedConfig(
ConfigTemplate([
{"name":"api_key", "type":"str", "value":api_key, "help":"A valid Lumalabs AI key to generate text using anthropic api"},
{"name":"api_key", "type":"str", "value":api_key, "help":"A valid Novita AI key to generate text using anthropic api"},
]),
BaseConfig(config={
"api_key": "", # use avx2
})
)
super().__init__("lumalabs", app, service_config, output_folder)
self.base_url = "https://api.lumalabs.ai/dream-machine/v1/generations"
self.headers = {
"accept": "application/json",
"authorization": f"Bearer {self.service_config.api_key}",
"content-type": "application/json"
}
super().__init__("lumalabs", app, service_config, output_folder, api_key)
self.client = LumaAI(auth_token=api_key)
def settings_updated(self):
self.base_url = "https://api.lumalabs.ai/dream-machine/v1/generations"
self.headers = {
"accept": "application/json",
"authorization": f"Bearer {self.service_config.api_key}",
"content-type": "application/json"
}
def determine_aspect_ratio(self, width, height):
# Define common aspect ratios and their tolerances
aspect_ratios = {
"1:1": (1, 1, 0.05),
"4:3": (4, 3, 0.1),
"16:9": (16, 9, 0.1),
"16:10": (16, 10, 0.1),
"21:9": (21, 9, 0.1),
"3:2": (3, 2, 0.1),
"5:4": (5, 4, 0.1),
"2:1": (2, 1, 0.1)
}
self.client = LumaAI(auth_token=self.service_config.api_key)
# Calculate the aspect ratio of the input dimensions
current_aspect = width / height
best_match = None
min_diff = float('inf')
for ratio, (w, h, tolerance) in aspect_ratios.items():
expected_aspect = w / h
diff = abs(expected_aspect - current_aspect)
if diff < min_diff and diff < tolerance:
min_diff = diff
best_match = ratio
if best_match:
return best_match
else:
# If no standard aspect ratio matches within tolerance, return the closest one
return f"{int(width)}:{int(height)}"
def generate_video(
self,
prompt: str,
negative_prompt: Optional[str] = None,
model_name: str = "",
height: int = 512,
width: int = 512,
steps: int = 20,
seed: int = -1,
nb_frames: int = None,
output_dir: str | Path = None,
) -> str:
"""
Generates a 'video' from a text prompt using LumaAI. Currently limited to a single image due to API constraints.
def generate_video(self, prompt: str, width, height,
loop: bool = False, num_frames: int = 60,
fps: int = 30, keyframes: Optional[Dict] = None)-> str:
aspect_ratio = self.determine_aspect_ratio(width, height)
payload = {
Args:
prompt (str): The text prompt describing the content.
negative_prompt (Optional[str]): Text describing elements to avoid (not supported by LumaAI, ignored).
model_name (str): Model name (not supported by LumaAI, ignored).
height (int): Desired height of the output image (default 512, LumaAI may override).
width (int): Desired width of the output image (default 512, LumaAI may override).
steps (int): Number of inference steps (default 20, ignored by LumaAI).
seed (int): Random seed for reproducibility (default -1, ignored by LumaAI).
nb_frames (int): Number of frames (default None, limited to 1 due to LumaAI image-only support).
output_dir (str | Path): Optional custom output directory.
Returns:
str: Path to the generated image file (single-frame 'video').
"""
output_path = Path(output_dir) if output_dir else self.output_folder
output_path.mkdir(exist_ok=True, parents=True)
# Warn about unsupported features
if negative_prompt:
ASCIIColors.warning("Warning: LumaAI does not support negative prompts. Ignoring negative_prompt.")
if model_name:
ASCIIColors.warning("Warning: LumaAI does not support model selection in this implementation. Ignoring model_name.")
if steps != 20:
ASCIIColors.warning("Warning: LumaAI controls inference steps internally. Ignoring steps parameter.")
if seed != -1:
ASCIIColors.warning("Warning: LumaAI does not support seed specification. Ignoring seed.")
if nb_frames and nb_frames > 1:
ASCIIColors.warning("Warning: LumaAI only supports single-image generation. Generating 1 frame instead of requested nb_frames.")
# Note: LumaAI's API (as shown) doesn't support width/height directly in the provided example,
# but we'll include them in case the API supports it in a newer version
generation_params = {
"prompt": prompt,
"aspect_ratio": aspect_ratio,
"loop": loop
# Uncomment and use these if LumaAI supports them in the future:
# "height": height,
# "width": width,
}
if keyframes:
payload["keyframes"] = keyframes
response = requests.post(self.base_url, headers=self.headers, json=payload)
response.raise_for_status()
generation_data = response.json()
video_url = generation_data['assets']['video']
# Download the video
video_response = requests.get(video_url)
video_response.raise_for_status()
output_path = "output.mp4"
with open(output_path, 'wb') as f:
f.write(video_response.content)
return output_path
# Create generation request
try:
generation = self.client.generations.image.create(**generation_params)
except Exception as e:
raise RuntimeError(f"Failed to initiate generation: {str(e)}")
def generate_video_by_frames(self, prompts, frames, negative_prompt, fps = 8, num_inference_steps = 50, guidance_scale = 6, seed = None):
pass # TODO : implement
# Poll for completion
completed = False
while not completed:
try:
generation = self.client.generations.get(id=generation.id)
if generation.state == "completed":
completed = True
elif generation.state == "failed":
raise RuntimeError(f"Generation failed: {generation.failure_reason}")
print("Dreaming...")
time.sleep(2)
except Exception as e:
raise RuntimeError(f"Error polling generation status: {str(e)}")
def extend_video(self, prompt: str, generation_id: str, reverse: bool = False) -> str:
keyframes = {
"frame0" if not reverse else "frame1": {
"type": "generation",
"id": generation_id
}
}
return self.generate_video(prompt, keyframes=keyframes)
# Download the image
image_url = generation.assets.image
output_filename = find_next_available_filename(output_path, f"{generation.id}.jpg")
try:
response = requests.get(image_url, stream=True)
response.raise_for_status()
with open(output_filename, 'wb') as file:
file.write(response.content)
print(f"File downloaded as {output_filename}")
except Exception as e:
raise RuntimeError(f"Failed to download image: {str(e)}")
def image_to_video(self, prompt: str, image_url: str, is_end_frame: bool = False) -> str:
keyframes = {
"frame0" if not is_end_frame else "frame1": {
"type": "image",
"url": image_url
}
}
return self.generate_video(prompt, keyframes=keyframes)
return str(output_filename)
# Usage example:
if __name__ == "__main__":
luma_video = LumaLabsVideo("your-api-key-here")
prompt = "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes."
output_video = luma_video.generate_video(prompt)
print(f"Video generated and saved to: {output_video}")
def generate_video_by_frames(self, prompts: List[str], frames: List[int], negative_prompt: str, fps: int = 8,
num_inference_steps: int = 50, guidance_scale: float = 6.0,
seed: Optional[int] = None) -> str:
"""
Generates a 'video' from a list of prompts. Since LumaAI only supports single images,
this will generate the first prompt's image and return it as a static representation.
"""
if not prompts:
raise ValueError("Prompts list cannot be empty.")
return self.generate_video(
prompt=prompts[0],
negative_prompt=negative_prompt,
seed=seed if seed is not None else -1
)
def getModels(self):
"""
Gets the list of models. LumaAI doesn't expose model selection, so returns a placeholder.
"""
return ["LumaAI_Default"]

View File

@ -252,9 +252,6 @@ class LollmsNovitaAITextToVideo(LollmsTTV):
width: int = 512,
steps: int = 20,
seed: int = -1,
guidance_scale: float = 6.0,
closed_loop: Optional[bool] = None,
clip_skip: Optional[int] = None,
output_dir:str | Path =None,
) -> str:
"""

View File

@ -38,10 +38,18 @@ class LollmsTTV(LollmsSERVICE):
@abstractmethod
def generate_video(self, prompt: str, negative_prompt: str, num_frames: int = 49, fps: int = 8,
num_inference_steps: int = 50, guidance_scale: float = 6.0,
seed: Optional[int] = None,
output_dir:str | Path =None,) -> str:
def generate_video(
self,
prompt: str,
negative_prompt: Optional[str] = None,
model_name: str = "",
height: int = 512,
width: int = 512,
steps: int = 20,
seed: int = -1,
nb_frames: int = None,
output_dir:str | Path =None,
) -> str:
"""
Generates a video from a single text prompt.