# Copyright 2025 ByteDance and/or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import argparse import librosa import numpy as np import torch from tn.chinese.normalizer import Normalizer as ZhNormalizer from tn.english.normalizer import Normalizer as EnNormalizer from langdetect import detect as classify_language from pydub import AudioSegment import pyloudnorm as pyln from tts.modules.ar_dur.commons.nar_tts_modules import LengthRegulator from tts.frontend_function import g2p, align, make_dur_prompt, dur_pred, prepare_inputs_for_dit from tts.utils.audio_utils.io import save_wav, to_wav_bytes, convert_to_wav_bytes, combine_audio_segments from tts.utils.commons.ckpt_utils import load_ckpt from tts.utils.commons.hparams import set_hparams, hparams from tts.utils.text_utils.text_encoder import TokenTextEncoder from tts.utils.text_utils.split_text import chunk_text_chinese, chunk_text_english from tts.utils.commons.hparams import hparams, set_hparams if "TOKENIZERS_PARALLELISM" not in os.environ: os.environ["TOKENIZERS_PARALLELISM"] = "false" def convert_to_wav(wav_path): # Check if the file exists if not os.path.exists(wav_path): print(f"The file '{wav_path}' does not exist.") return # Check if the file already has a .wav extension if not wav_path.endswith(".wav"): # Define the output path with a .wav extension out_path = os.path.splitext(wav_path)[0] + ".wav" # Load the audio file using pydub and convert it to WAV audio = AudioSegment.from_file(wav_path) audio.export(out_path, format="wav") print(f"Converted '{wav_path}' to '{out_path}'") def cut_wav(wav_path, max_len=28): audio = AudioSegment.from_file(wav_path) audio = audio[:int(max_len * 1000)] audio.export(wav_path, format="wav") class MegaTTS3DiTInfer(): def __init__( self, device=None, ckpt_root='./checkpoints', dit_exp_name='diffusion_transformer', frontend_exp_name='aligner_lm', wavvae_exp_name='wavvae', dur_ckpt_path='duration_lm', g2p_exp_name='g2p', precision=torch.float16, **kwargs ): self.sr = 24000 self.fm = 8 if device is None: device = 'cuda' if torch.cuda.is_available() else 'cpu' self.device = device self.precision = precision # build models self.dit_exp_name = os.path.join(ckpt_root, dit_exp_name) self.frontend_exp_name = os.path.join(ckpt_root, frontend_exp_name) self.wavvae_exp_name = os.path.join(ckpt_root, wavvae_exp_name) self.dur_exp_name = os.path.join(ckpt_root, dur_ckpt_path) self.g2p_exp_name = os.path.join(ckpt_root, g2p_exp_name) self.build_model(self.device) # init text normalizer self.zh_normalizer = ZhNormalizer(overwrite_cache=False, remove_erhua=False, remove_interjections=False) self.en_normalizer = EnNormalizer(overwrite_cache=False) # loudness meter self.loudness_meter = pyln.Meter(self.sr) def build_model(self, device): set_hparams(exp_name=self.dit_exp_name, print_hparams=False) ''' Load Dict ''' current_dir = os.path.dirname(os.path.abspath(__file__)) ling_dict = json.load(open(f"{current_dir}/utils/text_utils/dict.json", encoding='utf-8-sig')) self.ling_dict = {k: TokenTextEncoder(None, vocab_list=ling_dict[k], replace_oov='') for k in ['phone', 'tone']} self.token_encoder = token_encoder = self.ling_dict['phone'] ph_dict_size = len(token_encoder) ''' Load Duration LM ''' from tts.modules.ar_dur.ar_dur_predictor import ARDurPredictor hp_dur_model = self.hp_dur_model = set_hparams(f'{self.dur_exp_name}/config.yaml', global_hparams=False) hp_dur_model['frames_multiple'] = hparams['frames_multiple'] self.dur_model = ARDurPredictor( hp_dur_model, hp_dur_model['dur_txt_hs'], hp_dur_model['dur_model_hidden_size'], hp_dur_model['dur_model_layers'], ph_dict_size, hp_dur_model['dur_code_size'], use_rot_embed=hp_dur_model.get('use_rot_embed', False)) self.length_regulator = LengthRegulator() load_ckpt(self.dur_model, f'{self.dur_exp_name}', 'dur_model') self.dur_model.eval() self.dur_model.to(device) ''' Load Diffusion Transformer ''' from tts.modules.llm_dit.dit import Diffusion self.dit = Diffusion() load_ckpt(self.dit, f'{self.dit_exp_name}', 'dit', strict=False) self.dit.eval() self.dit.to(device) self.cfg_mask_token_phone = 302 - 1 self.cfg_mask_token_tone = 32 - 1 ''' Load Frontend LM ''' from tts.modules.aligner.whisper_small import Whisper self.aligner_lm = Whisper() load_ckpt(self.aligner_lm, f'{self.frontend_exp_name}', 'model') self.aligner_lm.eval() self.aligner_lm.to(device) self.kv_cache = None self.hooks = None ''' Load G2P LM''' from transformers import AutoTokenizer, AutoModelForCausalLM g2p_tokenizer = AutoTokenizer.from_pretrained(self.g2p_exp_name, padding_side="right") g2p_tokenizer.padding_side = "right" self.g2p_model = AutoModelForCausalLM.from_pretrained(self.g2p_exp_name).eval().to(device) self.g2p_tokenizer = g2p_tokenizer self.speech_start_idx = g2p_tokenizer.encode('')[0] ''' Wav VAE ''' self.hp_wavvae = hp_wavvae = set_hparams(f'{self.wavvae_exp_name}/config.yaml', global_hparams=False) from tts.modules.wavvae.decoder.wavvae_v3 import WavVAE_V3 self.wavvae = WavVAE_V3(hparams=hp_wavvae) if os.path.exists(f'{self.wavvae_exp_name}/model_only_last.ckpt'): load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/model_only_last.ckpt', 'model_gen', strict=True) self.has_vae_encoder = True else: load_ckpt(self.wavvae, f'{self.wavvae_exp_name}/decoder.ckpt', 'model_gen', strict=False) self.has_vae_encoder = False self.wavvae.eval() self.wavvae.to(device) self.vae_stride = hp_wavvae.get('vae_stride', 4) self.hop_size = hp_wavvae.get('hop_size', 4) def preprocess(self, audio_bytes, latent_file=None, topk_dur=1, **kwargs): wav_bytes = convert_to_wav_bytes(audio_bytes) ''' Load wav ''' wav, _ = librosa.core.load(wav_bytes, sr=self.sr) # Pad wav if necessary ws = hparams['win_size'] if len(wav) % ws < ws - 1: wav = np.pad(wav, (0, ws - 1 - (len(wav) % ws)), mode='constant', constant_values=0.0).astype(np.float32) wav = np.pad(wav, (0, 12000), mode='constant', constant_values=0.0).astype(np.float32) self.loudness_prompt = self.loudness_meter.integrated_loudness(wav.astype(float)) ''' obtain alignments with aligner_lm ''' ph_ref, tone_ref, mel2ph_ref = align(self, wav) with torch.inference_mode(): ''' Forward WaveVAE to obtain: prompt latent ''' if self.has_vae_encoder: wav = torch.FloatTensor(wav)[None].to(self.device) vae_latent = self.wavvae.encode_latent(wav) vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4] else: assert latent_file is not None, "Please provide latent_file in WaveVAE decoder-only mode" vae_latent = torch.from_numpy(np.load(latent_file)).to(self.device) vae_latent = vae_latent[:, :mel2ph_ref.size(1)//4] ''' Duration Prompting ''' self.dur_model.hparams["infer_top_k"] = topk_dur if topk_dur > 1 else None incremental_state_dur_prompt, ctx_dur_tokens = make_dur_prompt(self, mel2ph_ref, ph_ref, tone_ref) return { 'ph_ref': ph_ref, 'tone_ref': tone_ref, 'mel2ph_ref': mel2ph_ref, 'vae_latent': vae_latent, 'incremental_state_dur_prompt': incremental_state_dur_prompt, 'ctx_dur_tokens': ctx_dur_tokens, } def forward(self, resource_context, input_text, time_step, p_w, t_w, dur_disturb=0.1, dur_alpha=1.0, **kwargs): device = self.device ph_ref = resource_context['ph_ref'].to(device) tone_ref = resource_context['tone_ref'].to(device) mel2ph_ref = resource_context['mel2ph_ref'].to(device) vae_latent = resource_context['vae_latent'].to(device) ctx_dur_tokens = resource_context['ctx_dur_tokens'].to(device) incremental_state_dur_prompt = resource_context['incremental_state_dur_prompt'] with torch.inference_mode(): ''' Generating ''' wav_pred_ = [] language_type = classify_language(input_text) if language_type == 'en': input_text = self.en_normalizer.normalize(input_text) text_segs = chunk_text_english(input_text, max_chars=130) else: input_text = self.zh_normalizer.normalize(input_text) text_segs = chunk_text_chinese(input_text, limit=60) for seg_i, text in enumerate(text_segs): ''' G2P ''' ph_pred, tone_pred = g2p(self, text) ''' Duration Prediction ''' mel2ph_pred = dur_pred(self, ctx_dur_tokens, incremental_state_dur_prompt, ph_pred, tone_pred, seg_i, dur_disturb, dur_alpha, is_first=seg_i==0, is_final=seg_i==len(text_segs)-1) inputs = prepare_inputs_for_dit(self, mel2ph_ref, mel2ph_pred, ph_ref, tone_ref, ph_pred, tone_pred, vae_latent) # Speech dit inference with torch.cuda.amp.autocast(dtype=self.precision, enabled=True): x = self.dit.inference(inputs, timesteps=time_step, seq_cfg_w=[p_w, t_w]).float() # WavVAE decode x[:, :vae_latent.size(1)] = vae_latent wav_pred = self.wavvae.decode(x)[0,0].to(torch.float32) ''' Post-processing ''' # Trim prompt wav wav_pred = wav_pred[vae_latent.size(1)*self.vae_stride*self.hop_size:].cpu().numpy() # Norm generated wav to prompt wav's level meter = pyln.Meter(self.sr) # create BS.1770 meter loudness_pred = self.loudness_meter.integrated_loudness(wav_pred.astype(float)) wav_pred = pyln.normalize.loudness(wav_pred, loudness_pred, self.loudness_prompt) if np.abs(wav_pred).max() >= 1: wav_pred = wav_pred / np.abs(wav_pred).max() * 0.95 # Apply hamming window wav_pred_.append(wav_pred) wav_pred = combine_audio_segments(wav_pred_, sr=self.sr).astype(float) return to_wav_bytes(wav_pred, self.sr) @spaces.GPU(duration=120) def forward_zerogpu(self, file_content, latent_file, inp_text, time_step, p_w, t_w): resource_context = self.preprocess(file_content, latent_file) wav_bytes = self.forward(resource_context, inp_text, time_step=time_step, p_w=p_w, t_w=t_w) return wav_bytes if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--input_wav', type=str) parser.add_argument('--input_text', type=str) parser.add_argument('--output_dir', type=str) parser.add_argument('--time_step', type=int, default=32, help='Inference steps of Diffusion Transformer') parser.add_argument('--p_w', type=float, default=1.6, help='Intelligibility Weight') parser.add_argument('--t_w', type=float, default=2.5, help='Similarity Weight') args = parser.parse_args() wav_path, input_text, out_path, time_step, p_w, t_w = args.input_wav, args.input_text, args.output_dir, args.time_step, args.p_w, args.t_w infer_ins = MegaTTS3DiTInfer() with open(wav_path, 'rb') as file: file_content = file.read() print(f"| Start processing {wav_path}+{input_text}") resource_context = infer_ins.preprocess(file_content, latent_file=wav_path.replace('.wav', '.npy')) wav_bytes = infer_ins.forward(resource_context, input_text, time_step=time_step, p_w=p_w, t_w=t_w) print(f"| Saving results to {out_path}/[P]{input_text[:20]}.wav") os.makedirs(out_path, exist_ok=True) save_wav(wav_bytes, f'{out_path}/[P]{input_text[:20]}.wav')