File size: 7,506 Bytes
ef46f0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import argparse
import os
import glob
from typing import Optional, Union
import numpy as np
import torch
from tqdm import tqdm
from dataset import config_utils
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
from PIL import Image
import logging
from dataset.image_video_dataset import ItemInfo, save_latent_cache_wan, ARCHITECTURE_WAN
from utils.model_utils import str_to_dtype
from wan.configs import wan_i2v_14B
from wan.modules.vae import WanVAE
from wan.modules.clip import CLIPModel
import cache_latents
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def encode_and_save_batch(vae: WanVAE, clip: Optional[CLIPModel], batch: list[ItemInfo]):
contents = torch.stack([torch.from_numpy(item.content) for item in batch])
if len(contents.shape) == 4:
contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
contents = contents.to(vae.device, dtype=vae.dtype)
contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
h, w = contents.shape[3], contents.shape[4]
if h < 8 or w < 8:
item = batch[0] # other items should have the same size
raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
# print(f"encode batch: {contents.shape}")
with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
latent = vae.encode(contents) # list of Tensor[C, F, H, W]
latent = torch.stack(latent, dim=0) # B, C, F, H, W
latent = latent.to(vae.dtype) # convert to bfloat16, we are not sure if this is correct
if clip is not None:
# extract first frame of contents
images = contents[:, :, 0:1, :, :] # B, C, F, H, W, non contiguous view is fine
with torch.amp.autocast(device_type=clip.device.type, dtype=torch.float16), torch.no_grad():
clip_context = clip.visual(images)
clip_context = clip_context.to(torch.float16) # convert to fp16
# encode image latent for I2V
B, _, _, lat_h, lat_w = latent.shape
F = contents.shape[2]
# Create mask for the required number of frames
msk = torch.ones(1, F, lat_h, lat_w, dtype=vae.dtype, device=vae.device)
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2) # 1, F, 4, H, W -> 1, 4, F, H, W
msk = msk.repeat(B, 1, 1, 1, 1) # B, 4, F, H, W
# Zero padding for the required number of frames only
padding_frames = F - 1 # The first frame is the input image
images_resized = torch.concat([images, torch.zeros(B, 3, padding_frames, h, w, device=vae.device)], dim=2)
with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
y = vae.encode(images_resized)
y = torch.stack(y, dim=0) # B, C, F, H, W
y = y[:, :, :F] # may be not needed
y = y.to(vae.dtype) # convert to bfloat16
y = torch.concat([msk, y], dim=1) # B, 4 + C, F, H, W
else:
clip_context = None
y = None
# control videos
if batch[0].control_content is not None:
control_contents = torch.stack([torch.from_numpy(item.control_content) for item in batch])
if len(control_contents.shape) == 4:
control_contents = control_contents.unsqueeze(1)
control_contents = control_contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
control_contents = control_contents.to(vae.device, dtype=vae.dtype)
control_contents = control_contents / 127.5 - 1.0 # normalize to [-1, 1]
with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
control_latent = vae.encode(control_contents) # list of Tensor[C, F, H, W]
control_latent = torch.stack(control_latent, dim=0) # B, C, F, H, W
control_latent = control_latent.to(vae.dtype) # convert to bfloat16
else:
control_latent = None
# # debug: decode and save
# with torch.no_grad():
# latent_to_decode = latent / vae.config.scaling_factor
# images = vae.decode(latent_to_decode, return_dict=False)[0]
# images = (images / 2 + 0.5).clamp(0, 1)
# images = images.cpu().float().numpy()
# images = (images * 255).astype(np.uint8)
# images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
# for b in range(images.shape[0]):
# for f in range(images.shape[1]):
# fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
# img = Image.fromarray(images[b, f])
# img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
for i, item in enumerate(batch):
l = latent[i]
cctx = clip_context[i] if clip is not None else None
y_i = y[i] if clip is not None else None
control_latent_i = control_latent[i] if control_latent is not None else None
# print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
save_latent_cache_wan(item, l, cctx, y_i, control_latent_i)
def main(args):
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Load dataset config
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
datasets = train_dataset_group.datasets
if args.debug_mode is not None:
cache_latents.show_datasets(
datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16
)
return
assert args.vae is not None, "vae checkpoint is required"
vae_path = args.vae
logger.info(f"Loading VAE model from {vae_path}")
vae_dtype = torch.bfloat16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
cache_device = torch.device("cpu") if args.vae_cache_cpu else None
vae = WanVAE(vae_path=vae_path, device=device, dtype=vae_dtype, cache_device=cache_device)
if args.clip is not None:
clip_dtype = wan_i2v_14B.i2v_14B["clip_dtype"]
clip = CLIPModel(dtype=clip_dtype, device=device, weight_path=args.clip)
else:
clip = None
# Encode images
def encode(one_batch: list[ItemInfo]):
encode_and_save_batch(vae, clip, one_batch)
cache_latents.encode_datasets(datasets, encode, args)
def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
parser.add_argument(
"--clip",
type=str,
default=None,
help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required",
)
return parser
if __name__ == "__main__":
parser = cache_latents.setup_parser_common()
parser = wan_setup_parser(parser)
args = parser.parse_args()
main(args)
|