import argparse import os from typing import Optional, Union import numpy as np import torch from tqdm import tqdm from dataset import config_utils from dataset.config_utils import BlueprintGenerator, ConfigSanitizer import accelerate from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, BaseDataset, ItemInfo, save_text_encoder_output_cache from hunyuan_model import text_encoder as text_encoder_module from hunyuan_model.text_encoder import TextEncoder import logging from utils.model_utils import str_to_dtype logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]): data_type = "video" # video only, image is not supported text_inputs = text_encoder.text2tokens(prompt, data_type=data_type) with torch.no_grad(): prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type) return prompt_outputs.hidden_state, prompt_outputs.attention_mask def encode_and_save_batch( text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator] ): prompts = [item.caption for item in batch] # print(prompts) # encode prompt if accelerator is not None: with accelerator.autocast(): prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts) else: prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts) # # convert to fp16 if needed # if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32: # prompt_embeds = prompt_embeds.to(text_encoder.dtype) # save prompt cache for item, embed, mask in zip(batch, prompt_embeds, prompt_mask): save_text_encoder_output_cache(item, embed, mask, is_llm) def prepare_cache_files_and_paths(datasets: list[BaseDataset]): all_cache_files_for_dataset = [] # exisiting cache files all_cache_paths_for_dataset = [] # all cache paths in the dataset for dataset in datasets: all_cache_files = [os.path.normpath(file) for file in dataset.get_all_text_encoder_output_cache_files()] all_cache_files = set(all_cache_files) all_cache_files_for_dataset.append(all_cache_files) all_cache_paths_for_dataset.append(set()) return all_cache_files_for_dataset, all_cache_paths_for_dataset def process_text_encoder_batches( num_workers: Optional[int], skip_existing: bool, batch_size: int, datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set], encode: callable, ): num_workers = num_workers if num_workers is not None else max(1, os.cpu_count() - 1) for i, dataset in enumerate(datasets): logger.info(f"Encoding dataset [{i}]") all_cache_files = all_cache_files_for_dataset[i] all_cache_paths = all_cache_paths_for_dataset[i] for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)): # update cache files (it's ok if we update it multiple times) all_cache_paths.update([os.path.normpath(item.text_encoder_output_cache_path) for item in batch]) # skip existing cache files if skip_existing: filtered_batch = [ item for item in batch if not os.path.normpath(item.text_encoder_output_cache_path) in all_cache_files ] # print(f"Filtered {len(batch) - len(filtered_batch)} existing cache files") if len(filtered_batch) == 0: continue batch = filtered_batch bs = batch_size if batch_size is not None else len(batch) for i in range(0, len(batch), bs): encode(batch[i : i + bs]) def post_process_cache_files( datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set] ): for i, dataset in enumerate(datasets): all_cache_files = all_cache_files_for_dataset[i] all_cache_paths = all_cache_paths_for_dataset[i] for cache_file in all_cache_files: if cache_file not in all_cache_paths: if args.keep_cache: logger.info(f"Keep cache file not in the dataset: {cache_file}") else: os.remove(cache_file) logger.info(f"Removed old cache file: {cache_file}") def main(args): device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # Load dataset config blueprint_generator = BlueprintGenerator(ConfigSanitizer()) logger.info(f"Load dataset config from {args.dataset_config}") user_config = config_utils.load_user_config(args.dataset_config) blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO) train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group) datasets = train_dataset_group.datasets # define accelerator for fp8 inference accelerator = None if args.fp8_llm: accelerator = accelerate.Accelerator(mixed_precision="fp16") # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset all_cache_files_for_dataset, all_cache_paths_for_dataset = prepare_cache_files_and_paths(datasets) # Load Text Encoder 1 text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype) logger.info(f"loading text encoder 1: {args.text_encoder1}") text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype) text_encoder_1.to(device=device) # Encode with Text Encoder 1 (LLM) logger.info("Encoding with Text Encoder 1") def encode_for_text_encoder_1(batch: list[ItemInfo]): encode_and_save_batch(text_encoder_1, batch, is_llm=True, accelerator=accelerator) process_text_encoder_batches( args.num_workers, args.skip_existing, args.batch_size, datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, encode_for_text_encoder_1, ) del text_encoder_1 # Load Text Encoder 2 logger.info(f"loading text encoder 2: {args.text_encoder2}") text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype) text_encoder_2.to(device=device) # Encode with Text Encoder 2 logger.info("Encoding with Text Encoder 2") def encode_for_text_encoder_2(batch: list[ItemInfo]): encode_and_save_batch(text_encoder_2, batch, is_llm=False, accelerator=None) process_text_encoder_batches( args.num_workers, args.skip_existing, args.batch_size, datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, encode_for_text_encoder_2, ) del text_encoder_2 # remove cache files not in dataset post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset) def setup_parser_common(): parser = argparse.ArgumentParser() parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file") parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available") parser.add_argument( "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this" ) parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1") parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files") parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset") return parser def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory") parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory") parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16") parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)") return parser if __name__ == "__main__": parser = setup_parser_common() parser = hv_setup_parser(parser) args = parser.parse_args() main(args)