File size: 8,009 Bytes
5602c9a
 
 
 
 
 
 
 
 
 
 
 
33f757a
 
5602c9a
 
 
 
33f757a
 
5602c9a
 
33f757a
5602c9a
 
33f757a
5602c9a
 
 
33f757a
5602c9a
 
 
 
 
 
 
 
33f757a
 
 
 
 
 
5602c9a
33f757a
 
 
 
 
 
 
 
 
 
 
 
 
5602c9a
 
 
 
 
33f757a
5602c9a
 
 
 
 
33f757a
5602c9a
33f757a
 
5602c9a
 
 
 
 
 
 
 
 
33f757a
5602c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
33f757a
5602c9a
 
 
 
33f757a
 
 
 
 
 
5602c9a
33f757a
 
5602c9a
 
33f757a
 
5602c9a
 
 
 
33f757a
5602c9a
 
33f757a
5602c9a
33f757a
 
5602c9a
33f757a
 
5602c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
33f757a
 
 
5602c9a
 
33f757a
5602c9a
 
33f757a
5602c9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33f757a
5602c9a
 
 
 
33f757a
 
 
5602c9a
 
 
 
33f757a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
import numpy as np
from PIL import Image
from einops import rearrange
from pathlib import Path

import torch
from torch.utils.data import Dataset

from .transform import short_size_scale, random_crop, center_crop, offset_crop
from ..common.image_util import IMAGE_EXTENSION
import cv2
import imageio
import shutil

