|
import argparse |
|
from datetime import datetime |
|
import gc |
|
import random |
|
import os |
|
import re |
|
import time |
|
import math |
|
import copy |
|
from types import ModuleType, SimpleNamespace |
|
from typing import Tuple, Optional, List, Union, Any, Dict |
|
|
|
import torch |
|
import accelerate |
|
from accelerate import Accelerator |
|
from safetensors.torch import load_file, save_file |
|
from safetensors import safe_open |
|
from PIL import Image |
|
import cv2 |
|
import numpy as np |
|
import torchvision.transforms.functional as TF |
|
from tqdm import tqdm |
|
|
|
from networks import lora_wan |
|
from utils.safetensors_utils import mem_eff_save_file, load_safetensors |
|
from wan.configs import WAN_CONFIGS, SUPPORTED_SIZES |
|
import wan |
|
from wan.modules.model import WanModel, load_wan_model, detect_wan_sd_dtype |
|
from wan.modules.vae import WanVAE |
|
from wan.modules.t5 import T5EncoderModel |
|
from wan.modules.clip import CLIPModel |
|
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler |
|
from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps |
|
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
|
try: |
|
from lycoris.kohya import create_network_from_weights |
|
except: |
|
pass |
|
|
|
from utils.model_utils import str_to_dtype |
|
from utils.device_utils import clean_memory_on_device |
|
from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device |
|
from dataset.image_video_dataset import load_video |
|
|
|
import logging |
|
|
|
logger = logging.getLogger(__name__) |
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
class GenerationSettings: |
|
def __init__( |
|
self, device: torch.device, cfg, dit_dtype: torch.dtype, dit_weight_dtype: Optional[torch.dtype], vae_dtype: torch.dtype |
|
): |
|
self.device = device |
|
self.cfg = cfg |
|
self.dit_dtype = dit_dtype |
|
self.dit_weight_dtype = dit_weight_dtype |
|
self.vae_dtype = vae_dtype |
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
"""parse command line arguments""" |
|
parser = argparse.ArgumentParser(description="Wan 2.1 inference script") |
|
|
|
|
|
parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).") |
|
parser.add_argument("--task", type=str, default="t2v-14B", choices=list(WAN_CONFIGS.keys()), help="The task to run.") |
|
parser.add_argument( |
|
"--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample." |
|
) |
|
|
|
parser.add_argument("--dit", type=str, default=None, help="DiT checkpoint path") |
|
parser.add_argument("--vae", type=str, default=None, help="VAE checkpoint path") |
|
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is bfloat16") |
|
parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU") |
|
parser.add_argument("--t5", type=str, default=None, help="text encoder (T5) checkpoint path") |
|
parser.add_argument("--clip", type=str, default=None, help="text encoder (CLIP) checkpoint path") |
|
|
|
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path") |
|
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier") |
|
parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns") |
|
parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns") |
|
parser.add_argument( |
|
"--save_merged_model", |
|
type=str, |
|
default=None, |
|
help="Save merged model to path. If specified, no inference will be performed.", |
|
) |
|
|
|
|
|
parser.add_argument("--prompt", type=str, default=None, help="prompt for generation") |
|
parser.add_argument( |
|
"--negative_prompt", |
|
type=str, |
|
default=None, |
|
help="negative prompt for generation, use default negative prompt if not specified", |
|
) |
|
parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width") |
|
parser.add_argument("--video_length", type=int, default=None, help="video length, Default depends on task") |
|
parser.add_argument("--fps", type=int, default=16, help="video fps, Default is 16") |
|
parser.add_argument("--infer_steps", type=int, default=None, help="number of inference steps") |
|
parser.add_argument("--save_path", type=str, required=True, help="path to save generated video") |
|
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.") |
|
parser.add_argument( |
|
"--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False." |
|
) |
|
parser.add_argument( |
|
"--guidance_scale", |
|
type=float, |
|
default=5.0, |
|
help="Guidance scale for classifier free guidance. Default is 5.0.", |
|
) |
|
parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference") |
|
parser.add_argument("--image_path", type=str, default=None, help="path to image for image2video inference") |
|
parser.add_argument("--end_image_path", type=str, default=None, help="path to end image for image2video inference") |
|
parser.add_argument( |
|
"--control_path", |
|
type=str, |
|
default=None, |
|
help="path to control video for inference with controlnet. video file or directory with images", |
|
) |
|
parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving") |
|
parser.add_argument( |
|
"--cfg_skip_mode", |
|
type=str, |
|
default="none", |
|
choices=["early", "late", "middle", "early_late", "alternate", "none"], |
|
help="CFG skip mode. each mode skips different parts of the CFG. " |
|
" early: initial steps, late: later steps, middle: middle steps, early_late: both early and late, alternate: alternate, none: no skip (default)", |
|
) |
|
parser.add_argument( |
|
"--cfg_apply_ratio", |
|
type=float, |
|
default=None, |
|
help="The ratio of steps to apply CFG (0.0 to 1.0). Default is None (apply all steps).", |
|
) |
|
parser.add_argument( |
|
"--slg_layers", type=str, default=None, help="Skip block (layer) indices for SLG (Skip Layer Guidance), comma separated" |
|
) |
|
parser.add_argument( |
|
"--slg_scale", |
|
type=float, |
|
default=3.0, |
|
help="scale for SLG classifier free guidance. Default is 3.0. Ignored if slg_mode is None or uncond", |
|
) |
|
parser.add_argument("--slg_start", type=float, default=0.0, help="start ratio for inference steps for SLG. Default is 0.0.") |
|
parser.add_argument("--slg_end", type=float, default=0.3, help="end ratio for inference steps for SLG. Default is 0.3.") |
|
parser.add_argument( |
|
"--slg_mode", |
|
type=str, |
|
default=None, |
|
choices=["original", "uncond"], |
|
help="SLG mode. original: same as SD3, uncond: replace uncond pred with SLG pred", |
|
) |
|
|
|
|
|
parser.add_argument( |
|
"--flow_shift", |
|
type=float, |
|
default=None, |
|
help="Shift factor for flow matching schedulers. Default depends on task.", |
|
) |
|
|
|
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model") |
|
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8") |
|
parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled") |
|
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model") |
|
parser.add_argument( |
|
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU" |
|
) |
|
parser.add_argument( |
|
"--attn_mode", |
|
type=str, |
|
default="torch", |
|
choices=["flash", "flash2", "flash3", "torch", "sageattn", "xformers", "sdpa"], |
|
help="attention mode", |
|
) |
|
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model") |
|
parser.add_argument( |
|
"--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type" |
|
) |
|
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata") |
|
parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference") |
|
parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference") |
|
parser.add_argument("--compile", action="store_true", help="Enable torch.compile") |
|
parser.add_argument( |
|
"--compile_args", |
|
nargs=4, |
|
metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"), |
|
default=["inductor", "max-autotune-no-cudagraphs", "False", "False"], |
|
help="Torch.compile settings", |
|
) |
|
|
|
|
|
parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file") |
|
parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console") |
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.from_file and args.interactive: |
|
raise ValueError("Cannot use both --from_file and --interactive at the same time") |
|
|
|
if args.prompt is None and not args.from_file and not args.interactive and args.latent_path is None: |
|
raise ValueError("Either --prompt, --from_file, --interactive, or --latent_path must be specified") |
|
|
|
assert (args.latent_path is None or len(args.latent_path) == 0) or ( |
|
args.output_type == "images" or args.output_type == "video" |
|
), "latent_path is only supported for images or video output" |
|
|
|
return args |
|
|
|
|
|
def parse_prompt_line(line: str) -> Dict[str, Any]: |
|
"""Parse a prompt line into a dictionary of argument overrides |
|
|
|
Args: |
|
line: Prompt line with options |
|
|
|
Returns: |
|
Dict[str, Any]: Dictionary of argument overrides |
|
""" |
|
|
|
parts = line.split(" --") |
|
prompt = parts[0].strip() |
|
|
|
|
|
overrides = {"prompt": prompt} |
|
|
|
for part in parts[1:]: |
|
if not part.strip(): |
|
continue |
|
option_parts = part.split(" ", 1) |
|
option = option_parts[0].strip() |
|
value = option_parts[1].strip() if len(option_parts) > 1 else "" |
|
|
|
|
|
if option == "w": |
|
overrides["video_size_width"] = int(value) |
|
elif option == "h": |
|
overrides["video_size_height"] = int(value) |
|
elif option == "f": |
|
overrides["video_length"] = int(value) |
|
elif option == "d": |
|
overrides["seed"] = int(value) |
|
elif option == "s": |
|
overrides["infer_steps"] = int(value) |
|
elif option == "g" or option == "l": |
|
overrides["guidance_scale"] = float(value) |
|
elif option == "fs": |
|
overrides["flow_shift"] = float(value) |
|
elif option == "i": |
|
overrides["image_path"] = value |
|
elif option == "cn": |
|
overrides["control_path"] = value |
|
elif option == "n": |
|
overrides["negative_prompt"] = value |
|
|
|
return overrides |
|
|
|
|
|
def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace: |
|
"""Apply overrides to args |
|
|
|
Args: |
|
args: Original arguments |
|
overrides: Dictionary of overrides |
|
|
|
Returns: |
|
argparse.Namespace: New arguments with overrides applied |
|
""" |
|
args_copy = copy.deepcopy(args) |
|
|
|
for key, value in overrides.items(): |
|
if key == "video_size_width": |
|
args_copy.video_size[1] = value |
|
elif key == "video_size_height": |
|
args_copy.video_size[0] = value |
|
else: |
|
setattr(args_copy, key, value) |
|
|
|
return args_copy |
|
|
|
|
|
def get_task_defaults(task: str, size: Optional[Tuple[int, int]] = None) -> Tuple[int, float, int, bool]: |
|
"""Return default values for each task |
|
|
|
Args: |
|
task: task name (t2v, t2i, i2v etc.) |
|
size: size of the video (width, height) |
|
|
|
Returns: |
|
Tuple[int, float, int, bool]: (infer_steps, flow_shift, video_length, needs_clip) |
|
""" |
|
width, height = size if size else (0, 0) |
|
|
|
if "t2i" in task: |
|
return 50, 5.0, 1, False |
|
elif "i2v" in task: |
|
flow_shift = 3.0 if (width == 832 and height == 480) or (width == 480 and height == 832) else 5.0 |
|
return 40, flow_shift, 81, True |
|
else: |
|
return 50, 5.0, 81, False |
|
|
|
|
|
def setup_args(args: argparse.Namespace) -> argparse.Namespace: |
|
"""Validate and set default values for optional arguments |
|
|
|
Args: |
|
args: command line arguments |
|
|
|
Returns: |
|
argparse.Namespace: updated arguments |
|
""" |
|
|
|
infer_steps, flow_shift, video_length, _ = get_task_defaults(args.task, tuple(args.video_size)) |
|
|
|
|
|
if args.infer_steps is None: |
|
args.infer_steps = infer_steps |
|
if args.flow_shift is None: |
|
args.flow_shift = flow_shift |
|
if args.video_length is None: |
|
args.video_length = video_length |
|
|
|
|
|
if "t2i" in args.task: |
|
assert args.video_length == 1, f"video_length should be 1 for task {args.task}" |
|
|
|
|
|
if args.slg_layers is not None: |
|
args.slg_layers = list(map(int, args.slg_layers.split(","))) |
|
|
|
return args |
|
|
|
|
|
def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]: |
|
"""Validate video size and length |
|
|
|
Args: |
|
args: command line arguments |
|
|
|
Returns: |
|
Tuple[int, int, int]: (height, width, video_length) |
|
""" |
|
height = args.video_size[0] |
|
width = args.video_size[1] |
|
size = f"{width}*{height}" |
|
|
|
if size not in SUPPORTED_SIZES[args.task]: |
|
logger.warning(f"Size {size} is not supported for task {args.task}. Supported sizes are {SUPPORTED_SIZES[args.task]}.") |
|
|
|
video_length = args.video_length |
|
|
|
if height % 8 != 0 or width % 8 != 0: |
|
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") |
|
|
|
return height, width, video_length |
|
|
|
|
|
def calculate_dimensions(video_size: Tuple[int, int], video_length: int, config) -> Tuple[Tuple[int, int, int, int], int]: |
|
"""calculate dimensions for the generation |
|
|
|
Args: |
|
video_size: video frame size (height, width) |
|
video_length: number of frames in the video |
|
config: model configuration |
|
|
|
Returns: |
|
Tuple[Tuple[int, int, int, int], int]: |
|
((channels, frames, height, width), seq_len) |
|
""" |
|
height, width = video_size |
|
frames = video_length |
|
|
|
|
|
lat_f = (frames - 1) // config.vae_stride[0] + 1 |
|
lat_h = height // config.vae_stride[1] |
|
lat_w = width // config.vae_stride[2] |
|
|
|
|
|
seq_len = math.ceil((lat_h * lat_w) / (config.patch_size[1] * config.patch_size[2]) * lat_f) |
|
|
|
return ((16, lat_f, lat_h, lat_w), seq_len) |
|
|
|
|
|
def load_vae(args: argparse.Namespace, config, device: torch.device, dtype: torch.dtype) -> WanVAE: |
|
"""load VAE model |
|
|
|
Args: |
|
args: command line arguments |
|
config: model configuration |
|
device: device to use |
|
dtype: data type for the model |
|
|
|
Returns: |
|
WanVAE: loaded VAE model |
|
""" |
|
vae_path = args.vae if args.vae is not None else os.path.join(args.ckpt_dir, config.vae_checkpoint) |
|
|
|
logger.info(f"Loading VAE model from {vae_path}") |
|
cache_device = torch.device("cpu") if args.vae_cache_cpu else None |
|
vae = WanVAE(vae_path=vae_path, device=device, dtype=dtype, cache_device=cache_device) |
|
return vae |
|
|
|
|
|
def load_text_encoder(args: argparse.Namespace, config, device: torch.device) -> T5EncoderModel: |
|
"""load text encoder (T5) model |
|
|
|
Args: |
|
args: command line arguments |
|
config: model configuration |
|
device: device to use |
|
|
|
Returns: |
|
T5EncoderModel: loaded text encoder model |
|
""" |
|
checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_checkpoint) |
|
tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.t5_tokenizer) |
|
|
|
text_encoder = T5EncoderModel( |
|
text_len=config.text_len, |
|
dtype=config.t5_dtype, |
|
device=device, |
|
checkpoint_path=checkpoint_path, |
|
tokenizer_path=tokenizer_path, |
|
weight_path=args.t5, |
|
fp8=args.fp8_t5, |
|
) |
|
|
|
return text_encoder |
|
|
|
|
|
def load_clip_model(args: argparse.Namespace, config, device: torch.device) -> CLIPModel: |
|
"""load CLIP model (for I2V only) |
|
|
|
Args: |
|
args: command line arguments |
|
config: model configuration |
|
device: device to use |
|
|
|
Returns: |
|
CLIPModel: loaded CLIP model |
|
""" |
|
checkpoint_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_checkpoint) |
|
tokenizer_path = None if args.ckpt_dir is None else os.path.join(args.ckpt_dir, config.clip_tokenizer) |
|
|
|
clip = CLIPModel( |
|
dtype=config.clip_dtype, |
|
device=device, |
|
checkpoint_path=checkpoint_path, |
|
tokenizer_path=tokenizer_path, |
|
weight_path=args.clip, |
|
) |
|
|
|
return clip |
|
|
|
|
|
def load_dit_model( |
|
args: argparse.Namespace, |
|
config, |
|
device: torch.device, |
|
dit_dtype: torch.dtype, |
|
dit_weight_dtype: Optional[torch.dtype] = None, |
|
is_i2v: bool = False, |
|
) -> WanModel: |
|
"""load DiT model |
|
|
|
Args: |
|
args: command line arguments |
|
config: model configuration |
|
device: device to use |
|
dit_dtype: data type for the model |
|
dit_weight_dtype: data type for the model weights. None for as-is |
|
is_i2v: I2V mode |
|
|
|
Returns: |
|
WanModel: loaded DiT model |
|
""" |
|
loading_device = "cpu" |
|
if args.blocks_to_swap == 0 and args.lora_weight is None and not args.fp8_scaled: |
|
loading_device = device |
|
|
|
loading_weight_dtype = dit_weight_dtype |
|
if args.fp8_scaled or args.lora_weight is not None: |
|
loading_weight_dtype = dit_dtype |
|
|
|
|
|
model = load_wan_model(config, device, args.dit, args.attn_mode, False, loading_device, loading_weight_dtype, False) |
|
|
|
return model |
|
|
|
|
|
def merge_lora_weights(lora_module: ModuleType, model: torch.nn.Module, args: argparse.Namespace, device: torch.device) -> None: |
|
"""merge LoRA weights to the model |
|
|
|
Args: |
|
model: DiT model |
|
args: command line arguments |
|
device: device to use |
|
""" |
|
if args.lora_weight is None or len(args.lora_weight) == 0: |
|
return |
|
|
|
for i, lora_weight in enumerate(args.lora_weight): |
|
if args.lora_multiplier is not None and len(args.lora_multiplier) > i: |
|
lora_multiplier = args.lora_multiplier[i] |
|
else: |
|
lora_multiplier = 1.0 |
|
|
|
logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}") |
|
weights_sd = load_file(lora_weight) |
|
|
|
|
|
original_key_count = len(weights_sd.keys()) |
|
if args.include_patterns is not None and len(args.include_patterns) > i: |
|
include_pattern = args.include_patterns[i] |
|
regex_include = re.compile(include_pattern) |
|
weights_sd = {k: v for k, v in weights_sd.items() if regex_include.search(k)} |
|
logger.info(f"Filtered keys with include pattern {include_pattern}: {original_key_count} -> {len(weights_sd.keys())}") |
|
if args.exclude_patterns is not None and len(args.exclude_patterns) > i: |
|
original_key_count_ex = len(weights_sd.keys()) |
|
exclude_pattern = args.exclude_patterns[i] |
|
regex_exclude = re.compile(exclude_pattern) |
|
weights_sd = {k: v for k, v in weights_sd.items() if not regex_exclude.search(k)} |
|
logger.info( |
|
f"Filtered keys with exclude pattern {exclude_pattern}: {original_key_count_ex} -> {len(weights_sd.keys())}" |
|
) |
|
if len(weights_sd) != original_key_count: |
|
remaining_keys = list(set([k.split(".", 1)[0] for k in weights_sd.keys()])) |
|
remaining_keys.sort() |
|
logger.info(f"Remaining LoRA modules after filtering: {remaining_keys}") |
|
if len(weights_sd) == 0: |
|
logger.warning(f"No keys left after filtering.") |
|
|
|
if args.lycoris: |
|
lycoris_net, _ = create_network_from_weights( |
|
multiplier=lora_multiplier, |
|
file=None, |
|
weights_sd=weights_sd, |
|
unet=model, |
|
text_encoder=None, |
|
vae=None, |
|
for_inference=True, |
|
) |
|
lycoris_net.merge_to(None, model, weights_sd, dtype=None, device=device) |
|
else: |
|
network = lora_module.create_arch_network_from_weights(lora_multiplier, weights_sd, unet=model, for_inference=True) |
|
network.merge_to(None, model, weights_sd, device=device, non_blocking=True) |
|
|
|
synchronize_device(device) |
|
logger.info("LoRA weights loaded") |
|
|
|
|
|
if args.save_merged_model: |
|
logger.info(f"Saving merged model to {args.save_merged_model}") |
|
mem_eff_save_file(model.state_dict(), args.save_merged_model) |
|
logger.info("Merged model saved") |
|
|
|
|
|
def optimize_model( |
|
model: WanModel, args: argparse.Namespace, device: torch.device, dit_dtype: torch.dtype, dit_weight_dtype: torch.dtype |
|
) -> None: |
|
"""optimize the model (FP8 conversion, device move etc.) |
|
|
|
Args: |
|
model: dit model |
|
args: command line arguments |
|
device: device to use |
|
dit_dtype: dtype for the model |
|
dit_weight_dtype: dtype for the model weights |
|
""" |
|
if args.fp8_scaled: |
|
|
|
state_dict = model.state_dict() |
|
|
|
|
|
move_to_device = args.blocks_to_swap == 0 |
|
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=args.fp8_fast) |
|
|
|
info = model.load_state_dict(state_dict, strict=True, assign=True) |
|
logger.info(f"Loaded FP8 optimized weights: {info}") |
|
|
|
if args.blocks_to_swap == 0: |
|
model.to(device) |
|
else: |
|
|
|
target_dtype = None |
|
target_device = None |
|
|
|
if dit_weight_dtype is not None: |
|
logger.info(f"Convert model to {dit_weight_dtype}") |
|
target_dtype = dit_weight_dtype |
|
|
|
if args.blocks_to_swap == 0: |
|
logger.info(f"Move model to device: {device}") |
|
target_device = device |
|
|
|
model.to(target_device, target_dtype) |
|
|
|
if args.compile: |
|
compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args |
|
logger.info( |
|
f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]" |
|
) |
|
torch._dynamo.config.cache_size_limit = 32 |
|
for i in range(len(model.blocks)): |
|
model.blocks[i] = torch.compile( |
|
model.blocks[i], |
|
backend=compile_backend, |
|
mode=compile_mode, |
|
dynamic=compile_dynamic.lower() in "true", |
|
fullgraph=compile_fullgraph.lower() in "true", |
|
) |
|
|
|
if args.blocks_to_swap > 0: |
|
logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}") |
|
model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False) |
|
model.move_to_device_except_swap_blocks(device) |
|
model.prepare_block_swap_before_forward() |
|
else: |
|
|
|
model.to(device) |
|
|
|
model.eval().requires_grad_(False) |
|
clean_memory_on_device(device) |
|
|
|
|
|
def prepare_t2v_inputs( |
|
args: argparse.Namespace, |
|
config, |
|
accelerator: Accelerator, |
|
device: torch.device, |
|
vae: Optional[WanVAE] = None, |
|
encoded_context: Optional[Dict] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: |
|
"""Prepare inputs for T2V |
|
|
|
Args: |
|
args: command line arguments |
|
config: model configuration |
|
accelerator: Accelerator instance |
|
device: device to use |
|
vae: VAE model for control video encoding |
|
encoded_context: Pre-encoded text context |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: |
|
(noise, context, context_null, (arg_c, arg_null)) |
|
""" |
|
|
|
|
|
height, width = args.video_size |
|
frames = args.video_length |
|
(_, lat_f, lat_h, lat_w), seq_len = calculate_dimensions(args.video_size, args.video_length, config) |
|
target_shape = (16, lat_f, lat_h, lat_w) |
|
|
|
|
|
n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt |
|
|
|
|
|
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) |
|
if not args.cpu_noise: |
|
seed_g = torch.Generator(device=device) |
|
seed_g.manual_seed(seed) |
|
else: |
|
|
|
seed_g = torch.manual_seed(seed) |
|
|
|
if encoded_context is None: |
|
|
|
text_encoder = load_text_encoder(args, config, device) |
|
text_encoder.model.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
if args.fp8_t5: |
|
with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): |
|
context = text_encoder([args.prompt], device) |
|
context_null = text_encoder([n_prompt], device) |
|
else: |
|
context = text_encoder([args.prompt], device) |
|
context_null = text_encoder([n_prompt], device) |
|
|
|
|
|
del text_encoder |
|
clean_memory_on_device(device) |
|
else: |
|
|
|
context = encoded_context["context"] |
|
context_null = encoded_context["context_null"] |
|
|
|
|
|
if config.is_fun_control: |
|
|
|
logger.info(f"Encoding control video to latent space") |
|
|
|
control_video = load_control_video(args.control_path, frames, height, width).to(device) |
|
vae.to_device(device) |
|
with torch.autocast(device_type=device.type, dtype=vae.dtype), torch.no_grad(): |
|
control_latent = vae.encode([control_video])[0] |
|
y = torch.concat([control_latent, torch.zeros_like(control_latent)], dim=0) |
|
vae.to_device("cpu") |
|
else: |
|
y = None |
|
|
|
|
|
noise = torch.randn(target_shape, dtype=torch.float32, generator=seed_g, device=device if not args.cpu_noise else "cpu") |
|
noise = noise.to(device) |
|
|
|
|
|
arg_c = {"context": context, "seq_len": seq_len} |
|
arg_null = {"context": context_null, "seq_len": seq_len} |
|
if y is not None: |
|
arg_c["y"] = [y] |
|
arg_null["y"] = [y] |
|
|
|
return noise, context, context_null, (arg_c, arg_null) |
|
|
|
|
|
def prepare_i2v_inputs( |
|
args: argparse.Namespace, |
|
config, |
|
accelerator: Accelerator, |
|
device: torch.device, |
|
vae: WanVAE, |
|
encoded_context: Optional[Dict] = None, |
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: |
|
"""Prepare inputs for I2V |
|
|
|
Args: |
|
args: command line arguments |
|
config: model configuration |
|
accelerator: Accelerator instance |
|
device: device to use |
|
vae: VAE model, used for image encoding |
|
encoded_context: Pre-encoded text context |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]: |
|
(noise, context, context_null, y, (arg_c, arg_null)) |
|
""" |
|
|
|
height, width = args.video_size |
|
frames = args.video_length |
|
max_area = width * height |
|
|
|
|
|
img = Image.open(args.image_path).convert("RGB") |
|
|
|
|
|
img_cv2 = np.array(img) |
|
|
|
|
|
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) |
|
|
|
|
|
if args.end_image_path is not None: |
|
end_img = Image.open(args.end_image_path).convert("RGB") |
|
end_img_cv2 = np.array(end_img) |
|
else: |
|
end_img = None |
|
end_img_cv2 = None |
|
has_end_image = end_img is not None |
|
|
|
|
|
height, width = img_tensor.shape[1:] |
|
aspect_ratio = height / width |
|
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1]) |
|
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2]) |
|
height = lat_h * config.vae_stride[1] |
|
width = lat_w * config.vae_stride[2] |
|
lat_f = (frames - 1) // config.vae_stride[0] + 1 |
|
max_seq_len = (lat_f + (1 if has_end_image else 0)) * lat_h * lat_w // (config.patch_size[1] * config.patch_size[2]) |
|
|
|
|
|
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) |
|
if not args.cpu_noise: |
|
seed_g = torch.Generator(device=device) |
|
seed_g.manual_seed(seed) |
|
else: |
|
|
|
seed_g = torch.manual_seed(seed) |
|
|
|
|
|
noise = torch.randn( |
|
16, |
|
lat_f + (1 if has_end_image else 0), |
|
lat_h, |
|
lat_w, |
|
dtype=torch.float32, |
|
generator=seed_g, |
|
device=device if not args.cpu_noise else "cpu", |
|
) |
|
noise = noise.to(device) |
|
|
|
|
|
n_prompt = args.negative_prompt if args.negative_prompt else config.sample_neg_prompt |
|
|
|
if encoded_context is None: |
|
|
|
text_encoder = load_text_encoder(args, config, device) |
|
text_encoder.model.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
if args.fp8_t5: |
|
with torch.amp.autocast(device_type=device.type, dtype=config.t5_dtype): |
|
context = text_encoder([args.prompt], device) |
|
context_null = text_encoder([n_prompt], device) |
|
else: |
|
context = text_encoder([args.prompt], device) |
|
context_null = text_encoder([n_prompt], device) |
|
|
|
|
|
del text_encoder |
|
clean_memory_on_device(device) |
|
|
|
|
|
clip = load_clip_model(args, config, device) |
|
clip.model.to(device) |
|
|
|
|
|
logger.info(f"Encoding image to CLIP context") |
|
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): |
|
clip_context = clip.visual([img_tensor[:, None, :, :]]) |
|
logger.info(f"Encoding complete") |
|
|
|
|
|
del clip |
|
clean_memory_on_device(device) |
|
else: |
|
|
|
context = encoded_context["context"] |
|
context_null = encoded_context["context_null"] |
|
clip_context = encoded_context["clip_context"] |
|
|
|
|
|
logger.info(f"Encoding image to latent space") |
|
vae.to_device(device) |
|
|
|
|
|
interpolation = cv2.INTER_AREA if height < img_cv2.shape[0] else cv2.INTER_CUBIC |
|
img_resized = cv2.resize(img_cv2, (width, height), interpolation=interpolation) |
|
img_resized = TF.to_tensor(img_resized).sub_(0.5).div_(0.5).to(device) |
|
img_resized = img_resized.unsqueeze(1) |
|
|
|
if has_end_image: |
|
interpolation = cv2.INTER_AREA if height < end_img_cv2.shape[1] else cv2.INTER_CUBIC |
|
end_img_resized = cv2.resize(end_img_cv2, (width, height), interpolation=interpolation) |
|
end_img_resized = TF.to_tensor(end_img_resized).sub_(0.5).div_(0.5).to(device) |
|
end_img_resized = end_img_resized.unsqueeze(1) |
|
|
|
|
|
msk = torch.zeros(4, lat_f + (1 if has_end_image else 0), lat_h, lat_w, device=device) |
|
msk[:, 0] = 1 |
|
if has_end_image: |
|
msk[:, -1] = 1 |
|
|
|
|
|
with accelerator.autocast(), torch.no_grad(): |
|
|
|
padding_frames = frames - 1 |
|
img_resized = torch.concat([img_resized, torch.zeros(3, padding_frames, height, width, device=device)], dim=1) |
|
y = vae.encode([img_resized])[0] |
|
|
|
if has_end_image: |
|
y_end = vae.encode([end_img_resized])[0] |
|
y = torch.concat([y, y_end], dim=1) |
|
|
|
y = torch.concat([msk, y]) |
|
logger.info(f"Encoding complete") |
|
|
|
|
|
if config.is_fun_control: |
|
|
|
logger.info(f"Encoding control video to latent space") |
|
|
|
control_video = load_control_video(args.control_path, frames + (1 if has_end_image else 0), height, width).to(device) |
|
with accelerator.autocast(), torch.no_grad(): |
|
control_latent = vae.encode([control_video])[0] |
|
y = y[msk.shape[0] :] |
|
if has_end_image: |
|
y[:, 1:-1] = 0 |
|
else: |
|
y[:, 1:] = 0 |
|
y = torch.concat([control_latent, y], dim=0) |
|
|
|
|
|
arg_c = { |
|
"context": [context[0]], |
|
"clip_fea": clip_context, |
|
"seq_len": max_seq_len, |
|
"y": [y], |
|
} |
|
|
|
arg_null = { |
|
"context": context_null, |
|
"clip_fea": clip_context, |
|
"seq_len": max_seq_len, |
|
"y": [y], |
|
} |
|
|
|
vae.to_device("cpu") |
|
clean_memory_on_device(device) |
|
|
|
return noise, context, context_null, y, (arg_c, arg_null) |
|
|
|
|
|
def load_control_video(control_path: str, frames: int, height: int, width: int) -> torch.Tensor: |
|
"""load control video to latent space |
|
|
|
Args: |
|
control_path: path to control video |
|
frames: number of frames in the video |
|
height: height of the video |
|
width: width of the video |
|
|
|
Returns: |
|
torch.Tensor: control video latent, CFHW |
|
""" |
|
logger.info(f"Load control video from {control_path}") |
|
video = load_video(control_path, 0, frames, bucket_reso=(width, height)) |
|
if len(video) < frames: |
|
raise ValueError(f"Video length is less than {frames}") |
|
|
|
video = torch.stack([TF.to_tensor(frame).sub_(0.5).div_(0.5) for frame in video], dim=0) |
|
video = video.permute(1, 0, 2, 3) |
|
return video |
|
|
|
|
|
def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]: |
|
"""setup scheduler for sampling |
|
|
|
Args: |
|
args: command line arguments |
|
config: model configuration |
|
device: device to use |
|
|
|
Returns: |
|
Tuple[Any, torch.Tensor]: (scheduler, timesteps) |
|
""" |
|
if args.sample_solver == "unipc": |
|
scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False) |
|
scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift) |
|
timesteps = scheduler.timesteps |
|
elif args.sample_solver == "dpm++": |
|
scheduler = FlowDPMSolverMultistepScheduler( |
|
num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False |
|
) |
|
sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift) |
|
timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas) |
|
elif args.sample_solver == "vanilla": |
|
scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift) |
|
scheduler.set_timesteps(args.infer_steps, device=device) |
|
timesteps = scheduler.timesteps |
|
|
|
|
|
org_step = scheduler.step |
|
|
|
def step_wrapper( |
|
model_output: torch.Tensor, |
|
timestep: Union[int, torch.Tensor], |
|
sample: torch.Tensor, |
|
return_dict: bool = True, |
|
generator=None, |
|
): |
|
return org_step(model_output, timestep, sample, return_dict=return_dict) |
|
|
|
scheduler.step = step_wrapper |
|
else: |
|
raise NotImplementedError("Unsupported solver.") |
|
|
|
return scheduler, timesteps |
|
|
|
|
|
def run_sampling( |
|
model: WanModel, |
|
noise: torch.Tensor, |
|
scheduler: Any, |
|
timesteps: torch.Tensor, |
|
args: argparse.Namespace, |
|
inputs: Tuple[dict, dict], |
|
device: torch.device, |
|
seed_g: torch.Generator, |
|
accelerator: Accelerator, |
|
is_i2v: bool = False, |
|
use_cpu_offload: bool = True, |
|
) -> torch.Tensor: |
|
"""run sampling |
|
Args: |
|
model: dit model |
|
noise: initial noise |
|
scheduler: scheduler for sampling |
|
timesteps: time steps for sampling |
|
args: command line arguments |
|
inputs: model input (arg_c, arg_null) |
|
device: device to use |
|
seed_g: random generator |
|
accelerator: Accelerator instance |
|
is_i2v: I2V mode (False means T2V mode) |
|
use_cpu_offload: Whether to offload tensors to CPU during processing |
|
Returns: |
|
torch.Tensor: generated latent |
|
""" |
|
arg_c, arg_null = inputs |
|
|
|
latent = noise |
|
latent_storage_device = device if not use_cpu_offload else "cpu" |
|
latent = latent.to(latent_storage_device) |
|
|
|
|
|
apply_cfg_array = [] |
|
num_timesteps = len(timesteps) |
|
|
|
if args.cfg_skip_mode != "none" and args.cfg_apply_ratio is not None: |
|
|
|
apply_steps = int(num_timesteps * args.cfg_apply_ratio) |
|
|
|
if args.cfg_skip_mode == "early": |
|
|
|
start_index = num_timesteps - apply_steps |
|
end_index = num_timesteps |
|
elif args.cfg_skip_mode == "late": |
|
|
|
start_index = 0 |
|
end_index = apply_steps |
|
elif args.cfg_skip_mode == "early_late": |
|
|
|
start_index = (num_timesteps - apply_steps) // 2 |
|
end_index = start_index + apply_steps |
|
elif args.cfg_skip_mode == "middle": |
|
|
|
skip_steps = num_timesteps - apply_steps |
|
middle_start = (num_timesteps - skip_steps) // 2 |
|
middle_end = middle_start + skip_steps |
|
|
|
w = 0.0 |
|
for step_idx in range(num_timesteps): |
|
if args.cfg_skip_mode == "alternate": |
|
|
|
w += args.cfg_apply_ratio |
|
apply = w >= 1.0 |
|
if apply: |
|
w -= 1.0 |
|
elif args.cfg_skip_mode == "middle": |
|
|
|
apply = step_idx < middle_start or step_idx >= middle_end |
|
else: |
|
|
|
apply = step_idx >= start_index and step_idx < end_index |
|
|
|
apply_cfg_array.append(apply) |
|
|
|
pattern = ["A" if apply else "S" for apply in apply_cfg_array] |
|
pattern = "".join(pattern) |
|
logger.info(f"CFG skip mode: {args.cfg_skip_mode}, apply ratio: {args.cfg_apply_ratio}, pattern: {pattern}") |
|
else: |
|
|
|
apply_cfg_array = [True] * num_timesteps |
|
|
|
|
|
slg_start_step = int(args.slg_start * num_timesteps) |
|
slg_end_step = int(args.slg_end * num_timesteps) |
|
|
|
for i, t in enumerate(tqdm(timesteps)): |
|
|
|
latent_model_input = [latent.to(device)] |
|
timestep = torch.stack([t]).to(device) |
|
|
|
with accelerator.autocast(), torch.no_grad(): |
|
noise_pred_cond = model(latent_model_input, t=timestep, **arg_c)[0].to(latent_storage_device) |
|
|
|
apply_cfg = apply_cfg_array[i] |
|
if apply_cfg: |
|
apply_slg = i >= slg_start_step and i < slg_end_step |
|
|
|
if args.slg_mode == "original" and apply_slg: |
|
noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device) |
|
|
|
|
|
|
|
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
|
|
skip_layer_out = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to( |
|
latent_storage_device |
|
) |
|
|
|
|
|
|
|
noise_pred = noise_pred + args.slg_scale * (noise_pred_cond - skip_layer_out) |
|
elif args.slg_mode == "uncond" and apply_slg: |
|
|
|
noise_pred_uncond = model(latent_model_input, t=timestep, skip_block_indices=args.slg_layers, **arg_null)[0].to( |
|
latent_storage_device |
|
) |
|
|
|
|
|
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
|
|
else: |
|
|
|
noise_pred_uncond = model(latent_model_input, t=timestep, **arg_null)[0].to(latent_storage_device) |
|
|
|
|
|
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond) |
|
else: |
|
noise_pred = noise_pred_cond |
|
|
|
|
|
latent_input = latent.unsqueeze(0) |
|
temp_x0 = scheduler.step(noise_pred.unsqueeze(0), t, latent_input, return_dict=False, generator=seed_g)[0] |
|
|
|
|
|
latent = temp_x0.squeeze(0) |
|
|
|
return latent |
|
|
|
|
|
def generate(args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None) -> torch.Tensor: |
|
"""main function for generation |
|
|
|
Args: |
|
args: command line arguments |
|
shared_models: dictionary containing pre-loaded models and encoded data |
|
|
|
Returns: |
|
torch.Tensor: generated latent |
|
""" |
|
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = ( |
|
gen_settings.device, |
|
gen_settings.cfg, |
|
gen_settings.dit_dtype, |
|
gen_settings.dit_weight_dtype, |
|
gen_settings.vae_dtype, |
|
) |
|
|
|
|
|
mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16" |
|
accelerator = accelerate.Accelerator(mixed_precision=mixed_precision) |
|
|
|
|
|
is_i2v = "i2v" in args.task |
|
|
|
|
|
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1) |
|
args.seed = seed |
|
|
|
|
|
if shared_models is not None: |
|
|
|
vae = shared_models.get("vae") |
|
model = shared_models.get("model") |
|
encoded_context = shared_models.get("encoded_contexts", {}).get(args.prompt) |
|
|
|
|
|
if is_i2v: |
|
|
|
noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae, encoded_context) |
|
else: |
|
|
|
noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae, encoded_context) |
|
else: |
|
|
|
if is_i2v: |
|
|
|
vae = load_vae(args, cfg, device, vae_dtype) |
|
noise, context, context_null, y, inputs = prepare_i2v_inputs(args, cfg, accelerator, device, vae) |
|
|
|
else: |
|
|
|
vae = None |
|
if cfg.is_fun_control: |
|
|
|
vae = load_vae(args, cfg, device, vae_dtype) |
|
noise, context, context_null, inputs = prepare_t2v_inputs(args, cfg, accelerator, device, vae) |
|
|
|
|
|
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v) |
|
|
|
|
|
if args.lora_weight is not None and len(args.lora_weight) > 0: |
|
merge_lora_weights(lora_wan, model, args, device) |
|
|
|
|
|
if args.save_merged_model: |
|
return None |
|
|
|
|
|
optimize_model(model, args, device, dit_dtype, dit_weight_dtype) |
|
|
|
|
|
scheduler, timesteps = setup_scheduler(args, cfg, device) |
|
|
|
|
|
seed_g = torch.Generator(device=device) |
|
seed_g.manual_seed(seed) |
|
|
|
|
|
latent = run_sampling(model, noise, scheduler, timesteps, args, inputs, device, seed_g, accelerator, is_i2v) |
|
|
|
|
|
if shared_models is None: |
|
|
|
del model |
|
del scheduler |
|
synchronize_device(device) |
|
|
|
|
|
logger.info("Waiting for 5 seconds to finish block swap") |
|
time.sleep(5) |
|
|
|
gc.collect() |
|
clean_memory_on_device(device) |
|
|
|
|
|
if vae is None: |
|
args._vae = None |
|
else: |
|
args._vae = vae |
|
|
|
return latent |
|
|
|
|
|
def decode_latent(latent: torch.Tensor, args: argparse.Namespace, cfg) -> torch.Tensor: |
|
"""decode latent |
|
|
|
Args: |
|
latent: latent tensor |
|
args: command line arguments |
|
cfg: model configuration |
|
|
|
Returns: |
|
torch.Tensor: decoded video or image |
|
""" |
|
device = torch.device(args.device) |
|
|
|
|
|
vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else torch.bfloat16 |
|
if hasattr(args, "_vae") and args._vae is not None: |
|
vae = args._vae |
|
else: |
|
vae = load_vae(args, cfg, device, vae_dtype) |
|
|
|
vae.to_device(device) |
|
|
|
logger.info(f"Decoding video from latents: {latent.shape}") |
|
x0 = latent.to(device) |
|
|
|
with torch.autocast(device_type=device.type, dtype=vae_dtype), torch.no_grad(): |
|
videos = vae.decode(x0) |
|
|
|
|
|
if args.trim_tail_frames: |
|
videos[0] = videos[0][:, : -args.trim_tail_frames] |
|
|
|
logger.info(f"Decoding complete") |
|
video = videos[0] |
|
del videos |
|
video = video.to(torch.float32).cpu() |
|
|
|
return video |
|
|
|
|
|
def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str: |
|
"""Save latent to file |
|
|
|
Args: |
|
latent: latent tensor |
|
args: command line arguments |
|
height: height of frame |
|
width: width of frame |
|
|
|
Returns: |
|
str: Path to saved latent file |
|
""" |
|
save_path = args.save_path |
|
os.makedirs(save_path, exist_ok=True) |
|
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") |
|
|
|
seed = args.seed |
|
video_length = args.video_length |
|
latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors" |
|
|
|
if args.no_metadata: |
|
metadata = None |
|
else: |
|
metadata = { |
|
"seeds": f"{seed}", |
|
"prompt": f"{args.prompt}", |
|
"height": f"{height}", |
|
"width": f"{width}", |
|
"video_length": f"{video_length}", |
|
"infer_steps": f"{args.infer_steps}", |
|
"guidance_scale": f"{args.guidance_scale}", |
|
} |
|
if args.negative_prompt is not None: |
|
metadata["negative_prompt"] = f"{args.negative_prompt}" |
|
|
|
sd = {"latent": latent} |
|
save_file(sd, latent_path, metadata=metadata) |
|
logger.info(f"Latent saved to: {latent_path}") |
|
|
|
return latent_path |
|
|
|
|
|
def save_video(video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: |
|
"""Save video to file |
|
|
|
Args: |
|
video: Video tensor |
|
args: command line arguments |
|
original_base_name: Original base name (if latents are loaded from files) |
|
|
|
Returns: |
|
str: Path to saved video file |
|
""" |
|
save_path = args.save_path |
|
os.makedirs(save_path, exist_ok=True) |
|
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") |
|
|
|
seed = args.seed |
|
original_name = "" if original_base_name is None else f"_{original_base_name}" |
|
video_path = f"{save_path}/{time_flag}_{seed}{original_name}.mp4" |
|
|
|
video = video.unsqueeze(0) |
|
save_videos_grid(video, video_path, fps=args.fps, rescale=True) |
|
logger.info(f"Video saved to: {video_path}") |
|
|
|
return video_path |
|
|
|
|
|
def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str: |
|
"""Save images to directory |
|
|
|
Args: |
|
sample: Video tensor |
|
args: command line arguments |
|
original_base_name: Original base name (if latents are loaded from files) |
|
|
|
Returns: |
|
str: Path to saved images directory |
|
""" |
|
save_path = args.save_path |
|
os.makedirs(save_path, exist_ok=True) |
|
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S") |
|
|
|
seed = args.seed |
|
original_name = "" if original_base_name is None else f"_{original_base_name}" |
|
image_name = f"{time_flag}_{seed}{original_name}" |
|
sample = sample.unsqueeze(0) |
|
save_images_grid(sample, save_path, image_name, rescale=True) |
|
logger.info(f"Sample images saved to: {save_path}/{image_name}") |
|
|
|
return f"{save_path}/{image_name}" |
|
|
|
|
|
def save_output( |
|
latent: torch.Tensor, args: argparse.Namespace, cfg, height: int, width: int, original_base_names: Optional[List[str]] = None |
|
) -> None: |
|
"""save output |
|
|
|
Args: |
|
latent: latent tensor |
|
args: command line arguments |
|
cfg: model configuration |
|
height: height of frame |
|
width: width of frame |
|
original_base_names: original base names (if latents are loaded from files) |
|
""" |
|
if args.output_type == "latent" or args.output_type == "both": |
|
|
|
save_latent(latent, args, height, width) |
|
|
|
if args.output_type == "video" or args.output_type == "both": |
|
|
|
sample = decode_latent(latent.unsqueeze(0), args, cfg) |
|
original_name = "" if original_base_names is None else f"_{original_base_names[0]}" |
|
save_video(sample, args, original_name) |
|
|
|
elif args.output_type == "images": |
|
|
|
sample = decode_latent(latent.unsqueeze(0), args, cfg) |
|
original_name = "" if original_base_names is None else f"_{original_base_names[0]}" |
|
save_images(sample, args, original_name) |
|
|
|
|
|
def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]: |
|
"""Process multiple prompts for batch mode |
|
|
|
Args: |
|
prompt_lines: List of prompt lines |
|
base_args: Base command line arguments |
|
|
|
Returns: |
|
List[Dict]: List of prompt data dictionaries |
|
""" |
|
prompts_data = [] |
|
|
|
for line in prompt_lines: |
|
line = line.strip() |
|
if not line or line.startswith("#"): |
|
continue |
|
|
|
|
|
prompt_data = parse_prompt_line(line) |
|
logger.info(f"Parsed prompt data: {prompt_data}") |
|
prompts_data.append(prompt_data) |
|
|
|
return prompts_data |
|
|
|
|
|
def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None: |
|
"""Process multiple prompts with model reuse |
|
|
|
Args: |
|
prompts_data: List of prompt data dictionaries |
|
args: Base command line arguments |
|
""" |
|
if not prompts_data: |
|
logger.warning("No valid prompts found") |
|
return |
|
|
|
|
|
gen_settings = get_generation_settings(args) |
|
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = ( |
|
gen_settings.device, |
|
gen_settings.cfg, |
|
gen_settings.dit_dtype, |
|
gen_settings.dit_weight_dtype, |
|
gen_settings.vae_dtype, |
|
) |
|
is_i2v = "i2v" in args.task |
|
|
|
|
|
logger.info("Loading text encoder to encode all prompts") |
|
text_encoder = load_text_encoder(args, cfg, device) |
|
text_encoder.model.to(device) |
|
|
|
encoded_contexts = {} |
|
|
|
with torch.no_grad(): |
|
for prompt_data in prompts_data: |
|
prompt = prompt_data["prompt"] |
|
prompt_args = apply_overrides(args, prompt_data) |
|
n_prompt = prompt_data.get( |
|
"negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt |
|
) |
|
|
|
if args.fp8_t5: |
|
with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype): |
|
context = text_encoder([prompt], device) |
|
context_null = text_encoder([n_prompt], device) |
|
else: |
|
context = text_encoder([prompt], device) |
|
context_null = text_encoder([n_prompt], device) |
|
|
|
encoded_contexts[prompt] = {"context": context, "context_null": context_null} |
|
|
|
|
|
del text_encoder |
|
clean_memory_on_device(device) |
|
|
|
|
|
vae = None |
|
if is_i2v: |
|
logger.info("Loading VAE and CLIP for I2V preprocessing") |
|
vae = load_vae(args, cfg, device, vae_dtype) |
|
vae.to_device(device) |
|
|
|
clip = load_clip_model(args, cfg, device) |
|
clip.model.to(device) |
|
|
|
|
|
for prompt_data in prompts_data: |
|
if "image_path" not in prompt_data: |
|
continue |
|
|
|
prompt_args = apply_overrides(args, prompt_data) |
|
if not os.path.exists(prompt_args.image_path): |
|
logger.warning(f"Image path not found: {prompt_args.image_path}") |
|
continue |
|
|
|
|
|
img = Image.open(prompt_args.image_path).convert("RGB") |
|
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) |
|
|
|
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): |
|
clip_context = clip.visual([img_tensor[:, None, :, :]]) |
|
|
|
encoded_contexts[prompt_data["prompt"]]["clip_context"] = clip_context |
|
|
|
|
|
del clip |
|
clean_memory_on_device(device) |
|
|
|
|
|
vae.to_device("cpu") |
|
elif cfg.is_fun_control: |
|
|
|
vae = load_vae(args, cfg, device, vae_dtype) |
|
vae.to_device("cpu") |
|
|
|
|
|
logger.info("Loading DiT model") |
|
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v) |
|
|
|
|
|
if args.lora_weight is not None and len(args.lora_weight) > 0: |
|
merge_lora_weights(lora_wan, model, args, device) |
|
if args.save_merged_model: |
|
logger.info("Model merged and saved. Exiting.") |
|
return |
|
|
|
|
|
optimize_model(model, args, device, dit_dtype, dit_weight_dtype) |
|
|
|
|
|
shared_models = {"vae": vae, "model": model, "encoded_contexts": encoded_contexts} |
|
|
|
|
|
all_latents = [] |
|
all_prompt_args = [] |
|
|
|
for i, prompt_data in enumerate(prompts_data): |
|
logger.info(f"Processing prompt {i+1}/{len(prompts_data)}: {prompt_data['prompt'][:50]}...") |
|
|
|
|
|
prompt_args = apply_overrides(args, prompt_data) |
|
|
|
|
|
latent = generate(prompt_args, gen_settings, shared_models) |
|
|
|
|
|
height, width, _ = check_inputs(prompt_args) |
|
if prompt_args.output_type == "latent" or prompt_args.output_type == "both": |
|
save_latent(latent, prompt_args, height, width) |
|
|
|
all_latents.append(latent) |
|
all_prompt_args.append(prompt_args) |
|
|
|
|
|
del model |
|
clean_memory_on_device(device) |
|
synchronize_device(device) |
|
|
|
|
|
logger.info("Waiting for 5 seconds to finish block swap") |
|
time.sleep(5) |
|
|
|
gc.collect() |
|
clean_memory_on_device(device) |
|
|
|
|
|
if args.output_type != "latent": |
|
logger.info("Decoding latents to videos/images") |
|
|
|
if vae is None: |
|
vae = load_vae(args, cfg, device, vae_dtype) |
|
|
|
vae.to_device(device) |
|
|
|
for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)): |
|
logger.info(f"Decoding output {i+1}/{len(all_latents)}") |
|
|
|
|
|
video = decode_latent(latent.unsqueeze(0), prompt_args, cfg) |
|
|
|
|
|
if prompt_args.output_type == "video" or prompt_args.output_type == "both": |
|
save_video(video, prompt_args) |
|
elif prompt_args.output_type == "images": |
|
save_images(video, prompt_args) |
|
|
|
|
|
del vae |
|
|
|
clean_memory_on_device(device) |
|
gc.collect() |
|
|
|
|
|
def process_interactive(args: argparse.Namespace) -> None: |
|
"""Process prompts in interactive mode |
|
|
|
Args: |
|
args: Base command line arguments |
|
""" |
|
gen_settings = get_generation_settings(args) |
|
device, cfg, dit_dtype, dit_weight_dtype, vae_dtype = ( |
|
gen_settings.device, |
|
gen_settings.cfg, |
|
gen_settings.dit_dtype, |
|
gen_settings.dit_weight_dtype, |
|
gen_settings.vae_dtype, |
|
) |
|
is_i2v = "i2v" in args.task |
|
|
|
|
|
text_encoder = None |
|
vae = None |
|
model = None |
|
clip = None |
|
|
|
print("Interactive mode. Enter prompts (Ctrl+D to exit):") |
|
|
|
try: |
|
while True: |
|
try: |
|
line = input("> ") |
|
if not line.strip(): |
|
continue |
|
|
|
|
|
prompt_data = parse_prompt_line(line) |
|
prompt_args = apply_overrides(args, prompt_data) |
|
|
|
|
|
|
|
|
|
if text_encoder is None: |
|
logger.info("Loading text encoder") |
|
text_encoder = load_text_encoder(args, cfg, device) |
|
|
|
text_encoder.model.to(device) |
|
|
|
|
|
n_prompt = prompt_data.get( |
|
"negative_prompt", prompt_args.negative_prompt if prompt_args.negative_prompt else cfg.sample_neg_prompt |
|
) |
|
|
|
with torch.no_grad(): |
|
if args.fp8_t5: |
|
with torch.amp.autocast(device_type=device.type, dtype=cfg.t5_dtype): |
|
context = text_encoder([prompt_data["prompt"]], device) |
|
context_null = text_encoder([n_prompt], device) |
|
else: |
|
context = text_encoder([prompt_data["prompt"]], device) |
|
context_null = text_encoder([n_prompt], device) |
|
|
|
encoded_context = {"context": context, "context_null": context_null} |
|
|
|
|
|
text_encoder.model.to("cpu") |
|
|
|
|
|
if is_i2v: |
|
if clip is None: |
|
logger.info("Loading CLIP model") |
|
clip = load_clip_model(args, cfg, device) |
|
|
|
clip.model.to(device) |
|
|
|
|
|
if prompt_args.image_path and os.path.exists(prompt_args.image_path): |
|
img = Image.open(prompt_args.image_path).convert("RGB") |
|
img_tensor = TF.to_tensor(img).sub_(0.5).div_(0.5).to(device) |
|
|
|
with torch.amp.autocast(device_type=device.type, dtype=torch.float16), torch.no_grad(): |
|
clip_context = clip.visual([img_tensor[:, None, :, :]]) |
|
|
|
encoded_context["clip_context"] = clip_context |
|
|
|
|
|
clip.model.to("cpu") |
|
|
|
|
|
if vae is None: |
|
logger.info("Loading VAE model") |
|
vae = load_vae(args, cfg, device, vae_dtype) |
|
elif cfg.is_fun_control and vae is None: |
|
|
|
logger.info("Loading VAE model for Fun-Control") |
|
vae = load_vae(args, cfg, device, vae_dtype) |
|
|
|
|
|
if model is None: |
|
logger.info("Loading DiT model") |
|
model = load_dit_model(args, cfg, device, dit_dtype, dit_weight_dtype, is_i2v) |
|
|
|
|
|
if args.lora_weight is not None and len(args.lora_weight) > 0: |
|
merge_lora_weights(lora_wan, model, args, device) |
|
|
|
|
|
optimize_model(model, args, device, dit_dtype, dit_weight_dtype) |
|
else: |
|
|
|
model.to(device) |
|
|
|
|
|
shared_models = {"vae": vae, "model": model, "encoded_contexts": {prompt_data["prompt"]: encoded_context}} |
|
|
|
|
|
latent = generate(prompt_args, gen_settings, shared_models) |
|
|
|
|
|
model.to("cpu") |
|
|
|
|
|
height, width, _ = check_inputs(prompt_args) |
|
if prompt_args.output_type == "latent" or prompt_args.output_type == "both": |
|
save_latent(latent, prompt_args, height, width) |
|
|
|
|
|
if prompt_args.output_type != "latent": |
|
if vae is None: |
|
vae = load_vae(args, cfg, device, vae_dtype) |
|
|
|
vae.to_device(device) |
|
video = decode_latent(latent.unsqueeze(0), prompt_args, cfg) |
|
|
|
if prompt_args.output_type == "video" or prompt_args.output_type == "both": |
|
save_video(video, prompt_args) |
|
elif prompt_args.output_type == "images": |
|
save_images(video, prompt_args) |
|
|
|
|
|
vae.to_device("cpu") |
|
|
|
clean_memory_on_device(device) |
|
|
|
except KeyboardInterrupt: |
|
print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)") |
|
continue |
|
|
|
except EOFError: |
|
print("\nExiting interactive mode") |
|
|
|
|
|
if text_encoder is not None: |
|
del text_encoder |
|
if clip is not None: |
|
del clip |
|
if vae is not None: |
|
del vae |
|
if model is not None: |
|
del model |
|
|
|
clean_memory_on_device(device) |
|
gc.collect() |
|
|
|
|
|
def get_generation_settings(args: argparse.Namespace) -> GenerationSettings: |
|
device = torch.device(args.device) |
|
|
|
cfg = WAN_CONFIGS[args.task] |
|
|
|
|
|
dit_dtype = detect_wan_sd_dtype(args.dit) if args.dit is not None else torch.bfloat16 |
|
if dit_dtype.itemsize == 1: |
|
|
|
dit_dtype = torch.bfloat16 |
|
if args.fp8_scaled: |
|
raise ValueError( |
|
"DiT weights is already in fp8 format, cannot scale to fp8. Please use fp16/bf16 weights / DiTの重みはすでにfp8形式です。fp8にスケーリングできません。fp16/bf16の重みを使用してください" |
|
) |
|
|
|
dit_weight_dtype = dit_dtype |
|
if args.fp8_scaled: |
|
dit_weight_dtype = None |
|
elif args.fp8: |
|
dit_weight_dtype = torch.float8_e4m3fn |
|
|
|
vae_dtype = str_to_dtype(args.vae_dtype) if args.vae_dtype is not None else dit_dtype |
|
logger.info( |
|
f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}, VAE precision: {vae_dtype}" |
|
) |
|
|
|
gen_settings = GenerationSettings( |
|
device=device, |
|
cfg=cfg, |
|
dit_dtype=dit_dtype, |
|
dit_weight_dtype=dit_weight_dtype, |
|
vae_dtype=vae_dtype, |
|
) |
|
return gen_settings |
|
|
|
|
|
def main(): |
|
|
|
args = parse_args() |
|
|
|
|
|
latents_mode = args.latent_path is not None and len(args.latent_path) > 0 |
|
|
|
|
|
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" |
|
device = torch.device(device) |
|
logger.info(f"Using device: {device}") |
|
args.device = device |
|
|
|
if latents_mode: |
|
|
|
cfg = WAN_CONFIGS[args.task] |
|
original_base_names = [] |
|
latents_list = [] |
|
seeds = [] |
|
|
|
assert len(args.latent_path) == 1, "Only one latent path is supported for now" |
|
|
|
for latent_path in args.latent_path: |
|
original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0]) |
|
seed = 0 |
|
|
|
if os.path.splitext(latent_path)[1] != ".safetensors": |
|
latents = torch.load(latent_path, map_location="cpu") |
|
else: |
|
latents = load_file(latent_path)["latent"] |
|
with safe_open(latent_path, framework="pt") as f: |
|
metadata = f.metadata() |
|
if metadata is None: |
|
metadata = {} |
|
logger.info(f"Loaded metadata: {metadata}") |
|
|
|
if "seeds" in metadata: |
|
seed = int(metadata["seeds"]) |
|
if "height" in metadata and "width" in metadata: |
|
height = int(metadata["height"]) |
|
width = int(metadata["width"]) |
|
args.video_size = [height, width] |
|
if "video_length" in metadata: |
|
args.video_length = int(metadata["video_length"]) |
|
|
|
seeds.append(seed) |
|
latents_list.append(latents) |
|
|
|
logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}") |
|
|
|
latent = torch.stack(latents_list, dim=0) |
|
|
|
height = latents.shape[-2] |
|
width = latents.shape[-1] |
|
height *= cfg.patch_size[1] * cfg.vae_stride[1] |
|
width *= cfg.patch_size[2] * cfg.vae_stride[2] |
|
video_length = latents.shape[1] |
|
video_length = (video_length - 1) * cfg.vae_stride[0] + 1 |
|
args.seed = seeds[0] |
|
|
|
|
|
save_output(latent[0], args, cfg, height, width, original_base_names) |
|
|
|
elif args.from_file: |
|
|
|
args = setup_args(args) |
|
|
|
|
|
with open(args.from_file, "r", encoding="utf-8") as f: |
|
prompt_lines = f.readlines() |
|
|
|
|
|
prompts_data = preprocess_prompts_for_batch(prompt_lines, args) |
|
process_batch_prompts(prompts_data, args) |
|
|
|
elif args.interactive: |
|
|
|
args = setup_args(args) |
|
process_interactive(args) |
|
|
|
else: |
|
|
|
args = setup_args(args) |
|
height, width, video_length = check_inputs(args) |
|
|
|
logger.info( |
|
f"Video size: {height}x{width}@{video_length} (HxW@F), fps: {args.fps}, " |
|
f"infer_steps: {args.infer_steps}, flow_shift: {args.flow_shift}" |
|
) |
|
|
|
|
|
gen_settings = get_generation_settings(args) |
|
latent = generate(args, gen_settings) |
|
|
|
|
|
gc.collect() |
|
clean_memory_on_device(args.device) |
|
|
|
|
|
if args.save_merged_model: |
|
return |
|
|
|
|
|
latent = latent.unsqueeze(0) |
|
save_output(latent[0], args, WAN_CONFIGS[args.task], height, width) |
|
|
|
logger.info("Done!") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|