import os import numpy as np from PIL import Image from einops import rearrange from pathlib import Path import torch from torch.utils.data import Dataset from .transform import short_size_scale, random_crop, center_crop, offset_crop from ..common.image_util import IMAGE_EXTENSION import cv2 import imageio import shutil class ImageSequenceDataset(Dataset): def __init__( self, path: str, # 输入视频,如果是 mp4 则转换到固定目录 './input-video' layout_files: list, # 上传的 layout mask 文件列表(mp4 或目录),转换后存放到固定目录 './layout_masks/1', './layout_masks/2', ... prompt_ids: torch.Tensor, prompt: str, start_sample_frame: int = 0, n_sample_frame: int = 8, sampling_rate: int = 1, stride: int = -1, # tuning 时用于对长视频进行采样 image_mode: str = "RGB", image_size: int = 512, crop: str = "center", offset: dict = { "left": 0, "right": 0, "top": 0, "bottom": 0 }, **args ): # 若输入视频是 mp4,则转换到固定目录 './input-video' if path.endswith('.mp4'): self.path = self.mp4_to_png(path, target_dir='./input-video') else: self.path = path self.images = self.get_image_list(self.path) # 对每个上传的 layout 文件进行处理 # 若是 mp4,则转换到固定目录 './layout_masks/{i+1}' self.layout_mask_dirs = [] for idx, file in enumerate(layout_files): if file.endswith('.mp4'): folder = self.mp4_to_png(file, target_dir=f'./layout_masks/{idx+1}') else: folder = file self.layout_mask_dirs.append(folder) # 保持上传顺序作为 layout_mask_order(此处仅用索引表示顺序) self.layout_mask_order = list(range(len(self.layout_mask_dirs))) # 用第一个 layout mask 目录获取 mask 图像索引(用于判断帧数) self.masks_index = self.get_image_list(self.layout_mask_dirs[0]) self.n_images = len(self.images) self.offset = offset self.start_sample_frame = start_sample_frame if n_sample_frame < 0: n_sample_frame = len(self.images) self.n_sample_frame = n_sample_frame self.sampling_rate = sampling_rate self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1 if self.n_images < self.sequence_length: raise ValueError(f"self.n_images {self.n_images} < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images}") # 若视频太长,则全局采样 self.stride = stride if stride > 0 else (self.n_images + 1) self.video_len = (self.n_images - self.sequence_length) // self.stride + 1 self.image_mode = image_mode self.image_size = image_size crop_methods = { "center": center_crop, "random": random_crop, } if crop not in crop_methods: raise ValueError("Unsupported crop method") self.crop = crop_methods[crop] self.prompt = prompt self.prompt_ids = prompt_ids def __len__(self): max_len = (self.n_images - self.sequence_length) // self.stride + 1 if hasattr(self, 'num_class_images'): max_len = max(max_len, self.num_class_images) return max_len def __getitem__(self, index): return_batch = {} frame_indices = self.get_frame_indices(index % self.video_len) frames = [self.load_frame(i) for i in frame_indices] frames = self.transform(frames) layout_ = [] # 遍历每个 layout mask 目录(顺序与用户上传顺序一致) for layout_dir in self.layout_mask_dirs: # 对于每个 layout 目录,根据帧索引读取对应的 mask 图像(PNG 文件) frame_indices_local = self.get_frame_indices(index % self.video_len) mask = [self._read_mask(layout_dir, i) for i in frame_indices_local] masks = np.stack(mask) # shape: (n_sample_frame, c, h, w) layout_.append(masks) layout_ = np.stack(layout_) # shape: (num_layouts, n_sample_frame, c, h, w) merged_masks = [] for i in range(int(self.n_sample_frame)): merged_mask_frame = np.sum(layout_[:, i, :, :, :], axis=0) merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8) merged_masks.append(merged_mask_frame) masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w") masks = torch.from_numpy(masks).half() layouts = rearrange(layout_, "s f c h w -> f s c h w") layouts = torch.from_numpy(layouts).half() return_batch.update({ "images": frames, "masks": masks, "layouts": layouts, "prompt_ids": self.prompt_ids, }) return return_batch def transform(self, frames): frames = self.tensorize_frames(frames) frames = offset_crop(frames, **self.offset) frames = short_size_scale(frames, size=self.image_size) frames = self.crop(frames, height=self.image_size, width=self.image_size) return frames @staticmethod def tensorize_frames(frames): frames = rearrange(np.stack(frames), "f h w c -> c f h w") return torch.from_numpy(frames).div(255) * 2 - 1 def _read_mask(self, mask_dir, index: int): # 构造 mask 文件名(png 格式) mask_path = os.path.join(mask_dir, f"{index:05d}.png") mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) mask = (mask > 0).astype(np.uint8) # 根据原图大小动态缩放(这里缩小8倍) height, width = mask.shape dest_size = (width // 8, height // 8) mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) mask = mask[np.newaxis, ...] return mask def load_frame(self, index): image_path = os.path.join(self.path, self.images[index]) return Image.open(image_path).convert(self.image_mode) def load_class_frame(self, index): image_path = self.class_images_path[index] return Image.open(image_path).convert(self.image_mode) def get_frame_indices(self, index): if self.start_sample_frame is not None: frame_start = self.start_sample_frame + self.stride * index else: frame_start = self.stride * index return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame)) def get_class_indices(self, index): frame_start = index return (frame_start + i for i in range(self.n_sample_frame)) @staticmethod def get_image_list(path): images = [] # 如果传入的是 mp4 文件,则先转换成 PNG 图像目录 if path.endswith('.mp4'): path = ImageSequenceDataset.mp4_to_png(path, target_dir='./input-video') for file in sorted(os.listdir(path)): if file.endswith(IMAGE_EXTENSION): images.append(file) return images @staticmethod def mp4_to_png(video_source: str, target_dir: str): """ Convert an mp4 video to a sequence of PNG images, storing them in target_dir. target_dir 为固定路径,例如:'./input-video' 或 './layout_masks/1' """ if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir, exist_ok=True) reader = imageio.get_reader(video_source) for i, im in enumerate(reader): path = os.path.join(target_dir, f"{i:05d}.png") cv2.imwrite(path, im[:, :, ::-1]) return target_dir