Spaces:
Configuration error
Configuration error
import os | |
import numpy as np | |
from PIL import Image | |
from einops import rearrange | |
from pathlib import Path | |
import torch | |
from torch.utils.data import Dataset | |
from .transform import short_size_scale, random_crop, center_crop, offset_crop | |
from ..common.image_util import IMAGE_EXTENSION | |
import cv2 | |
import imageio | |
import shutil | |
class ImageSequenceDataset(Dataset): | |
def __init__( | |
self, | |
path: str, # 输入视频,如果是 mp4 则转换到固定目录 './input-video' | |
layout_files: list, # 上传的 layout mask 文件列表(mp4 或目录),转换后存放到固定目录 './layout_masks/1', './layout_masks/2', ... | |
prompt_ids: torch.Tensor, | |
prompt: str, | |
start_sample_frame: int = 0, | |
n_sample_frame: int = 8, | |
sampling_rate: int = 1, | |
stride: int = -1, # tuning 时用于对长视频进行采样 | |
image_mode: str = "RGB", | |
image_size: int = 512, | |
crop: str = "center", | |
offset: dict = { | |
"left": 0, | |
"right": 0, | |
"top": 0, | |
"bottom": 0 | |
}, | |
**args | |
): | |
# 若输入视频是 mp4,则转换到固定目录 './input-video' | |
if path.endswith('.mp4'): | |
self.path = self.mp4_to_png(path, target_dir='./input-video') | |
else: | |
self.path = path | |
self.images = self.get_image_list(self.path) | |
# 对每个上传的 layout 文件进行处理 | |
# 若是 mp4,则转换到固定目录 './layout_masks/{i+1}' | |
self.layout_mask_dirs = [] | |
for idx, file in enumerate(layout_files): | |
if file.endswith('.mp4'): | |
folder = self.mp4_to_png(file, target_dir=f'./layout_masks/{idx+1}') | |
else: | |
folder = file | |
self.layout_mask_dirs.append(folder) | |
# 保持上传顺序作为 layout_mask_order(此处仅用索引表示顺序) | |
self.layout_mask_order = list(range(len(self.layout_mask_dirs))) | |
# 用第一个 layout mask 目录获取 mask 图像索引(用于判断帧数) | |
self.masks_index = self.get_image_list(self.layout_mask_dirs[0]) | |
self.n_images = len(self.images) | |
self.offset = offset | |
self.start_sample_frame = start_sample_frame | |
if n_sample_frame < 0: | |
n_sample_frame = len(self.images) | |
self.n_sample_frame = n_sample_frame | |
self.sampling_rate = sampling_rate | |
self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1 | |
if self.n_images < self.sequence_length: | |
raise ValueError(f"self.n_images {self.n_images} < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images}") | |
# 若视频太长,则全局采样 | |
self.stride = stride if stride > 0 else (self.n_images + 1) | |
self.video_len = (self.n_images - self.sequence_length) // self.stride + 1 | |
self.image_mode = image_mode | |
self.image_size = image_size | |
crop_methods = { | |
"center": center_crop, | |
"random": random_crop, | |
} | |
if crop not in crop_methods: | |
raise ValueError("Unsupported crop method") | |
self.crop = crop_methods[crop] | |
self.prompt = prompt | |
self.prompt_ids = prompt_ids | |
def __len__(self): | |
max_len = (self.n_images - self.sequence_length) // self.stride + 1 | |
if hasattr(self, 'num_class_images'): | |
max_len = max(max_len, self.num_class_images) | |
return max_len | |
def __getitem__(self, index): | |
return_batch = {} | |
frame_indices = self.get_frame_indices(index % self.video_len) | |
frames = [self.load_frame(i) for i in frame_indices] | |
frames = self.transform(frames) | |
layout_ = [] | |
# 遍历每个 layout mask 目录(顺序与用户上传顺序一致) | |
for layout_dir in self.layout_mask_dirs: | |
# 对于每个 layout 目录,根据帧索引读取对应的 mask 图像(PNG 文件) | |
frame_indices_local = self.get_frame_indices(index % self.video_len) | |
mask = [self._read_mask(layout_dir, i) for i in frame_indices_local] | |
masks = np.stack(mask) # shape: (n_sample_frame, c, h, w) | |
layout_.append(masks) | |
layout_ = np.stack(layout_) # shape: (num_layouts, n_sample_frame, c, h, w) | |
merged_masks = [] | |
for i in range(int(self.n_sample_frame)): | |
merged_mask_frame = np.sum(layout_[:, i, :, :, :], axis=0) | |
merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8) | |
merged_masks.append(merged_mask_frame) | |
masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w") | |
masks = torch.from_numpy(masks).half() | |
layouts = rearrange(layout_, "s f c h w -> f s c h w") | |
layouts = torch.from_numpy(layouts).half() | |
return_batch.update({ | |
"images": frames, | |
"masks": masks, | |
"layouts": layouts, | |
"prompt_ids": self.prompt_ids, | |
}) | |
return return_batch | |
def transform(self, frames): | |
frames = self.tensorize_frames(frames) | |
frames = offset_crop(frames, **self.offset) | |
frames = short_size_scale(frames, size=self.image_size) | |
frames = self.crop(frames, height=self.image_size, width=self.image_size) | |
return frames | |
def tensorize_frames(frames): | |
frames = rearrange(np.stack(frames), "f h w c -> c f h w") | |
return torch.from_numpy(frames).div(255) * 2 - 1 | |
def _read_mask(self, mask_dir, index: int): | |
# 构造 mask 文件名(png 格式) | |
mask_path = os.path.join(mask_dir, f"{index:05d}.png") | |
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
mask = (mask > 0).astype(np.uint8) | |
# 根据原图大小动态缩放(这里缩小8倍) | |
height, width = mask.shape | |
dest_size = (width // 8, height // 8) | |
mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST) | |
mask = mask[np.newaxis, ...] | |
return mask | |
def load_frame(self, index): | |
image_path = os.path.join(self.path, self.images[index]) | |
return Image.open(image_path).convert(self.image_mode) | |
def load_class_frame(self, index): | |
image_path = self.class_images_path[index] | |
return Image.open(image_path).convert(self.image_mode) | |
def get_frame_indices(self, index): | |
if self.start_sample_frame is not None: | |
frame_start = self.start_sample_frame + self.stride * index | |
else: | |
frame_start = self.stride * index | |
return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame)) | |
def get_class_indices(self, index): | |
frame_start = index | |
return (frame_start + i for i in range(self.n_sample_frame)) | |
def get_image_list(path): | |
images = [] | |
# 如果传入的是 mp4 文件,则先转换成 PNG 图像目录 | |
if path.endswith('.mp4'): | |
path = ImageSequenceDataset.mp4_to_png(path, target_dir='./input-video') | |
for file in sorted(os.listdir(path)): | |
if file.endswith(IMAGE_EXTENSION): | |
images.append(file) | |
return images | |
def mp4_to_png(video_source: str, target_dir: str): | |
""" | |
Convert an mp4 video to a sequence of PNG images, storing them in target_dir. | |
target_dir 为固定路径,例如:'./input-video' 或 './layout_masks/1' | |
""" | |
if os.path.exists(target_dir): | |
shutil.rmtree(target_dir) | |
os.makedirs(target_dir, exist_ok=True) | |
reader = imageio.get_reader(video_source) | |
for i, im in enumerate(reader): | |
path = os.path.join(target_dir, f"{i:05d}.png") | |
cv2.imwrite(path, im[:, :, ::-1]) | |
return target_dir | |