File size: 3,778 Bytes
ef46f0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import argparse
import os
from typing import Optional, Union
import numpy as np
import torch
from tqdm import tqdm
from dataset import config_utils
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
import accelerate
from dataset.image_video_dataset import ARCHITECTURE_WAN, ItemInfo, save_text_encoder_output_cache_wan
# for t5 config: all Wan2.1 models have the same config for t5
from wan.configs import wan_t2v_14B
import cache_text_encoder_outputs
import logging
from utils.model_utils import str_to_dtype
from wan.modules.t5 import T5EncoderModel
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def encode_and_save_batch(
text_encoder: T5EncoderModel, batch: list[ItemInfo], device: torch.device, accelerator: Optional[accelerate.Accelerator]
):
prompts = [item.caption for item in batch]
# print(prompts)
# encode prompt
with torch.no_grad():
if accelerator is not None:
with accelerator.autocast():
context = text_encoder(prompts, device)
else:
context = text_encoder(prompts, device)
# save prompt cache
for item, ctx in zip(batch, context):
save_text_encoder_output_cache_wan(item, ctx)
def main(args):
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device)
# Load dataset config
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
logger.info(f"Load dataset config from {args.dataset_config}")
user_config = config_utils.load_user_config(args.dataset_config)
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_WAN)
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
datasets = train_dataset_group.datasets
# define accelerator for fp8 inference
config = wan_t2v_14B.t2v_14B # all Wan2.1 models have the same config for t5
accelerator = None
if args.fp8_t5:
accelerator = accelerate.Accelerator(mixed_precision="bf16" if config.t5_dtype == torch.bfloat16 else "fp16")
# prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)
# Load T5
logger.info(f"Loading T5: {args.t5}")
text_encoder = T5EncoderModel(
text_len=config.text_len, dtype=config.t5_dtype, device=device, weight_path=args.t5, fp8=args.fp8_t5
)
# Encode with T5
logger.info("Encoding with T5")
def encode_for_text_encoder(batch: list[ItemInfo]):
encode_and_save_batch(text_encoder, batch, device, accelerator)
cache_text_encoder_outputs.process_text_encoder_batches(
args.num_workers,
args.skip_existing,
args.batch_size,
datasets,
all_cache_files_for_dataset,
all_cache_paths_for_dataset,
encode_for_text_encoder,
)
del text_encoder
# remove cache files not in dataset
cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset)
def wan_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
parser.add_argument("--t5", type=str, default=None, required=True, help="text encoder (T5) checkpoint path")
parser.add_argument("--fp8_t5", action="store_true", help="use fp8 for Text Encoder model")
return parser
if __name__ == "__main__":
parser = cache_text_encoder_outputs.setup_parser_common()
parser = wan_setup_parser(parser)
args = parser.parse_args()
main(args)
|