class ImageSequenceDataset(Dataset):
    def __init__(
        self,
        path: str,  # 输入视频,如果是 mp4 则转换到固定目录 './input-video'
        layout_files: list,  # 上传的 layout mask 文件列表(mp4 或目录),转换后存放到固定目录 './layout_masks/1', './layout_masks/2', ...
        prompt_ids: torch.Tensor,
        prompt: str,
        start_sample_frame: int = 0,
        n_sample_frame: int = 8,
        sampling_rate: int = 1,
        stride: int = -1,  # tuning 时用于对长视频进行采样
        image_mode: str = "RGB",
        image_size: int = 512,
        crop: str = "center",

        offset: dict = {
            "left": 0,
            "right": 0,
            "top": 0,
            "bottom": 0
        },
        **args
    ):
        # 若输入视频是 mp4,则转换到固定目录 './input-video'
        if path.endswith('.mp4'):
            self.path = self.mp4_to_png(path, target_dir='./input-video')
        else:
            self.path = path
        self.images = self.get_image_list(self.path)
        
        # 对每个上传的 layout 文件进行处理
        # 若是 mp4,则转换到固定目录 './layout_masks/{i+1}'
        self.layout_mask_dirs = []
        for idx, file in enumerate(layout_files):
            if file.endswith('.mp4'):
                folder = self.mp4_to_png(file, target_dir=f'./layout_masks/{idx+1}')
            else:
                folder = file
            self.layout_mask_dirs.append(folder)
        # 保持上传顺序作为 layout_mask_order(此处仅用索引表示顺序)
        self.layout_mask_order = list(range(len(self.layout_mask_dirs)))
        # 用第一个 layout mask 目录获取 mask 图像索引(用于判断帧数)
        self.masks_index = self.get_image_list(self.layout_mask_dirs[0])

        self.n_images = len(self.images)
        self.offset = offset
        self.start_sample_frame = start_sample_frame
        if n_sample_frame < 0:
            n_sample_frame = len(self.images)
        self.n_sample_frame = n_sample_frame
        self.sampling_rate = sampling_rate

        self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1
        if self.n_images < self.sequence_length:
            raise ValueError(f"self.n_images {self.n_images} < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images}")
        
        # 若视频太长,则全局采样
        self.stride = stride if stride > 0 else (self.n_images + 1)
        self.video_len = (self.n_images - self.sequence_length) // self.stride + 1

        self.image_mode = image_mode
        self.image_size = image_size
        crop_methods = {
            "center": center_crop,
            "random": random_crop,
        }
        if crop not in crop_methods:
            raise ValueError("Unsupported crop method")
        self.crop = crop_methods[crop]

        self.prompt = prompt
        self.prompt_ids = prompt_ids
        
        
    def __len__(self):
        max_len = (self.n_images - self.sequence_length) // self.stride + 1
        if hasattr(self, 'num_class_images'):
            max_len = max(max_len, self.num_class_images)
        return max_len

    def __getitem__(self, index):
        return_batch = {}
        frame_indices = self.get_frame_indices(index % self.video_len)
        frames = [self.load_frame(i) for i in frame_indices]
        frames = self.transform(frames)
        
        layout_ = []
        # 遍历每个 layout mask 目录(顺序与用户上传顺序一致)
        for layout_dir in self.layout_mask_dirs:
            # 对于每个 layout 目录,根据帧索引读取对应的 mask 图像(PNG 文件)
            frame_indices_local = self.get_frame_indices(index % self.video_len)
            mask = [self._read_mask(layout_dir, i) for i in frame_indices_local]
            masks = np.stack(mask)  # shape: (n_sample_frame, c, h, w)
            layout_.append(masks)
        layout_ = np.stack(layout_)  # shape: (num_layouts, n_sample_frame, c, h, w)
        
        merged_masks = []
        for i in range(int(self.n_sample_frame)):
            merged_mask_frame = np.sum(layout_[:, i, :, :, :], axis=0)
            merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8)
            merged_masks.append(merged_mask_frame)
        masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w")
        masks = torch.from_numpy(masks).half()

        layouts = rearrange(layout_, "s f c h w -> f s c h w")
        layouts = torch.from_numpy(layouts).half()

        return_batch.update({
            "images": frames,
            "masks": masks,
            "layouts": layouts,
            "prompt_ids": self.prompt_ids,
        })

        return return_batch
    
    def transform(self, frames):
        frames = self.tensorize_frames(frames)
        frames = offset_crop(frames, **self.offset)
        frames = short_size_scale(frames, size=self.image_size)
        frames = self.crop(frames, height=self.image_size, width=self.image_size)
        return frames

    @staticmethod
    def tensorize_frames(frames):
        frames = rearrange(np.stack(frames), "f h w c -> c f h w")
        return torch.from_numpy(frames).div(255) * 2 - 1

    def _read_mask(self, mask_dir, index: int):
        # 构造 mask 文件名(png 格式)
        mask_path = os.path.join(mask_dir, f"{index:05d}.png")
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 0).astype(np.uint8)
        # 根据原图大小动态缩放(这里缩小8倍)
        height, width = mask.shape
        dest_size = (width // 8, height // 8)
        mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST)
        mask = mask[np.newaxis, ...]
        return mask

    def load_frame(self, index):
        image_path = os.path.join(self.path, self.images[index])
        return Image.open(image_path).convert(self.image_mode)

    def load_class_frame(self, index):
        image_path = self.class_images_path[index]
        return Image.open(image_path).convert(self.image_mode)

    def get_frame_indices(self, index):
        if self.start_sample_frame is not None:
            frame_start = self.start_sample_frame + self.stride * index
        else:
            frame_start = self.stride * index
        return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame))

    def get_class_indices(self, index):
        frame_start = index
        return (frame_start + i for i in range(self.n_sample_frame))

    @staticmethod
    def get_image_list(path):
        images = []
        # 如果传入的是 mp4 文件,则先转换成 PNG 图像目录
        if path.endswith('.mp4'):
            path = ImageSequenceDataset.mp4_to_png(path, target_dir='./input-video')
        for file in sorted(os.listdir(path)):
            if file.endswith(IMAGE_EXTENSION):
                images.append(file)
        return images

    @staticmethod
    def mp4_to_png(video_source: str, target_dir: str):
        """
        Convert an mp4 video to a sequence of PNG images, storing them in target_dir.
        target_dir 为固定路径,例如:'./input-video' 或 './layout_masks/1'
        """
        if os.path.exists(target_dir):
            shutil.rmtree(target_dir)
        os.makedirs(target_dir, exist_ok=True)
        
        reader = imageio.get_reader(video_source)
        for i, im in enumerate(reader):
            path = os.path.join(target_dir, f"{i:05d}.png")
            cv2.imwrite(path, im[:, :, ::-1])
        return target_dir