XiangpengYang's picture
huggingface space
33f757a
raw
history blame
8.01 kB
import os
import numpy as np
from PIL import Image
from einops import rearrange
from pathlib import Path
import torch
from torch.utils.data import Dataset
from .transform import short_size_scale, random_crop, center_crop, offset_crop
from ..common.image_util import IMAGE_EXTENSION
import cv2
import imageio
import shutil
class ImageSequenceDataset(Dataset):
def __init__(
self,
path: str, # 输入视频,如果是 mp4 则转换到固定目录 './input-video'
layout_files: list, # 上传的 layout mask 文件列表(mp4 或目录),转换后存放到固定目录 './layout_masks/1', './layout_masks/2', ...
prompt_ids: torch.Tensor,
prompt: str,
start_sample_frame: int = 0,
n_sample_frame: int = 8,
sampling_rate: int = 1,
stride: int = -1, # tuning 时用于对长视频进行采样
image_mode: str = "RGB",
image_size: int = 512,
crop: str = "center",
offset: dict = {
"left": 0,
"right": 0,
"top": 0,
"bottom": 0
},
**args
):
# 若输入视频是 mp4,则转换到固定目录 './input-video'
if path.endswith('.mp4'):
self.path = self.mp4_to_png(path, target_dir='./input-video')
else:
self.path = path
self.images = self.get_image_list(self.path)
# 对每个上传的 layout 文件进行处理
# 若是 mp4,则转换到固定目录 './layout_masks/{i+1}'
self.layout_mask_dirs = []
for idx, file in enumerate(layout_files):
if file.endswith('.mp4'):
folder = self.mp4_to_png(file, target_dir=f'./layout_masks/{idx+1}')
else:
folder = file
self.layout_mask_dirs.append(folder)
# 保持上传顺序作为 layout_mask_order(此处仅用索引表示顺序)
self.layout_mask_order = list(range(len(self.layout_mask_dirs)))
# 用第一个 layout mask 目录获取 mask 图像索引(用于判断帧数)
self.masks_index = self.get_image_list(self.layout_mask_dirs[0])
self.n_images = len(self.images)
self.offset = offset
self.start_sample_frame = start_sample_frame
if n_sample_frame < 0:
n_sample_frame = len(self.images)
self.n_sample_frame = n_sample_frame
self.sampling_rate = sampling_rate
self.sequence_length = (n_sample_frame - 1) * sampling_rate + 1
if self.n_images < self.sequence_length:
raise ValueError(f"self.n_images {self.n_images} < self.sequence_length {self.sequence_length}: Required number of frames {self.sequence_length} larger than total frames in the dataset {self.n_images}")
# 若视频太长,则全局采样
self.stride = stride if stride > 0 else (self.n_images + 1)
self.video_len = (self.n_images - self.sequence_length) // self.stride + 1
self.image_mode = image_mode
self.image_size = image_size
crop_methods = {
"center": center_crop,
"random": random_crop,
}
if crop not in crop_methods:
raise ValueError("Unsupported crop method")
self.crop = crop_methods[crop]
self.prompt = prompt
self.prompt_ids = prompt_ids
def __len__(self):
max_len = (self.n_images - self.sequence_length) // self.stride + 1
if hasattr(self, 'num_class_images'):
max_len = max(max_len, self.num_class_images)
return max_len
def __getitem__(self, index):
return_batch = {}
frame_indices = self.get_frame_indices(index % self.video_len)
frames = [self.load_frame(i) for i in frame_indices]
frames = self.transform(frames)
layout_ = []
# 遍历每个 layout mask 目录(顺序与用户上传顺序一致)
for layout_dir in self.layout_mask_dirs:
# 对于每个 layout 目录,根据帧索引读取对应的 mask 图像(PNG 文件)
frame_indices_local = self.get_frame_indices(index % self.video_len)
mask = [self._read_mask(layout_dir, i) for i in frame_indices_local]
masks = np.stack(mask) # shape: (n_sample_frame, c, h, w)
layout_.append(masks)
layout_ = np.stack(layout_) # shape: (num_layouts, n_sample_frame, c, h, w)
merged_masks = []
for i in range(int(self.n_sample_frame)):
merged_mask_frame = np.sum(layout_[:, i, :, :, :], axis=0)
merged_mask_frame = (merged_mask_frame > 0).astype(np.uint8)
merged_masks.append(merged_mask_frame)
masks = rearrange(np.stack(merged_masks), "f c h w -> c f h w")
masks = torch.from_numpy(masks).half()
layouts = rearrange(layout_, "s f c h w -> f s c h w")
layouts = torch.from_numpy(layouts).half()
return_batch.update({
"images": frames,
"masks": masks,
"layouts": layouts,
"prompt_ids": self.prompt_ids,
})
return return_batch
def transform(self, frames):
frames = self.tensorize_frames(frames)
frames = offset_crop(frames, **self.offset)
frames = short_size_scale(frames, size=self.image_size)
frames = self.crop(frames, height=self.image_size, width=self.image_size)
return frames
@staticmethod
def tensorize_frames(frames):
frames = rearrange(np.stack(frames), "f h w c -> c f h w")
return torch.from_numpy(frames).div(255) * 2 - 1
def _read_mask(self, mask_dir, index: int):
# 构造 mask 文件名(png 格式)
mask_path = os.path.join(mask_dir, f"{index:05d}.png")
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = (mask > 0).astype(np.uint8)
# 根据原图大小动态缩放(这里缩小8倍)
height, width = mask.shape
dest_size = (width // 8, height // 8)
mask = cv2.resize(mask, dest_size, interpolation=cv2.INTER_NEAREST)
mask = mask[np.newaxis, ...]
return mask
def load_frame(self, index):
image_path = os.path.join(self.path, self.images[index])
return Image.open(image_path).convert(self.image_mode)
def load_class_frame(self, index):
image_path = self.class_images_path[index]
return Image.open(image_path).convert(self.image_mode)
def get_frame_indices(self, index):
if self.start_sample_frame is not None:
frame_start = self.start_sample_frame + self.stride * index
else:
frame_start = self.stride * index
return (frame_start + i * self.sampling_rate for i in range(self.n_sample_frame))
def get_class_indices(self, index):
frame_start = index
return (frame_start + i for i in range(self.n_sample_frame))
@staticmethod
def get_image_list(path):
images = []
# 如果传入的是 mp4 文件,则先转换成 PNG 图像目录
if path.endswith('.mp4'):
path = ImageSequenceDataset.mp4_to_png(path, target_dir='./input-video')
for file in sorted(os.listdir(path)):
if file.endswith(IMAGE_EXTENSION):
images.append(file)
return images
@staticmethod
def mp4_to_png(video_source: str, target_dir: str):
"""
Convert an mp4 video to a sequence of PNG images, storing them in target_dir.
target_dir 为固定路径,例如:'./input-video' 或 './layout_masks/1'
"""
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir, exist_ok=True)
reader = imageio.get_reader(video_source)
for i, im in enumerate(reader):
path = os.path.join(target_dir, f"{i:05d}.png")
cv2.imwrite(path, im[:, :, ::-1])
return target_dir