|
import os |
|
import sys |
|
import signal |
|
import subprocess |
|
import shutil |
|
import concurrent.futures |
|
import multiprocessing |
|
from contextlib import contextmanager |
|
|
|
sys.path.append(os.getcwd()) |
|
|
|
import argparse |
|
import csv |
|
import json |
|
from importlib.resources import files |
|
from pathlib import Path |
|
|
|
import torchaudio |
|
from tqdm import tqdm |
|
from datasets.arrow_writer import ArrowWriter |
|
|
|
from f5_tts.model.utils import ( |
|
convert_char_to_pinyin, |
|
) |
|
|
|
|
|
PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt") |
|
|
|
|
|
def is_csv_wavs_format(input_dataset_dir): |
|
fpath = Path(input_dataset_dir) |
|
metadata = fpath / "metadata.csv" |
|
wavs = fpath / "wavs" |
|
return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir() |
|
|
|
|
|
|
|
BATCH_SIZE = 100 |
|
MAX_WORKERS = max(1, multiprocessing.cpu_count() - 1) |
|
THREAD_NAME_PREFIX = "AudioProcessor" |
|
CHUNK_SIZE = 100 |
|
|
|
executor = None |
|
|
|
|
|
@contextmanager |
|
def graceful_exit(): |
|
"""Context manager for graceful shutdown on signals""" |
|
|
|
def signal_handler(signum, frame): |
|
print("\nReceived signal to terminate. Cleaning up...") |
|
if executor is not None: |
|
print("Shutting down executor...") |
|
executor.shutdown(wait=False, cancel_futures=True) |
|
sys.exit(1) |
|
|
|
|
|
signal.signal(signal.SIGINT, signal_handler) |
|
signal.signal(signal.SIGTERM, signal_handler) |
|
|
|
try: |
|
yield |
|
finally: |
|
if executor is not None: |
|
executor.shutdown(wait=False) |
|
|
|
|
|
def process_audio_file(audio_path, text, polyphone): |
|
"""Process a single audio file by checking its existence and extracting duration.""" |
|
if not Path(audio_path).exists(): |
|
print(f"audio {audio_path} not found, skipping") |
|
return None |
|
try: |
|
audio_duration = get_audio_duration(audio_path) |
|
if audio_duration <= 0: |
|
raise ValueError(f"Duration {audio_duration} is non-positive.") |
|
return (audio_path, text, audio_duration) |
|
except Exception as e: |
|
print(f"Warning: Failed to process {audio_path} due to error: {e}. Skipping corrupt file.") |
|
return None |
|
|
|
|
|
def batch_convert_texts(texts, polyphone, batch_size=BATCH_SIZE): |
|
"""Convert a list of texts to pinyin in batches.""" |
|
converted_texts = [] |
|
for i in range(0, len(texts), batch_size): |
|
batch = texts[i : i + batch_size] |
|
converted_batch = convert_char_to_pinyin(batch, polyphone=polyphone) |
|
converted_texts.extend(converted_batch) |
|
return converted_texts |
|
|
|
|
|
def prepare_csv_wavs_dir(input_dir, num_workers=None): |
|
global executor |
|
assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}" |
|
input_dir = Path(input_dir) |
|
metadata_path = input_dir / "metadata.csv" |
|
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix()) |
|
|
|
polyphone = True |
|
total_files = len(audio_path_text_pairs) |
|
|
|
|
|
worker_count = num_workers if num_workers is not None else min(MAX_WORKERS, total_files) |
|
print(f"\nProcessing {total_files} audio files using {worker_count} workers...") |
|
|
|
with graceful_exit(): |
|
|
|
with concurrent.futures.ThreadPoolExecutor( |
|
max_workers=worker_count, thread_name_prefix=THREAD_NAME_PREFIX |
|
) as exec: |
|
executor = exec |
|
results = [] |
|
|
|
|
|
for i in range(0, len(audio_path_text_pairs), CHUNK_SIZE): |
|
chunk = audio_path_text_pairs[i : i + CHUNK_SIZE] |
|
|
|
chunk_futures = [executor.submit(process_audio_file, pair[0], pair[1], polyphone) for pair in chunk] |
|
|
|
|
|
for future in tqdm( |
|
chunk_futures, |
|
total=len(chunk), |
|
desc=f"Processing chunk {i // CHUNK_SIZE + 1}/{(total_files + CHUNK_SIZE - 1) // CHUNK_SIZE}", |
|
): |
|
try: |
|
result = future.result() |
|
if result is not None: |
|
results.append(result) |
|
except Exception as e: |
|
print(f"Error processing file: {e}") |
|
|
|
executor = None |
|
|
|
|
|
processed = [res for res in results if res is not None] |
|
if not processed: |
|
raise RuntimeError("No valid audio files were processed!") |
|
|
|
|
|
raw_texts = [item[1] for item in processed] |
|
converted_texts = batch_convert_texts(raw_texts, polyphone, batch_size=BATCH_SIZE) |
|
|
|
|
|
sub_result = [] |
|
durations = [] |
|
vocab_set = set() |
|
|
|
for (audio_path, _, duration), conv_text in zip(processed, converted_texts): |
|
sub_result.append({"audio_path": audio_path, "text": conv_text, "duration": duration}) |
|
durations.append(duration) |
|
vocab_set.update(list(conv_text)) |
|
|
|
return sub_result, durations, vocab_set |
|
|
|
|
|
def get_audio_duration(audio_path, timeout=5): |
|
""" |
|
Get the duration of an audio file in seconds using ffmpeg's ffprobe. |
|
Falls back to torchaudio.load() if ffprobe fails. |
|
""" |
|
try: |
|
cmd = [ |
|
"ffprobe", |
|
"-v", |
|
"error", |
|
"-show_entries", |
|
"format=duration", |
|
"-of", |
|
"default=noprint_wrappers=1:nokey=1", |
|
audio_path, |
|
] |
|
result = subprocess.run( |
|
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, check=True, timeout=timeout |
|
) |
|
duration_str = result.stdout.strip() |
|
if duration_str: |
|
return float(duration_str) |
|
raise ValueError("Empty duration string from ffprobe.") |
|
except (subprocess.TimeoutExpired, subprocess.SubprocessError, ValueError) as e: |
|
print(f"Warning: ffprobe failed for {audio_path} with error: {e}. Falling back to torchaudio.") |
|
try: |
|
audio, sample_rate = torchaudio.load(audio_path) |
|
return audio.shape[1] / sample_rate |
|
except Exception as e: |
|
raise RuntimeError(f"Both ffprobe and torchaudio failed for {audio_path}: {e}") |
|
|
|
|
|
def read_audio_text_pairs(csv_file_path): |
|
audio_text_pairs = [] |
|
|
|
parent = Path(csv_file_path).parent |
|
with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile: |
|
reader = csv.reader(csvfile, delimiter="|") |
|
next(reader) |
|
for row in reader: |
|
if len(row) >= 2: |
|
audio_file = row[0].strip() |
|
text = row[1].strip() |
|
audio_file_path = parent / audio_file |
|
audio_text_pairs.append((audio_file_path.as_posix(), text)) |
|
|
|
return audio_text_pairs |
|
|
|
|
|
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune): |
|
out_dir = Path(out_dir) |
|
out_dir.mkdir(exist_ok=True, parents=True) |
|
print(f"\nSaving to {out_dir} ...") |
|
|
|
|
|
raw_arrow_path = out_dir / "raw.arrow" |
|
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=100) as writer: |
|
for line in tqdm(result, desc="Writing to raw.arrow ..."): |
|
writer.write(line) |
|
|
|
|
|
dur_json_path = out_dir / "duration.json" |
|
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f: |
|
json.dump({"duration": duration_list}, f, ensure_ascii=False) |
|
|
|
|
|
voca_out_path = out_dir / "vocab.txt" |
|
if is_finetune: |
|
file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix() |
|
shutil.copy2(file_vocab_finetune, voca_out_path) |
|
else: |
|
with open(voca_out_path.as_posix(), "w") as f: |
|
for vocab in sorted(text_vocab_set): |
|
f.write(vocab + "\n") |
|
|
|
dataset_name = out_dir.stem |
|
print(f"\nFor {dataset_name}, sample count: {len(result)}") |
|
print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}") |
|
print(f"For {dataset_name}, total {sum(duration_list) / 3600:.2f} hours") |
|
|
|
|
|
def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True, num_workers: int = None): |
|
if is_finetune: |
|
assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}" |
|
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir, num_workers=num_workers) |
|
save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune) |
|
|
|
|
|
def cli(): |
|
try: |
|
|
|
if shutil.which("ffprobe") is None: |
|
print( |
|
"Warning: ffprobe is not available. Duration extraction will rely on torchaudio (which may be slower)." |
|
) |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
description="Prepare and save dataset.", |
|
epilog=""" |
|
Examples: |
|
# For fine-tuning (default): |
|
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path |
|
|
|
# For pre-training: |
|
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --pretrain |
|
|
|
# With custom worker count: |
|
python prepare_csv_wavs.py /input/dataset/path /output/dataset/path --workers 4 |
|
""", |
|
) |
|
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.") |
|
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.") |
|
parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune") |
|
parser.add_argument("--workers", type=int, help=f"Number of worker threads (default: {MAX_WORKERS})") |
|
args = parser.parse_args() |
|
|
|
prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain, num_workers=args.workers) |
|
except KeyboardInterrupt: |
|
print("\nOperation cancelled by user. Cleaning up...") |
|
if executor is not None: |
|
executor.shutdown(wait=False, cancel_futures=True) |
|
sys.exit(1) |
|
|
|
|
|
if __name__ == "__main__": |
|
cli() |
|
|