File size: 7,506 Bytes
ef46f0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import argparse
import os
import glob
from typing import Optional, Union

import numpy as np
import torch
from tqdm import tqdm

from dataset import config_utils
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
from PIL import Image

import logging

from dataset.image_video_dataset import ItemInfo, save_latent_cache_wan, ARCHITECTURE_WAN
from utils.model_utils import str_to_dtype
from wan.configs import wan_i2v_14B
from wan.modules.vae import WanVAE
from wan.modules.clip import CLIPModel
import cache_latents

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


def encode_and_save_batch(vae: WanVAE, clip: Optional[CLIPModel], batch: list[ItemInfo]):
    contents = torch.stack([torch.from_numpy(item.content) for item in batch])
    if len(contents.shape) == 4:
        contents = contents.unsqueeze(1)  # B, H, W, C -> B, F, H, W, C

    contents = contents.permute(0, 4, 1, 2, 3).contiguous()  # B, C, F, H, W
    contents = contents.to(vae.device, dtype=vae.dtype)
    contents = contents / 127.5 - 1.0  # normalize to [-1, 1]

    h, w = contents.shape[3], contents.shape[4]
    if h < 8 or w < 8:
        item = batch[0]  # other items should have the same size
        raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")

    # print(f"encode batch: {contents.shape}")
    with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
        latent = vae.encode(contents)  # list of Tensor[C, F, H, W]
    latent = torch.stack(latent, dim=0)  # B, C, F, H, W
    latent = latent.to(vae.dtype)  # convert to bfloat16, we are not sure if this is correct

    if clip is not None:
        # extract first frame of contents
        images = contents[:, :, 0:1, :, :]  # B, C, F, H, W, non contiguous view is fine

        with torch.amp.autocast(device_type=clip.device.type, dtype=torch.float16), torch.no_grad():
            clip_context = clip.visual(images)
        clip_context = clip_context.to(torch.float16)  # convert to fp16

        # encode image latent for I2V
        B, _, _, lat_h, lat_w = latent.shape
        F = contents.shape[2]

        # Create mask for the required number of frames
        msk = torch.ones(1, F, lat_h, lat_w, dtype=vae.dtype, device=vae.device)
        msk[:, 1:] = 0
        msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
        msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
        msk = msk.transpose(1, 2)  # 1, F, 4, H, W -> 1, 4, F, H, W
        msk = msk.repeat(B, 1, 1, 1, 1)  # B, 4, F, H, W

        # Zero padding for the required number of frames only
        padding_frames = F - 1  # The first frame is the input image
        images_resized = torch.concat([images, torch.zeros(B, 3, padding_frames, h, w, device=vae.device)], dim=2)
        with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
            y = vae.encode(images_resized)
        y = torch.stack(y, dim=0)  # B, C, F, H, W

        y = y[:, :, :F]  # may be not needed
        y = y.to(vae.dtype)  # convert to bfloat16
        y = torch.concat([msk, y], dim=1)  # B, 4 + C, F, H, W

    else:
        clip_context = None
        y = None

    # control videos
    if batch[0].control_content is not None:
        control_contents = torch.stack([torch.from_numpy(item.control_content) for item in batch])
        if len(control_contents.shape) == 4:
            control_contents = control_contents.unsqueeze(1)
        control_contents = control_contents.permute(0, 4, 1, 2, 3).contiguous()  # B, C, F, H, W
        control_contents = control_contents.to(vae.device, dtype=vae.dtype)
        control_contents = control_contents / 127.5 - 1.0  # normalize to [-1, 1]
        with torch.amp.autocast(device_type=vae.device.type, dtype=vae.dtype), torch.no_grad():
            control_latent = vae.encode(control_contents)  # list of Tensor[C, F, H, W]
        control_latent = torch.stack(control_latent, dim=0)  # B, C, F, H, W
        control_latent = control_latent.to(vae.dtype)  # convert to bfloat16
    else:
        control_latent = None

    # # debug: decode and save
    # with torch.no_grad():
    #     latent_to_decode = latent / vae.config.scaling_factor
    #     images = vae.decode(latent_to_decode, return_dict=False)[0]
    #     images = (images / 2 + 0.5).clamp(0, 1)
    #     images = images.cpu().float().numpy()
    #     images = (images * 255).astype(np.uint8)
    #     images = images.transpose(0, 2, 3, 4, 1)  # B, C, F, H, W -> B, F, H, W, C
    #     for b in range(images.shape[0]):
    #         for f in range(images.shape[1]):
    #             fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
    #             img = Image.fromarray(images[b, f])
    #             img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")

    for i, item in enumerate(batch):
        l = latent[i]
        cctx = clip_context[i] if clip is not None else None
        y_i = y[i] if clip is not None else None
        control_latent_i = control_latent[i] if control_latent is not None else None
        # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
        save_latent_cache_wan(item, l, cctx, y_i, control_latent_i)


def main(args):
    device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)

    # Load dataset config
    blueprint_generator = BlueprintGenerator(ConfigSanitizer())
    logger.info(f"Load dataset config from {args.dataset_config}")
    user_config = config_utils.load_user_config(args.dataset_config)
    blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN)
    train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)

    datasets = train_dataset_group.datasets

    if args.debug_mode is not None:
        cache_latents.show_datasets(
            datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16
        )
        return

    assert args.vae is not None, "vae checkpoint is required"

    vae_path = args.vae

    logger.info(f"Loading VAE model from {vae_path}")
    vae_dtype = torch.bfloat16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
    cache_device = torch.device("cpu") if args.vae_cache_cpu else None
    vae = WanVAE(vae_path=vae_path, device=device, dtype=vae_dtype, cache_device=cache_device)

    if args.clip is not None:
        clip_dtype = wan_i2v_14B.i2v_14B["clip_dtype"]
        clip = CLIPModel(dtype=clip_dtype, device=device, weight_path=args.clip)
    else:
        clip = None

    # Encode images
    def encode(one_batch: list[ItemInfo]):
        encode_and_save_batch(vae, clip, one_batch)

    cache_latents.encode_datasets(datasets, encode, args)


def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
    parser.add_argument("--vae_cache_cpu", action="store_true", help="cache features in VAE on CPU")
    parser.add_argument(
        "--clip",
        type=str,
        default=None,
        help="text encoder (CLIP) checkpoint path, optional. If training I2V model, this is required",
    )
    return parser


if __name__ == "__main__":
    parser = cache_latents.setup_parser_common()
    parser = wan_setup_parser(parser)

    args = parser.parse_args()
    main(args)