YiChen_FramePack_lora_early / wan_cache_text_encoder_outputs.py
svjack's picture
Upload folder using huggingface_hub
ef46f0f verified
raw
history blame
3.78 kB
import argparse
import os
from typing import Optional, Union
import numpy as np
import torch
from tqdm import tqdm
from dataset import config_utils
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
import accelerate
from dataset.image_video_dataset import ARCHITECTURE_WAN, ItemInfo, save_text_encoder_output_cache_wan
# for t5 config: all Wan2.1 models have the same config for t5
from wan.configs import wan_t2v_14B
import cache_text_encoder_outputs
import logging
from utils.model_utils import str_to_dtype
from wan.modules.t5 import T5EncoderModel
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def encode_and_save_batch(
text_encoder: T5EncoderModel, batch: list[ItemInfo], device: torch.device, accelerator: Optional[accelerate.Accelerator]
):
prompts = [item.caption for item in batch]
# print(prompts)
# encode prompt
with torch.no_grad():
if accelerator is not None:
with accelerator.autocast():
context = text_encoder(prompts, device)
else:
context = text_encoder(prompts, device)
# save prompt cache
for item, ctx in zip(batch, context):
save_text_encoder_output_cache_wan(item, ctx)
def main(args):
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Load dataset config
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
datasets = train_dataset_group.datasets
# define accelerator for fp8 inference
config = wan_t2v_14B.t2v_14B # all Wan2.1 models have the same config for t5
accelerator = None
if args.fp8_t5:
accelerator = accelerate.Accelerator(mixed_precision="bf16" if config.t5_dtype == torch.bfloat16 else "fp16")
# prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)
# Load T5
logger.info(f"Loading T5: {args.t5}")
text_encoder = T5EncoderModel(
text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=args.t5, fp8=args.fp8_t5
)
# Encode with T5
logger.info("Encoding with T5")
def encode_for_text_encoder(batch: list[ItemInfo]):
encode_and_save_batch(text_encoder, batch, device, accelerator)
cache_text_encoder_outputs.process_text_encoder_batches(
args.num_workers,
args.skip_existing,
args.batch_size,
datasets,
all_cache_files_for_dataset,
all_cache_paths_for_dataset,
encode_for_text_encoder,
)
del text_encoder
# remove cache files not in dataset
cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset)
def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--t5", type=str, default=None, required=True, help="text encoder (T5) checkpoint path")
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
return parser
if __name__ == "__main__":
parser = cache_text_encoder_outputs.setup_parser_common()
parser = wan_setup_parser(parser)
args = parser.parse_args()
main(args)