diff --git "a/modeling_spark_tts.py" "b/modeling_spark_tts.py" new file mode 100644--- /dev/null +++ "b/modeling_spark_tts.py" @@ -0,0 +1,3270 @@ +# coding=utf-8 +# Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch SparkTTS model.""" + +import os +import re +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.utils import weight_norm, remove_weight_norm # Needed for modules +import torchaudio # Needed for mel transformer in BiCodec +import numpy as np # Needed for BiCodecTokenizer logic + +from pathlib import Path +from typing import Optional, Union, Tuple, List, Dict, Any +from collections import namedtuple # For Perceiver +from functools import wraps, partial # For Perceiver/FSQ +from contextlib import nullcontext # For FSQ + +from huggingface_hub import snapshot_download +from safetensors.torch import load_file # For BiCodec loading + +from transformers.modeling_utils import PreTrainedModel +from transformers.modeling_outputs import CausalLMOutputWithPast # LLM output type +from transformers.generation import GenerationMixin +from transformers.configuration_utils import PretrainedConfig +from transformers.models.auto.modeling_auto import AutoModelForCausalLM +from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model +from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor # Needed for from_pretrained +from transformers.utils import logging +from transformers import AutoTokenizer # Needed for token parser test +from einops import rearrange, repeat, pack, unpack # Needed for modules +from einops.layers.torch import Rearrange # Needed for modules +from packaging import version # Needed for Perceiver + +from torch import Tensor, int32, einsum +from torch.amp import autocast +from einops import rearrange, reduce, pack, unpack +from numpy.lib.stride_tricks import sliding_window_view +import soxr +import soundfile + +# Import custom config +from .configuration_spark_tts import SparkTTSConfig, SparkTTSBiCodecConfig + +logger = logging.get_logger(__name__) + +# ============================================================================= +# >> START: PASTE CODE FROM sparktts/modules/* HERE << +# ============================================================================= +# IMPORTANT: All classes defined in sparktts/modules/* (layers, samper, vocos, +# fsq, residual_fsq, ecapa_tdnn, pooling_layers, perceiver_encoder, +# speaker_encoder, feat_encoder, feat_decoder, wave_generator, +# factorized_vector_quantize) need to be pasted or defined *within* this file +# so they can be found when `trust_remote_code=True` is used. + +# Example placeholder comment: +# --- Paste sparktts/modules/blocks/layers.py content here --- + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +# --- Paste sparktts/modules/blocks/samper.py content here --- +class SamplingBlock(nn.Module): + """Sampling block for upsampling or downsampling""" + + def __init__( + self, + dim: int, + groups: int = 1, + upsample_scale: int = 1, + downsample_scale: int = 1, + ) -> None: + """ + Args: + dim: input dimension + groups: number of groups + upsample_scale: upsampling scale + downsample_scale: downsampling scale + """ + super(SamplingBlock, self).__init__() + + self.upsample_scale = upsample_scale + self.downsample_scale = downsample_scale + + if self.upsample_scale > 1: + self.de_conv_upsampler = nn.Sequential( + nn.LeakyReLU(0.2), + nn.ConvTranspose1d( + dim, + dim, + kernel_size=upsample_scale * 2, + stride=upsample_scale, + padding=upsample_scale // 2 + upsample_scale % 2, + output_padding=upsample_scale % 2, + groups=groups, + ), + ) + + if self.downsample_scale > 1: + self.conv_downsampler = nn.Sequential( + nn.LeakyReLU(0.2), + nn.Conv1d( + dim, + dim, + kernel_size=2 * downsample_scale, + stride=downsample_scale, + padding=downsample_scale // 2 + downsample_scale % 2, + groups=groups, + ), + ) + + @staticmethod + def repeat_upsampler(x, upsample_scale): + return x.repeat_interleave(upsample_scale, dim=2) + + @staticmethod + def skip_downsampler(x, downsample_scale): + return F.avg_pool1d(x, kernel_size=downsample_scale, stride=downsample_scale) + + def forward(self, x): + x = x.transpose(1, 2) + if self.upsample_scale > 1: + repeat_res = self.repeat_upsampler(x, self.upsample_scale) + deconv_res = self.de_conv_upsampler(x) + upmerge_res = repeat_res + deconv_res + else: + upmerge_res = x + repeat_res = x + + if self.downsample_scale > 1: + conv_res = self.conv_downsampler(upmerge_res) + skip2_res = self.skip_downsampler(upmerge_res, self.downsample_scale) + skip1_res = self.skip_downsampler(repeat_res, self.downsample_scale) + else: + conv_res = upmerge_res + skip2_res = upmerge_res + skip1_res = repeat_res + + final_res = conv_res + skip1_res + skip2_res + + return final_res + +# --- Paste sparktts/modules/blocks/vocos.py content here --- +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal. + + Args: + dim (int): Number of input channels. + intermediate_dim (int): Dimensionality of the intermediate layer. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional LayerNorm. Defaults to None. + """ + + def __init__( + self, + dim: int, + intermediate_dim: int, + layer_scale_init_value: float, + condition_dim: Optional[int] = None, + ): + super().__init__() + self.dwconv = nn.Conv1d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.adanorm = condition_dim is not None + if condition_dim: + self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, intermediate_dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(intermediate_dim, dim) + self.gamma = ( + nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) + if layer_scale_init_value > 0 + else None + ) + + def forward( + self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None + ) -> torch.Tensor: + residual = x + x = self.dwconv(x) + x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) + if self.adanorm: + assert cond_embedding_id is not None + x = self.norm(x, cond_embedding_id) + else: + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + + x = residual + x + return x + + +class AdaLayerNorm(nn.Module): + """ + Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes + + Args: + condition_dim (int): Dimension of the condition. + embedding_dim (int): Dimension of the embeddings. + """ + + def __init__(self, condition_dim: int, embedding_dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + self.dim = embedding_dim + self.scale = nn.Linear(condition_dim, embedding_dim) + self.shift = nn.Linear(condition_dim, embedding_dim) + torch.nn.init.ones_(self.scale.weight) + torch.nn.init.zeros_(self.shift.weight) + + def forward(self, x: torch.Tensor, cond_embedding: torch.Tensor) -> torch.Tensor: + scale = self.scale(cond_embedding) + shift = self.shift(cond_embedding) + x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps) + x = x * scale.unsqueeze(1) + shift.unsqueeze(1) + return x + + +class ResBlock1(nn.Module): + """ + ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions, + but without upsampling layers. + + Args: + dim (int): Number of input channels. + kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3. + dilation (tuple[int], optional): Dilation factors for the dilated convolutions. + Defaults to (1, 3, 5). + lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function. + Defaults to 0.1. + layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling. + Defaults to None. + """ + + def __init__( + self, + dim: int, + kernel_size: int = 3, + dilation: Tuple[int, int, int] = (1, 3, 5), + lrelu_slope: float = 0.1, + layer_scale_init_value: Optional[float] = None, + ): + super().__init__() + self.lrelu_slope = lrelu_slope + self.convs1 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[0], + padding=self.get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[1], + padding=self.get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=dilation[2], + padding=self.get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + weight_norm( + nn.Conv1d( + dim, + dim, + kernel_size, + 1, + dilation=1, + padding=self.get_padding(kernel_size, 1), + ) + ), + ] + ) + + self.gamma = nn.ParameterList( + [ + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ( + nn.Parameter( + layer_scale_init_value * torch.ones(dim, 1), requires_grad=True + ) + if layer_scale_init_value is not None + else None + ), + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma): + xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope) + xt = c1(xt) + xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope) + xt = c2(xt) + if gamma is not None: + xt = gamma * xt + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + @staticmethod + def get_padding(kernel_size: int, dilation: int = 1) -> int: + return int((kernel_size * dilation - dilation) / 2) + + +class Backbone(nn.Module): + """Base class for the generator's backbone. It preserves the same temporal resolution across all layers.""" + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + """ + Args: + x (Tensor): Input tensor of shape (B, C, L), where B is the batch size, + C denotes output features, and L is the sequence length. + + Returns: + Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length, + and H denotes the model dimension. + """ + raise NotImplementedError("Subclasses must implement the forward method.") + + +class VocosBackbone(Backbone): + """ + Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock. + num_layers (int): Number of ConvNeXtBlock layers. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`. + adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm. + None means non-conditional model. Defaults to None. + """ + + def __init__( + self, + input_channels: int, + dim: int, + intermediate_dim: int, + num_layers: int, + layer_scale_init_value: Optional[float] = None, + condition_dim: Optional[int] = None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3) + self.adanorm = condition_dim is not None + if condition_dim: + self.norm = AdaLayerNorm(condition_dim, dim, eps=1e-6) + else: + self.norm = nn.LayerNorm(dim, eps=1e-6) + layer_scale_init_value = layer_scale_init_value or 1 / num_layers + self.convnext = nn.ModuleList( + [ + ConvNeXtBlock( + dim=dim, + intermediate_dim=intermediate_dim, + layer_scale_init_value=layer_scale_init_value, + condition_dim=condition_dim, + ) + for _ in range(num_layers) + ] + ) + self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv1d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor, condition: torch.Tensor = None) -> torch.Tensor: + x = self.embed(x) + if self.adanorm: + assert condition is not None + x = self.norm(x.transpose(1, 2), condition) + else: + x = self.norm(x.transpose(1, 2)) + x = x.transpose(1, 2) + for conv_block in self.convnext: + x = conv_block(x, condition) + x = self.final_layer_norm(x.transpose(1, 2)) + return x + + +class VocosResNetBackbone(Backbone): + """ + Vocos backbone module built with ResBlocks. + + Args: + input_channels (int): Number of input features channels. + dim (int): Hidden dimension of the model. + num_blocks (int): Number of ResBlock1 blocks. + layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None. + """ + + def __init__( + self, + input_channels, + dim, + num_blocks, + layer_scale_init_value=None, + ): + super().__init__() + self.input_channels = input_channels + self.embed = weight_norm( + nn.Conv1d(input_channels, dim, kernel_size=3, padding=1) + ) + layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3 + self.resnet = nn.Sequential( + *[ + ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value) + for _ in range(num_blocks) + ] + ) + + def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor: + x = self.embed(x) + x = self.resnet(x) + x = x.transpose(1, 2) + return x + + +# --- Paste sparktts/modules/fsq/finite_scalar_quantization.py content here --- +def exists(v): + return v is not None + + +def default(*args): + for arg in args: + if exists(arg): + return arg + return None + + +def maybe(fn): + @wraps(fn) + def inner(x, *args, **kwargs): + if not exists(x): + return x + return fn(x, *args, **kwargs) + + return inner + + +def pack_one(t, pattern): + return pack([t], pattern) + + +def unpack_one(t, ps, pattern): + return unpack(t, ps, pattern)[0] + + +# tensor helpers + + +def round_ste(z: Tensor) -> Tensor: + """Round with straight through gradients.""" + zhat = z.round() + return z + (zhat - z).detach() + + +# main class + + +class FSQ(nn.Module): + def __init__( + self, + levels: List[int], + dim: int | None = None, + num_codebooks=1, + keep_num_codebooks_dim: bool | None = None, + scale: float | None = None, + allowed_dtypes: Tuple[torch.dtype, ...] = (torch.float32, torch.float64), + channel_first: bool = False, + projection_has_bias: bool = True, + return_indices=True, + force_quantization_f32=True, + ): + super().__init__() + _levels = torch.tensor(levels, dtype=int32) + self.register_buffer("_levels", _levels, persistent=False) + + _basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=int32) + self.register_buffer("_basis", _basis, persistent=False) + + self.scale = scale + + codebook_dim = len(levels) + self.codebook_dim = codebook_dim + + effective_codebook_dim = codebook_dim * num_codebooks + self.num_codebooks = num_codebooks + self.effective_codebook_dim = effective_codebook_dim + + keep_num_codebooks_dim = default(keep_num_codebooks_dim, num_codebooks > 1) + assert not (num_codebooks > 1 and not keep_num_codebooks_dim) + self.keep_num_codebooks_dim = keep_num_codebooks_dim + + self.dim = default(dim, len(_levels) * num_codebooks) + + self.channel_first = channel_first + + has_projections = self.dim != effective_codebook_dim + self.project_in = ( + nn.Linear(self.dim, effective_codebook_dim, bias=projection_has_bias) + if has_projections + else nn.Identity() + ) + self.project_out = ( + nn.Linear(effective_codebook_dim, self.dim, bias=projection_has_bias) + if has_projections + else nn.Identity() + ) + + self.has_projections = has_projections + + self.return_indices = return_indices + if return_indices: + self.codebook_size = self._levels.prod().item() + implicit_codebook = self._indices_to_codes(torch.arange(self.codebook_size)) + self.register_buffer( + "implicit_codebook", implicit_codebook, persistent=False + ) + + self.allowed_dtypes = allowed_dtypes + self.force_quantization_f32 = force_quantization_f32 + + def bound(self, z, eps: float = 1e-3): + """Bound `z`, an array of shape (..., d).""" + half_l = (self._levels - 1) * (1 + eps) / 2 + offset = torch.where(self._levels % 2 == 0, 0.5, 0.0) + shift = (offset / half_l).atanh() + return (z + shift).tanh() * half_l - offset + + def quantize(self, z): + """Quantizes z, returns quantized zhat, same shape as z.""" + quantized = round_ste(self.bound(z)) + half_width = self._levels // 2 # Renormalize to [-1, 1]. + return quantized / half_width + + def _scale_and_shift(self, zhat_normalized): + half_width = self._levels // 2 + return (zhat_normalized * half_width) + half_width + + def _scale_and_shift_inverse(self, zhat): + half_width = self._levels // 2 + return (zhat - half_width) / half_width + + def _indices_to_codes(self, indices): + level_indices = self.indices_to_level_indices(indices) + codes = self._scale_and_shift_inverse(level_indices) + return codes + + def codes_to_indices(self, zhat): + """Converts a `code` to an index in the codebook.""" + assert zhat.shape[-1] == self.codebook_dim + zhat = self._scale_and_shift(zhat) + return (zhat * self._basis).sum(dim=-1).to(int32) + + def indices_to_level_indices(self, indices): + """Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings""" + indices = rearrange(indices, "... -> ... 1") + codes_non_centered = (indices // self._basis) % self._levels + return codes_non_centered + + def indices_to_codes(self, indices): + """Inverse of `codes_to_indices`.""" + assert exists(indices) + + is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim)) + + codes = self._indices_to_codes(indices) + + if self.keep_num_codebooks_dim: + codes = rearrange(codes, "... c d -> ... (c d)") + + codes = self.project_out(codes) + + if is_img_or_video or self.channel_first: + codes = rearrange(codes, "b ... d -> b d ...") + + return codes + + def forward(self, z): + """ + einstein notation + b - batch + n - sequence (or flattened spatial dimensions) + d - feature dimension + c - number of codebook dim + """ + + is_img_or_video = z.ndim >= 4 + need_move_channel_last = is_img_or_video or self.channel_first + + # standardize image or video into (batch, seq, dimension) + + if need_move_channel_last: + z = rearrange(z, "b d ... -> b ... d") + z, ps = pack_one(z, "b * d") + + assert ( + z.shape[-1] == self.dim + ), f"expected dimension of {self.dim} but found dimension of {z.shape[-1]}" + + z = self.project_in(z) + + z = rearrange(z, "b n (c d) -> b n c d", c=self.num_codebooks) + + # whether to force quantization step to be full precision or not + + force_f32 = self.force_quantization_f32 + quantization_context = ( + partial(autocast, "cuda", enabled=False) if force_f32 else nullcontext + ) + + with quantization_context(): + orig_dtype = z.dtype + + if force_f32 and orig_dtype not in self.allowed_dtypes: + z = z.float() + + codes = self.quantize(z) + + # returning indices could be optional + + indices = None + + if self.return_indices: + indices = self.codes_to_indices(codes) + + codes = rearrange(codes, "b n c d -> b n (c d)") + + codes = codes.type(orig_dtype) + + # project out + + out = self.project_out(codes) + + # reconstitute image or video dimensions + + if need_move_channel_last: + out = unpack_one(out, ps, "b * d") + out = rearrange(out, "b ... d -> b d ...") + + indices = maybe(unpack_one)(indices, ps, "b * c") + + if not self.keep_num_codebooks_dim and self.return_indices: + indices = maybe(rearrange)(indices, "... 1 -> ...") + + # return quantized output and indices + + return out, indices + + +# --- Paste sparktts/modules/fsq/residual_fsq.py content here --- +import random +import torch.distributed as dist +from einx import get_at + +def round_up_multiple(num, mult): + return ceil(num / mult) * mult + +def is_distributed(): + return dist.is_initialized() and dist.get_world_size() > 1 + + +def get_maybe_sync_seed(device, max_size=10_000): + rand_int = torch.randint(0, max_size, (), device=device) + + if is_distributed(): + dist.all_reduce(rand_int) + + return rand_int.item() + + +class ResidualFSQ(nn.Module): + """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" + + def __init__( + self, + *, + levels: List[int], + num_quantizers, + dim=None, + is_channel_first=False, + quantize_dropout=False, + quantize_dropout_cutoff_index=0, + quantize_dropout_multiple_of=1, + **kwargs, + ): + super().__init__() + codebook_dim = len(levels) + dim = default(dim, codebook_dim) + + requires_projection = codebook_dim != dim + self.project_in = ( + nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity() + ) + self.has_projections = requires_projection + + self.is_channel_first = is_channel_first + self.num_quantizers = num_quantizers + + self.levels = levels + self.layers = nn.ModuleList([]) + + levels_tensor = torch.Tensor(levels) + + scales = [] + + for ind in range(num_quantizers): + scales.append((levels_tensor - 1) ** -ind) + + fsq = FSQ(levels=levels, dim=codebook_dim, **kwargs) + + self.layers.append(fsq) + + assert all([not fsq.has_projections for fsq in self.layers]) + + self.codebook_size = self.layers[0].codebook_size + + self.register_buffer("scales", torch.stack(scales), persistent=False) + + self.quantize_dropout = quantize_dropout and num_quantizers > 1 + + assert quantize_dropout_cutoff_index >= 0 + + self.quantize_dropout_cutoff_index = quantize_dropout_cutoff_index + self.quantize_dropout_multiple_of = quantize_dropout_multiple_of # encodec paper proposes structured dropout, believe this was set to 4 + + @property + def codebooks(self): + codebooks = [layer.implicit_codebook for layer in self.layers] + codebooks = torch.stack(codebooks, dim=0) + return codebooks + + def get_codes_from_indices(self, indices): + + batch, quantize_dim = indices.shape[0], indices.shape[-1] + + # may also receive indices in the shape of 'b h w q' (accept_image_fmap) + + indices, ps = pack([indices], "b * q") + + # because of quantize dropout, one can pass in indices that are coarse + # and the network should be able to reconstruct + + if quantize_dim < self.num_quantizers: + assert ( + self.quantize_dropout > 0.0 + ), "quantize dropout must be greater than 0 if you wish to reconstruct from a signal with less fine quantizations" + indices = F.pad(indices, (0, self.num_quantizers - quantize_dim), value=-1) + + # take care of quantizer dropout + + mask = indices == -1 + indices = indices.masked_fill( + mask, 0 + ) # have it fetch a dummy code to be masked out later + + all_codes = get_at("q [c] d, b n q -> q b n d", self.codebooks, indices) + + # mask out any codes that were dropout-ed + + all_codes = all_codes.masked_fill(rearrange(mask, "b n q -> q b n 1"), 0.0) + + # scale the codes + + scales = rearrange(self.scales, "q d -> q 1 1 d") + all_codes = all_codes * scales + + # if (accept_image_fmap = True) then return shape (quantize, batch, height, width, dimension) + + (all_codes,) = unpack(all_codes, ps, "q b * d") + + return all_codes + + def get_output_from_indices(self, indices): + codes = self.get_codes_from_indices(indices) + codes_summed = reduce(codes, "q ... -> ...", "sum") + return self.project_out(codes_summed) + + def forward(self, x, return_all_codes=False, rand_quantize_dropout_fixed_seed=None): + num_quant, quant_dropout_multiple_of, device = ( + self.num_quantizers, + self.quantize_dropout_multiple_of, + x.device, + ) + + # handle channel first + + if self.is_channel_first: + x = rearrange(x, "b d ... -> b ... d") + x, ps = pack([x], "b * d") + + # maybe project in + + x = self.project_in(x) + + quantized_out = 0.0 + residual = x + + all_indices = [] + + should_quantize_dropout = self.training and self.quantize_dropout + + # sample a layer index at which to dropout further residual quantization + # also prepare null indices + + if should_quantize_dropout: + + # check if seed is manually passed in + + if not exists(rand_quantize_dropout_fixed_seed): + rand_quantize_dropout_fixed_seed = get_maybe_sync_seed(device) + + rand = random.Random(rand_quantize_dropout_fixed_seed) + + rand_quantize_dropout_index = rand.randrange( + self.quantize_dropout_cutoff_index, num_quant + ) + + if quant_dropout_multiple_of != 1: + rand_quantize_dropout_index = ( + round_up_multiple( + rand_quantize_dropout_index + 1, quant_dropout_multiple_of + ) + - 1 + ) + + null_indices = torch.full( + x.shape[:2], -1.0, device=device, dtype=torch.long + ) + + # go through the layers + + with autocast("cuda", enabled=False): + for quantizer_index, (layer, scale) in enumerate( + zip(self.layers, self.scales) + ): + + if ( + should_quantize_dropout + and quantizer_index > rand_quantize_dropout_index + ): + all_indices.append(null_indices) + continue + + quantized, indices = layer(residual / scale) + + quantized = quantized * scale + + residual = residual - quantized.detach() + quantized_out = quantized_out + quantized + + all_indices.append(indices) + + # project out, if needed + + quantized_out = self.project_out(quantized_out) + + # stack all indices + + all_indices = torch.stack(all_indices, dim=-1) + + # channel first out + + if self.is_channel_first: + (quantized_out,) = unpack(quantized_out, ps, "b * d") + (all_indices,) = unpack(all_indices, ps, "b * d") + + quantized_out = rearrange(quantized_out, "b ... d -> b d ...") + all_indices = rearrange(all_indices, "b ... d -> b d ...") + + # return + + ret = (quantized_out, all_indices) + + if not return_all_codes: + return ret + + # whether to return all codes from all codebooks across layers + + all_codes = self.get_codes_from_indices(all_indices) + + # will return all codes in shape (quantizer, batch, sequence length, codebook dimension) + + return (*ret, all_codes) + + +# grouped residual fsq + + +class GroupedResidualFSQ(nn.Module): + def __init__(self, *, dim, groups=1, accept_image_fmap=False, **kwargs): + super().__init__() + self.dim = dim + self.groups = groups + assert (dim % groups) == 0 + dim_per_group = dim // groups + + self.accept_image_fmap = accept_image_fmap + + self.rvqs = nn.ModuleList([]) + + for _ in range(groups): + self.rvqs.append(ResidualFSQ(dim=dim_per_group, **kwargs)) + + self.codebook_size = self.rvqs[0].codebook_size + + @property + def codebooks(self): + return torch.stack(tuple(rvq.codebooks for rvq in self.rvqs)) + + @property + def split_dim(self): + return 1 if self.accept_image_fmap else -1 + + def get_codes_from_indices(self, indices): + codes = tuple( + rvq.get_codes_from_indices(chunk_indices) + for rvq, chunk_indices in zip(self.rvqs, indices) + ) + return torch.stack(codes) + + def get_output_from_indices(self, indices): + outputs = tuple( + rvq.get_output_from_indices(chunk_indices) + for rvq, chunk_indices in zip(self.rvqs, indices) + ) + return torch.cat(outputs, dim=self.split_dim) + + def forward(self, x, return_all_codes=False): + shape, split_dim, device = x.shape, self.split_dim, x.device + assert shape[split_dim] == self.dim + + # split the feature dimension into groups + + x = x.chunk(self.groups, dim=split_dim) + + forward_kwargs = dict( + return_all_codes=return_all_codes, + rand_quantize_dropout_fixed_seed=( + get_maybe_sync_seed(device) if self.training else None + ), + ) + + # invoke residual vq on each group + + out = tuple(rvq(chunk, **forward_kwargs) for rvq, chunk in zip(self.rvqs, x)) + out = tuple(zip(*out)) + + # otherwise, get all the zipped outputs and combine them + + quantized, all_indices, *maybe_all_codes = out + + quantized = torch.cat(quantized, dim=split_dim) + all_indices = torch.stack(all_indices) + + ret = (quantized, all_indices, *maybe_all_codes) + return ret + + +# --- Paste sparktts/modules/speaker/pooling_layers.py content here --- + +class TAP(nn.Module): + """ + Temporal average pooling, only first-order mean is considered + """ + + def __init__(self, in_dim=0, **kwargs): + super(TAP, self).__init__() + self.in_dim = in_dim + + def forward(self, x): + pooling_mean = x.mean(dim=-1) + # To be compatable with 2D input + pooling_mean = pooling_mean.flatten(start_dim=1) + return pooling_mean + + def get_out_dim(self): + self.out_dim = self.in_dim + return self.out_dim + + +class TSDP(nn.Module): + """ + Temporal standard deviation pooling, only second-order std is considered + """ + + def __init__(self, in_dim=0, **kwargs): + super(TSDP, self).__init__() + self.in_dim = in_dim + + def forward(self, x): + # The last dimension is the temporal axis + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) + pooling_std = pooling_std.flatten(start_dim=1) + return pooling_std + + def get_out_dim(self): + self.out_dim = self.in_dim + return self.out_dim + + +class TSTP(nn.Module): + """ + Temporal statistics pooling, concatenate mean and std, which is used in + x-vector + Comment: simple concatenation can not make full use of both statistics + """ + + def __init__(self, in_dim=0, **kwargs): + super(TSTP, self).__init__() + self.in_dim = in_dim + + def forward(self, x): + # The last dimension is the temporal axis + pooling_mean = x.mean(dim=-1) + pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-7) + pooling_mean = pooling_mean.flatten(start_dim=1) + pooling_std = pooling_std.flatten(start_dim=1) + stats = torch.cat((pooling_mean, pooling_std), 1) + return stats + + def get_out_dim(self): + self.out_dim = self.in_dim * 2 + return self.out_dim + + +class ASTP(nn.Module): + """ Attentive statistics pooling: Channel- and context-dependent + statistics pooling, first used in ECAPA_TDNN. + """ + + def __init__(self, + in_dim, + bottleneck_dim=128, + global_context_att=False, + **kwargs): + super(ASTP, self).__init__() + self.in_dim = in_dim + self.global_context_att = global_context_att + + # Use Conv1d with stride == 1 rather than Linear, then we don't + # need to transpose inputs. + if global_context_att: + self.linear1 = nn.Conv1d( + in_dim * 3, bottleneck_dim, + kernel_size=1) # equals W and b in the paper + else: + self.linear1 = nn.Conv1d( + in_dim, bottleneck_dim, + kernel_size=1) # equals W and b in the paper + self.linear2 = nn.Conv1d(bottleneck_dim, in_dim, + kernel_size=1) # equals V and k in the paper + + def forward(self, x): + """ + x: a 3-dimensional tensor in tdnn-based architecture (B,F,T) + or a 4-dimensional tensor in resnet architecture (B,C,F,T) + 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) + """ + if len(x.shape) == 4: + x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3]) + assert len(x.shape) == 3 + + if self.global_context_att: + context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x) + context_std = torch.sqrt( + torch.var(x, dim=-1, keepdim=True) + 1e-7).expand_as(x) + x_in = torch.cat((x, context_mean, context_std), dim=1) + else: + x_in = x + + # DON'T use ReLU here! ReLU may be hard to converge. + alpha = torch.tanh( + self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in)) + alpha = torch.softmax(self.linear2(alpha), dim=2) + mean = torch.sum(alpha * x, dim=2) + var = torch.sum(alpha * (x**2), dim=2) - mean**2 + std = torch.sqrt(var.clamp(min=1e-7)) + return torch.cat([mean, std], dim=1) + + def get_out_dim(self): + self.out_dim = 2 * self.in_dim + return self.out_dim + + +class MHASTP(torch.nn.Module): + """ Multi head attentive statistics pooling + Reference: + Self Multi-Head Attention for Speaker Recognition + https://arxiv.org/pdf/1906.09890.pdf + """ + + def __init__(self, + in_dim, + layer_num=2, + head_num=2, + d_s=1, + bottleneck_dim=64, + **kwargs): + super(MHASTP, self).__init__() + assert (in_dim % head_num + ) == 0 # make sure that head num can be divided by input_dim + self.in_dim = in_dim + self.head_num = head_num + d_model = int(in_dim / head_num) + channel_dims = [bottleneck_dim for i in range(layer_num + 1)] + if d_s > 1: + d_s = d_model + else: + d_s = 1 + self.d_s = d_s + channel_dims[0], channel_dims[-1] = d_model, d_s + heads_att_trans = [] + for i in range(self.head_num): + att_trans = nn.Sequential() + for i in range(layer_num - 1): + att_trans.add_module( + 'att_' + str(i), + nn.Conv1d(channel_dims[i], channel_dims[i + 1], 1, 1)) + att_trans.add_module('tanh' + str(i), nn.Tanh()) + att_trans.add_module( + 'att_' + str(layer_num - 1), + nn.Conv1d(channel_dims[layer_num - 1], channel_dims[layer_num], + 1, 1)) + heads_att_trans.append(att_trans) + self.heads_att_trans = nn.ModuleList(heads_att_trans) + + def forward(self, input): + """ + input: a 3-dimensional tensor in xvector architecture + or a 4-dimensional tensor in resnet architecture + 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) + """ + if len(input.shape) == 4: # B x F x T + input = input.reshape(input.shape[0], + input.shape[1] * input.shape[2], + input.shape[3]) + assert len(input.shape) == 3 + bs, f_dim, t_dim = input.shape + chunks = torch.chunk(input, self.head_num, 1) + # split + chunks_out = [] + # for i in range(self.head_num): + # att_score = self.heads_att_trans[i](chunks[i]) + for i, layer in enumerate(self.heads_att_trans): + att_score = layer(chunks[i]) + alpha = F.softmax(att_score, dim=-1) + mean = torch.sum(alpha * chunks[i], dim=2) + var = torch.sum(alpha * chunks[i]**2, dim=2) - mean**2 + std = torch.sqrt(var.clamp(min=1e-7)) + chunks_out.append(torch.cat((mean, std), dim=1)) + out = torch.cat(chunks_out, dim=1) + return out + + def get_out_dim(self): + self.out_dim = 2 * self.in_dim + return self.out_dim + + +class MQMHASTP(torch.nn.Module): + """ An attentive pooling + Reference: + multi query multi head attentive statistics pooling + https://arxiv.org/pdf/2110.05042.pdf + Args: + in_dim: the feature dimension of input + layer_num: the number of layer in the pooling layer + query_num: the number of querys + head_num: the number of heads + bottleneck_dim: the bottleneck dimension + + SA (H = 1, Q = 1, n = 2, d_s = 1) ref: + https://www.danielpovey.com/files/2018_interspeech_xvector_attention.pdf + MHA (H > 1, Q = 1, n = 1, d_s = 1) ref: + https://arxiv.org/pdf/1906.09890.pdf + AS (H = 1, Q > 1, n = 2, d_s = 1) ref: + https://arxiv.org/pdf/1803.10963.pdf + VSA (H = 1, Q > 1, n = 2, d_s = d_h) ref: + http://www.interspeech2020.org/uploadfile/pdf/Mon-2-10-5.pdf + """ + + def __init__(self, + in_dim, + layer_num=2, + query_num=2, + head_num=8, + d_s=2, + bottleneck_dim=64, + **kwargs): + super(MQMHASTP, self).__init__() + self.n_query = nn.ModuleList([ + MHASTP(in_dim, + layer_num=layer_num, + head_num=head_num, + d_s=d_s, + bottleneck_dim=bottleneck_dim) for i in range(query_num) + ]) + self.query_num = query_num + self.in_dim = in_dim + + def forward(self, input): + """ + input: a 3-dimensional tensor in xvector architecture + or a 4-dimensional tensor in resnet architecture + 0-dim: batch-dimension, last-dim: time-dimension (frame-dimension) + """ + if len(input.shape) == 4: # B x F x T + input = input.reshape(input.shape[0], + input.shape[1] * input.shape[2], + input.shape[3]) + assert len(input.shape) == 3 + res = [] + for i, layer in enumerate(self.n_query): + res.append(layer(input)) + out = torch.cat(res, dim=-1) + return out + + def get_out_dim(self): + self.out_dim = self.in_dim * 2 * self.query_num + return self.out_dim + + + +# --- Paste sparktts/modules/speaker/ecapa_tdnn.py content here --- + +class Res2Conv1dReluBn(nn.Module): + """ + in_channels == out_channels == channels + """ + + def __init__( + self, + channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + scale=4, + ): + super().__init__() + assert channels % scale == 0, "{} % {} != 0".format(channels, scale) + self.scale = scale + self.width = channels // scale + self.nums = scale if scale == 1 else scale - 1 + + self.convs = [] + self.bns = [] + for i in range(self.nums): + self.convs.append( + nn.Conv1d( + self.width, + self.width, + kernel_size, + stride, + padding, + dilation, + bias=bias, + ) + ) + self.bns.append(nn.BatchNorm1d(self.width)) + self.convs = nn.ModuleList(self.convs) + self.bns = nn.ModuleList(self.bns) + + def forward(self, x): + out = [] + spx = torch.split(x, self.width, 1) + sp = spx[0] + for i, (conv, bn) in enumerate(zip(self.convs, self.bns)): + # Order: conv -> relu -> bn + if i >= 1: + sp = sp + spx[i] + sp = conv(sp) + sp = bn(F.relu(sp)) + out.append(sp) + if self.scale != 1: + out.append(spx[self.nums]) + out = torch.cat(out, dim=1) + + return out + + +""" Conv1d + BatchNorm1d + ReLU +""" + + +class Conv1dReluBn(nn.Module): + + def __init__( + self, + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + dilation=1, + bias=True, + ): + super().__init__() + self.conv = nn.Conv1d( + in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias + ) + self.bn = nn.BatchNorm1d(out_channels) + + def forward(self, x): + return self.bn(F.relu(self.conv(x))) + + +""" The SE connection of 1D case. +""" + + +class SE_Connect(nn.Module): + + def __init__(self, channels, se_bottleneck_dim=128): + super().__init__() + self.linear1 = nn.Linear(channels, se_bottleneck_dim) + self.linear2 = nn.Linear(se_bottleneck_dim, channels) + + def forward(self, x): + out = x.mean(dim=2) + out = F.relu(self.linear1(out)) + out = torch.sigmoid(self.linear2(out)) + out = x * out.unsqueeze(2) + + return out + + +""" SE-Res2Block of the ECAPA-TDNN architecture. +""" + + +class SE_Res2Block(nn.Module): + + def __init__(self, channels, kernel_size, stride, padding, dilation, scale): + super().__init__() + self.se_res2block = nn.Sequential( + Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), + Res2Conv1dReluBn( + channels, kernel_size, stride, padding, dilation, scale=scale + ), + Conv1dReluBn(channels, channels, kernel_size=1, stride=1, padding=0), + SE_Connect(channels), + ) + + def forward(self, x): + return x + self.se_res2block(x) + + +class ECAPA_TDNN(nn.Module): + + def __init__( + self, + channels=512, + feat_dim=80, + embed_dim=192, + pooling_func="ASTP", + global_context_att=False, + emb_bn=False, + ): + super().__init__() + + self.layer1 = Conv1dReluBn(feat_dim, channels, kernel_size=5, padding=2) + self.layer2 = SE_Res2Block( + channels, kernel_size=3, stride=1, padding=2, dilation=2, scale=8 + ) + self.layer3 = SE_Res2Block( + channels, kernel_size=3, stride=1, padding=3, dilation=3, scale=8 + ) + self.layer4 = SE_Res2Block( + channels, kernel_size=3, stride=1, padding=4, dilation=4, scale=8 + ) + + cat_channels = channels * 3 + out_channels = 512 * 3 + self.conv = nn.Conv1d(cat_channels, out_channels, kernel_size=1) + self.pool = globals()[pooling_func]( + in_dim=out_channels, global_context_att=global_context_att + ) + self.pool_out_dim = self.pool.get_out_dim() + self.bn = nn.BatchNorm1d(self.pool_out_dim) + self.linear = nn.Linear(self.pool_out_dim, embed_dim) + self.emb_bn = emb_bn + if emb_bn: # better in SSL for SV + self.bn2 = nn.BatchNorm1d(embed_dim) + else: + self.bn2 = nn.Identity() + + def forward(self, x, return_latent=False): + x = x.permute(0, 2, 1) # (B,T,F) -> (B,F,T) + + out1 = self.layer1(x) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + out4 = self.layer4(out3) + + out = torch.cat([out2, out3, out4], dim=1) + latent = F.relu(self.conv(out)) + out = self.bn(self.pool(latent)) + out = self.linear(out) + if self.emb_bn: + out = self.bn2(out) + + if return_latent: + return out, latent + return out + + +def ECAPA_TDNN_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): + return ECAPA_TDNN( + channels=1024, + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + emb_bn=emb_bn, + ) + + +def ECAPA_TDNN_GLOB_c1024(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): + return ECAPA_TDNN( + channels=1024, + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + emb_bn=emb_bn, + ) + + +def ECAPA_TDNN_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): + return ECAPA_TDNN( + channels=512, + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + emb_bn=emb_bn, + ) + + +def ECAPA_TDNN_GLOB_c512(feat_dim, embed_dim, pooling_func="ASTP", emb_bn=False): + return ECAPA_TDNN( + channels=512, + feat_dim=feat_dim, + embed_dim=embed_dim, + pooling_func=pooling_func, + global_context_att=True, + emb_bn=emb_bn, + ) + + +# --- Paste sparktts/modules/speaker/perceiver_encoder.py content here --- + +def once(fn): + called = False + + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + + return inner + + +print_once = once(print) + +# main class + + +class Attend(nn.Module): + def __init__(self, dropout=0.0, causal=False, use_flash=False): + super().__init__() + self.dropout = dropout + self.attn_dropout = nn.Dropout(dropout) + + self.causal = causal + self.register_buffer("mask", None, persistent=False) + + self.use_flash = use_flash + assert not ( + use_flash and version.parse(torch.__version__) < version.parse("2.0.0") + ), "in order to use flash attention, you must be using pytorch 2.0 or above" + + # determine efficient attention configs for cuda and cpu + self.config = namedtuple( + "EfficientAttentionConfig", + ["enable_flash", "enable_math", "enable_mem_efficient"], + ) + self.cpu_config = self.config(True, True, True) + self.cuda_config = None + + if not torch.cuda.is_available() or not use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) + + if device_properties.major == 8 and device_properties.minor == 0: + print_once( + "A100 GPU detected, using flash attention if input tensor is on cuda" + ) + self.cuda_config = self.config(True, False, False) + else: + print_once( + "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda" + ) + self.cuda_config = self.config(False, True, True) + + def get_mask(self, n, device): + if exists(self.mask) and self.mask.shape[-1] >= n: + return self.mask[:n, :n] + + mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) + self.register_buffer("mask", mask, persistent=False) + return mask + + def flash_attn(self, q, k, v, mask=None): + _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda + + # Recommended for multi-query single-key-value attention by Tri Dao + # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) + + if k.ndim == 3: + k = rearrange(k, "b ... -> b 1 ...").expand_as(q) + + if v.ndim == 3: + v = rearrange(v, "b ... -> b 1 ...").expand_as(q) + + # Check if mask exists and expand to compatible shape + # The mask is B L, so it would have to be expanded to B H N L + + if exists(mask): + mask = rearrange(mask, "b j -> b 1 1 j") + mask = mask.expand(-1, heads, q_len, -1) + + # Check if there is a compatible device for flash attention + + config = self.cuda_config if is_cuda else self.cpu_config + + # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale + + with torch.backends.cuda.sdp_kernel(**config._asdict()): + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=self.causal, + ) + + return out + + def forward(self, q, k, v, mask=None): + """ + einstein notation + b - batch + h - heads + n, i, j - sequence length (base sequence length, source, target) + d - feature dimension + """ + + n, device = q.shape[-2], q.device + + scale = q.shape[-1] ** -0.5 + + if self.use_flash: + return self.flash_attn(q, k, v, mask=mask) + + kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" + + # similarity + + sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale + + # key padding mask + + if exists(mask): + mask = rearrange(mask, "b j -> b 1 1 j") + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # causal mask + + if self.causal: + causal_mask = self.get_mask(n, device) + sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) + + # attention + + attn = sim.softmax(dim=-1) + attn = self.attn_dropout(attn) + + # aggregate values + + out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) + + return out + + +def Sequential(*mods): + return nn.Sequential(*filter(exists, mods)) + + +class RMSNorm(nn.Module): + def __init__(self, dim, scale=True, dim_cond=None): + super().__init__() + self.cond = exists(dim_cond) + self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None + + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(dim)) if scale else None + + def forward(self, x, cond=None): + gamma = default(self.gamma, 1) + out = F.normalize(x, dim=-1) * self.scale * gamma + + if not self.cond: + return out + + assert exists(cond) + gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1) + gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta)) + return out * gamma + beta + + +class CausalConv1d(nn.Conv1d): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + (kernel_size,) = self.kernel_size + (dilation,) = self.dilation + (stride,) = self.stride + + assert stride == 1 + self.causal_padding = dilation * (kernel_size - 1) + + def forward(self, x): + causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) + return super().forward(causal_padded_x) + + +class GEGLU(nn.Module): + def forward(self, x): + x, gate = x.chunk(2, dim=-1) + return F.gelu(gate) * x + + +def FeedForward(dim, mult=4, causal_conv=False): + dim_inner = int(dim * mult * 2 / 3) + + conv = None + if causal_conv: + conv = nn.Sequential( + Rearrange("b n d -> b d n"), + CausalConv1d(dim_inner, dim_inner, 3), + Rearrange("b d n -> b n d"), + ) + + return Sequential( + nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim) + ) + + +class Attention(nn.Module): + def __init__( + self, + dim, + *, + dim_context=None, + causal=False, + dim_head=64, + heads=8, + dropout=0.0, + use_flash=False, + cross_attn_include_queries=False, + ): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + self.cross_attn_include_queries = cross_attn_include_queries + + dim_inner = dim_head * heads + dim_context = default(dim_context, dim) + + self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) + self.to_q = nn.Linear(dim, dim_inner, bias=False) + self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) + self.to_out = nn.Linear(dim_inner, dim, bias=False) + + def forward(self, x, context=None, mask=None): + h, has_context = self.heads, exists(context) + + context = default(context, x) + + if has_context and self.cross_attn_include_queries: + context = torch.cat((x, context), dim=-2) + + q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = self.attend(q, k, v, mask=mask) + + out = rearrange(out, "b h n d -> b n (h d)") + return self.to_out(out) + + +class PerceiverResampler(nn.Module): + def __init__( + self, + *, + dim, + depth=2, + dim_context=None, + num_latents=32, + dim_head=64, + heads=8, + ff_mult=4, + use_flash_attn=False, + ): + super().__init__() + dim_context = default(dim_context, dim) + + self.proj_context = ( + nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() + ) + + self.latents = nn.Parameter(torch.randn(num_latents, dim)) + nn.init.normal_(self.latents, std=0.02) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + Attention( + dim=dim, + dim_head=dim_head, + heads=heads, + use_flash=use_flash_attn, + cross_attn_include_queries=True, + ), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + self.norm = RMSNorm(dim) + + def forward(self, x, mask=None): + batch = x.shape[0] + + x = self.proj_context(x) + + latents = repeat(self.latents, "n d -> b n d", b=batch) + + for attn, ff in self.layers: + latents = attn(latents, x, mask=mask) + latents + latents = ff(latents) + latents + + return self.norm(latents) + + +# --- Paste sparktts/modules/speaker/speaker_encoder.py content here --- + +class SpeakerEncoder(nn.Module): + """ + + Args: + input_dim (int): acoustic feature dimension + out_dim (int): output dimension of x-vector and d-vector + latent_dim (int): latent dimension before quantization + token_num (int): sequence length of speaker tokens + fsq_levels (List[int]): number of levels for each quantizer + fsq_num_quantizers (int): number of quantizers + + Return: + speaker_embs: (B, T2, out_dim) + """ + + def __init__( + self, + input_dim: int = 100, + out_dim: int = 512, + latent_dim: int = 128, + token_num: int = 32, + fsq_levels: List[int] = [4, 4, 4, 4, 4, 4], + fsq_num_quantizers: int = 1, + ): + super(SpeakerEncoder, self).__init__() + + self.speaker_encoder = ECAPA_TDNN_GLOB_c512( + feat_dim=input_dim, embed_dim=out_dim + ) + self.perceiver_sampler = PerceiverResampler( + dim=latent_dim, dim_context=512 * 3, num_latents=token_num + ) + self.quantizer = ResidualFSQ( + levels=fsq_levels, + num_quantizers=fsq_num_quantizers, + dim=latent_dim, + is_channel_first=True, + quantize_dropout=False, + ) + + self.project = nn.Linear(latent_dim * token_num, out_dim) + + def get_codes_from_indices(self, indices: torch.Tensor) -> torch.Tensor: + zq = self.quantizer.get_codes_from_indices(indices.transpose(1, 2)) + return zq.transpose(1, 2) + + def get_indices(self, mels: torch.Tensor) -> torch.Tensor: + mels = mels.transpose(1, 2) + x = self.perceiver_sampler(mels).transpose(1, 2) + zq, indices = self.quantizer(x) + return indices + + def forward(self, mels: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + mels: (B, D_mel, T1) + + Return: + x_vector: (B, out_dim) + d_vector: (B, out_dim) + """ + # mels = mels.transpose(1,2) + + x_vector, features = self.speaker_encoder(mels, True) + x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2) + zq, indices = self.quantizer(x) # zq: (B, latent_dim, T2, latent_dim) + x = zq.reshape(zq.shape[0], -1) + d_vector = self.project(x) + + return x_vector, d_vector + + def tokenize(self, mels: torch.Tensor) -> torch.Tensor: + """tokenize the input mel spectrogram""" + _, features = self.speaker_encoder(mels, True) + x = self.perceiver_sampler(features.transpose(1, 2)).transpose(1, 2) + zq, indices = self.quantizer(x) + return indices + + def detokenize(self, indices: torch.Tensor) -> torch.Tensor: + """detokenize the input indices to d-vector""" + zq = self.quantizer.get_output_from_indices(indices.transpose(1, 2)).transpose(1, 2) + x = zq.reshape(zq.shape[0], -1) + d_vector = self.project(x) + return d_vector + + +# --- Paste sparktts/modules/encoder_decoder/feat_encoder.py content here --- + +class Encoder(nn.Module): + """Encoder module with convnext and downsampling blocks""" + + def __init__( + self, + input_channels: int, + vocos_dim: int, + vocos_intermediate_dim: int, + vocos_num_layers: int, + out_channels: int, + sample_ratios: List[int] = [1, 1], + ): + super().__init__() + """ + Encoder module with VocosBackbone and sampling blocks. + + Args: + sample_ratios (List[int]): sample ratios + example: [2, 2] means downsample by 2x and then upsample by 2x + """ + self.encoder = VocosBackbone( + input_channels=input_channels, + dim=vocos_dim, + intermediate_dim=vocos_intermediate_dim, + num_layers=vocos_num_layers, + condition_dim=None, + ) + + modules = [ + nn.Sequential( + SamplingBlock( + dim=vocos_dim, + groups=vocos_dim, + downsample_scale=ratio, + ), + VocosBackbone( + input_channels=vocos_dim, + dim=vocos_dim, + intermediate_dim=vocos_intermediate_dim, + num_layers=2, + condition_dim=None, + ), + ) + for ratio in sample_ratios + ] + + self.downsample = nn.Sequential(*modules) + + self.project = nn.Linear(vocos_dim, out_channels) + + def forward(self, x: torch.Tensor, *args): + """ + Args: + x (torch.Tensor): (batch_size, input_channels, length) + + Returns: + x (torch.Tensor): (batch_size, encode_channels, length) + """ + x = self.encoder(x) + x = self.downsample(x) + x = self.project(x) + return x.transpose(1, 2) + + + +# --- Paste sparktts/modules/encoder_decoder/feat_decoder.py content here --- + +class Decoder(nn.Module): + """Decoder module with convnext and upsampling blocks + + Args: + sample_ratios (List[int]): sample ratios + example: [2, 2] means downsample by 2x and then upsample by 2x + """ + + def __init__( + self, + input_channels: int, + vocos_dim: int, + vocos_intermediate_dim: int, + vocos_num_layers: int, + out_channels: int, + condition_dim: int = None, + sample_ratios: List[int] = [1, 1], + use_tanh_at_final: bool = False, + ): + super().__init__() + + self.linear_pre = nn.Linear(input_channels, vocos_dim) + modules = [ + nn.Sequential( + SamplingBlock( + dim=vocos_dim, + groups=vocos_dim, + upsample_scale=ratio, + ), + VocosBackbone( + input_channels=vocos_dim, + dim=vocos_dim, + intermediate_dim=vocos_intermediate_dim, + num_layers=2, + condition_dim=None, + ), + ) + for ratio in sample_ratios + ] + + self.downsample = nn.Sequential(*modules) + + self.vocos_backbone = VocosBackbone( + input_channels=vocos_dim, + dim=vocos_dim, + intermediate_dim=vocos_intermediate_dim, + num_layers=vocos_num_layers, + condition_dim=condition_dim, + ) + self.linear = nn.Linear(vocos_dim, out_channels) + self.use_tanh_at_final = use_tanh_at_final + + def forward(self, x: torch.Tensor, c: torch.Tensor = None): + """encoder forward. + + Args: + x (torch.Tensor): (batch_size, input_channels, length) + + Returns: + x (torch.Tensor): (batch_size, encode_channels, length) + """ + x = self.linear_pre(x.transpose(1, 2)) + x = self.downsample(x).transpose(1, 2) + x = self.vocos_backbone(x, condition=c) + x = self.linear(x).transpose(1, 2) + if self.use_tanh_at_final: + x = torch.tanh(x) + + return x + + +# --- Paste sparktts/modules/encoder_decoder/wave_generator.py content here --- + +class DecoderBlock(nn.Module): + def __init__( + self, + input_dim: int = 16, + output_dim: int = 8, + kernel_size: int = 2, + stride: int = 1, + ): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=kernel_size, + stride=stride, + padding=(kernel_size - stride) // 2, + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class WaveGenerator(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + kernel_sizes, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, (kernel_size, stride) in enumerate(zip(kernel_sizes, rates)): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, kernel_size, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + self.apply(init_weights) + + def forward(self, x): + return self.model(x) + + +# --- Paste sparktts/modules/vq/factorized_vector_quantize.py content here --- + +def ema_inplace(moving_avg, new, decay): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +class FactorizedVectorQuantize(nn.Module): + def __init__( + self, + input_dim: int, + codebook_size: int, + codebook_dim: int, + commitment: float, + codebook_loss_weight: float = 1.0, + decay: float = 0.99, + threshold_ema_dead_code: float = 2, + momentum: float = 0.99, + **kwargs, + ): + super().__init__() + self.input_dim = input_dim + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.commitment = commitment + self.codebook_loss_weight = codebook_loss_weight + self.decay = decay + self.threshold_ema_dead_code = threshold_ema_dead_code + self.momentum = momentum + + if input_dim != self.codebook_dim: + self.in_project = WNConv1d(input_dim, self.codebook_dim, kernel_size=1) + self.out_project = WNConv1d(self.codebook_dim, input_dim, kernel_size=1) + + else: + self.in_project = nn.Identity() + self.out_project = nn.Identity() + + self.codebook = nn.Embedding(self.codebook_size, self.codebook_dim) + self.register_buffer("cluster_size", torch.zeros(self.codebook_size)) + + def forward(self, z: torch.Tensor) -> Dict[str, Any]: + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + # transpose since we use linear + + # Factorized codes project input into low-dimensional space if self.input_dim != self.codebook_dim + z_e = self.in_project(z) + z_q, indices, dists = self.decode_latents(z_e) + + # statistic the usage of codes + embed_onehot = F.one_hot(indices, self.codebook_size).type(z_e.dtype) + avg_probs = torch.mean(embed_onehot.reshape(-1, self.codebook_size), dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + active_num = (embed_onehot.sum(0).sum(0) > 0).sum() + if self.training: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + ema_inplace(self.cluster_size, embed_onehot.sum(0).sum(0), self.decay) + active_num = sum(self.cluster_size > self.threshold_ema_dead_code) + + if self.training: + commit_loss = ( + F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + * self.commitment + ) + + codebook_loss = ( + F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + * self.codebook_loss_weight + ) + + else: + commit_loss = torch.zeros(0, device=z.device) + codebook_loss = torch.zeros(0, device=z.device) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_project(z_q) + + vq_loss = (commit_loss + codebook_loss).mean() + + return { + "z_q": z_q, + "indices": indices, + "dists": dists, + "vq_loss": vq_loss, + "perplexity": perplexity, + "active_num": active_num.float(), + } + + def vq2emb(self, vq, out_proj=True): + emb = self.embed_code(vq) + if out_proj: + emb = self.out_project(emb) + return emb + + def tokenize(self, z: torch.Tensor) -> torch.Tensor: + """tokenize the input tensor""" + z_e = self.in_project(z) + _, indices, _ = self.decode_latents(z_e) + return indices + + def detokenize(self, indices): + """detokenize the input indices""" + z_q = self.decode_code(indices) + z_q = self.out_project(z_q) + return z_q + + def get_emb(self): + return self.codebook.weight + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight + + # L2 normalize encodings and codebook + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance between encodings and codebook, + # with L2 normalization, the distance is equal to cosine distance + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + + return z_q, indices, dist + + +# ============================================================================= +# >> END: PASTE CODE FROM sparktts/modules/* HERE << +# ============================================================================= + + +# ============================================================================= +# >> START: PASTE CODE FROM sparktts/models/bicodec.py HERE << +# ============================================================================= +# IMPORTANT: The BiCodec class definition needs to be here. +# Modify its loading mechanism as suggested. + + +class BiCodec(nn.Module): + def __init__( + self, + mel_params: Dict[str, Any], + encoder: nn.Module, + decoder: nn.Module, + quantizer: nn.Module, + speaker_encoder: nn.Module, + prenet: nn.Module, + postnet: nn.Module, + **kwargs + ) -> None: + super().__init__() + self.encoder = encoder + self.decoder = decoder + self.quantizer = quantizer + self.speaker_encoder = speaker_encoder + self.prenet = prenet + self.postnet = postnet + self.init_mel_transformer(mel_params) + + @classmethod + def load_from_config_and_checkpoint(cls, model_dir: Path, bicodec_config_object: SparkTTSBiCodecConfig) -> "BiCodec": + """ + Loads the BiCodec model using a SparkTTSBiCodecConfig object and a checkpoint file. + Args: + model_dir (Path): Path to the directory containing the model checkpoint ('model.safetensors'). + bicodec_config_object (SparkTTSBiCodecConfig): The nested config object from SparkTTSConfig. + Returns: + BiCodec: The initialized BiCodec model. + """ + ckpt_path = model_dir / 'model.safetensors' + if not ckpt_path.exists(): + ckpt_path_bin = model_dir / 'pytorch_model.bin' + if ckpt_path_bin.exists(): + ckpt_path = ckpt_path_bin + else: + raise FileNotFoundError(f"BiCodec checkpoint not found at {model_dir / 'model.safetensors'} or potential fallbacks.") + + # Instantiate components using specific attributes from the nested config objects + mel_params_config = bicodec_config_object.mel_params + encoder_cfg = bicodec_config_object.encoder_config + decoder_cfg = bicodec_config_object.decoder_config # WaveGenerator config + quantizer_cfg = bicodec_config_object.quantizer_config + speaker_encoder_cfg = bicodec_config_object.speaker_encoder_config + prenet_cfg = bicodec_config_object.prenet_config + postnet_cfg = bicodec_config_object.postnet_config + + # Pass only the arguments expected by each module's __init__ + mel_params = mel_params_config.to_dict() # Mel params might be needed as dict + + encoder = Encoder( + input_channels=encoder_cfg.input_channels, + vocos_dim=encoder_cfg.vocos_dim, + vocos_intermediate_dim=encoder_cfg.vocos_intermediate_dim, + vocos_num_layers=encoder_cfg.vocos_num_layers, + out_channels=encoder_cfg.out_channels, + sample_ratios=encoder_cfg.sample_ratios, + ) + quantizer = FactorizedVectorQuantize( + input_dim=quantizer_cfg.input_dim, + codebook_size=quantizer_cfg.codebook_size, + codebook_dim=quantizer_cfg.codebook_dim, + commitment=quantizer_cfg.commitment, + codebook_loss_weight=quantizer_cfg.codebook_loss_weight, + decay=quantizer_cfg.decay, + threshold_ema_dead_code=quantizer_cfg.threshold_ema_dead_code, + # Add any other kwargs FactorizedVectorQuantize expects from its config + ) + prenet = Decoder( # Assuming Prenet uses the Decoder class structure + input_channels=prenet_cfg.input_channels, + vocos_dim=prenet_cfg.vocos_dim, + vocos_intermediate_dim=prenet_cfg.vocos_intermediate_dim, + vocos_num_layers=prenet_cfg.vocos_num_layers, + out_channels=prenet_cfg.out_channels, + condition_dim=prenet_cfg.condition_dim, + sample_ratios=prenet_cfg.sample_ratios, + use_tanh_at_final=prenet_cfg.use_tanh_at_final, + ) + postnet = Decoder( # Assuming Postnet uses the Decoder class structure + input_channels=postnet_cfg.input_channels, + vocos_dim=postnet_cfg.vocos_dim, + vocos_intermediate_dim=postnet_cfg.vocos_intermediate_dim, + vocos_num_layers=postnet_cfg.vocos_num_layers, + out_channels=postnet_cfg.out_channels, + # condition_dim=postnet_cfg.condition_dim, # Postnet might not have condition_dim + # sample_ratios=postnet_cfg.sample_ratios, # Postnet might not have sample_ratios + use_tanh_at_final=postnet_cfg.use_tanh_at_final, + ) + decoder = WaveGenerator( # This is the actual audio decoder + input_channel=decoder_cfg.input_channel, + channels=decoder_cfg.channels, + rates=decoder_cfg.rates, + kernel_sizes=decoder_cfg.kernel_sizes, + # d_out is likely fixed to 1 internally in WaveGenerator, not configured + ) + speaker_encoder = SpeakerEncoder( + input_dim=speaker_encoder_cfg.input_dim, + out_dim=speaker_encoder_cfg.out_dim, + latent_dim=speaker_encoder_cfg.latent_dim, + token_num=speaker_encoder_cfg.token_num, + fsq_levels=speaker_encoder_cfg.fsq_levels, + fsq_num_quantizers=speaker_encoder_cfg.fsq_num_quantizers, + ) + + # Instantiate the BiCodec model itself + model = cls( + mel_params=mel_params, # Pass the dict here + encoder=encoder, + decoder=decoder, + quantizer=quantizer, + speaker_encoder=speaker_encoder, + prenet=prenet, + postnet=postnet, + ) + + # --- State Dict Loading --- + logger.info(f"Loading BiCodec state dict from: {ckpt_path}") + if str(ckpt_path).endswith(".safetensors"): + state_dict = load_file(ckpt_path, device="cpu") # Load to CPU first + else: + state_dict = torch.load(ckpt_path, map_location="cpu") + + missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) + + if missing_keys: + logger.warning(f"BiCodec Missing keys: {missing_keys}") + if unexpected_keys: + logger.warning(f"BiCodec Unexpected keys: {unexpected_keys}") + + model.eval() + model.remove_weight_norm() # Important step from original code + + logger.info("BiCodec loaded successfully.") + return model +# +# # --- Paste the rest of the BiCodec methods here --- +# # forward, tokenize, detokenize, init_mel_transformer, remove_weight_norm + + def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """ + Performs a forward pass through the model. + + Args: + batch (dict): A dictionary containing features, reference waveform, and target waveform. + + Returns: + dict: A dictionary containing the reconstruction, features, and other metrics. + """ + feat = batch["feat"] + mel = self.mel_transformer(batch["ref_wav"]).squeeze(1) + + z = self.encoder(feat.transpose(1, 2)) + vq_outputs = self.quantizer(z) + + x_vector, d_vector = self.speaker_encoder(mel.transpose(1, 2)) + + conditions = d_vector + with_speaker_loss = False + + x = self.prenet(vq_outputs["z_q"], conditions) + pred_feat = self.postnet(x) + x = x + conditions.unsqueeze(-1) + wav_recon = self.decoder(x) + + return { + "vq_loss": vq_outputs["vq_loss"], + "perplexity": vq_outputs["perplexity"], + "cluster_size": vq_outputs["active_num"], + "recons": wav_recon, + "pred_feat": pred_feat, + "x_vector": x_vector, + "d_vector": d_vector, + "audios": batch["wav"].unsqueeze(1), + "with_speaker_loss": with_speaker_loss, + } + + + @torch.no_grad() + def tokenize(self, batch: Dict[str, Any]): + """ + Tokenizes the input audio into semantic and global tokens. + + Args: + batch (dict): The input audio features and reference waveform. + + Returns: + tuple: Semantic tokens and global tokens. + """ + feat = batch["feat"] + mel = self.mel_transformer(batch["ref_wav"]).squeeze(1) + z = self.encoder(feat.transpose(1, 2)) + semantic_tokens = self.quantizer.tokenize(z) + global_tokens = self.speaker_encoder.tokenize(mel.transpose(1, 2)) + + return semantic_tokens, global_tokens + + @torch.no_grad() + def detokenize(self, semantic_tokens, global_tokens): + """ + Detokenizes the semantic and global tokens into a waveform. + + Args: + semantic_tokens (tensor): Semantic tokens. + global_tokens (tensor): Global tokens. + + Returns: + tensor: Reconstructed waveform. + """ + z_q = self.quantizer.detokenize(semantic_tokens) + d_vector = self.speaker_encoder.detokenize(global_tokens) + x = self.prenet(z_q, d_vector) + x = x + d_vector.unsqueeze(-1) + wav_recon = self.decoder(x) + + return wav_recon + + def init_mel_transformer(self, config: Dict[str, Any]): + """ + Initializes the MelSpectrogram transformer based on the provided configuration. + + Args: + config (dict): Configuration parameters for MelSpectrogram. + """ + import torchaudio.transforms as TT + + self.mel_transformer = TT.MelSpectrogram( + config["sample_rate"], + config["n_fft"], + config["win_length"], + config["hop_length"], + config["mel_fmin"], + config["mel_fmax"], + n_mels=config["num_mels"], + power=1, + norm="slaney", + mel_scale="slaney", + ) + + def remove_weight_norm(self): + """Removes weight normalization from all layers.""" + def _remove_weight_norm(m): + try: + torch.nn.utils.remove_weight_norm(m) + except ValueError: + pass # The module didn't have weight norm + + self.apply(_remove_weight_norm) + +# ============================================================================= +# >> END: PASTE CODE FROM sparktts/models/bicodec.py HERE << +# ============================================================================= + + +# ============================================================================= +# >> START: PASTE CODE FROM sparktts/utils/audio.py HERE (if needed by model) << +# ============================================================================= +# Functions like audio_volume_normalize, load_audio, etc., are typically part +# of the Processor. However, if any are directly used *within* the BiCodec or +# other model components pasted above, they need to be defined here too. +# It seems `get_ref_clip` logic might be needed if `BiCodecTokenizer` logic is embedded. + +# Example placeholder comment: + +def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray: + """ + Normalize the volume of an audio signal. + + Parameters: + audio (numpy array): Input audio signal array. + coeff (float): Target coefficient for normalization, default is 0.2. + + Returns: + numpy array: The volume-normalized audio signal. + """ + # Sort the absolute values of the audio signal + temp = np.sort(np.abs(audio)) + + # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1 + if temp[-1] < 0.1: + scaling_factor = max( + temp[-1], 1e-3 + ) # Prevent division by zero with a small constant + audio = audio / scaling_factor * 0.1 + + # Filter out values less than 0.01 from temp + temp = temp[temp > 0.01] + L = temp.shape[0] # Length of the filtered array + + # If there are fewer than or equal to 10 significant values, return the audio without further processing + if L <= 10: + return audio + + # Compute the average of the top 10% to 1% of values in temp + volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)]) + + # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10 + audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10) + + # Ensure the maximum absolute value in the audio does not exceed 1 + max_value = np.max(np.abs(audio)) + if max_value > 1: + audio = audio / max_value + + return audio + + +def load_audio( + adfile: Path, + sampling_rate: int = None, + length: int = None, + volume_normalize: bool = False, + segment_duration: int = None, +) -> np.ndarray: + r"""Load audio file with target sampling rate and lsength + + Args: + adfile (Path): path to audio file. + sampling_rate (int, optional): target sampling rate. Defaults to None. + length (int, optional): target audio length. Defaults to None. + volume_normalize (bool, optional): whether perform volume normalization. Defaults to False. + segment_duration (int): random select a segment with duration of {segment_duration}s. + Defualt to None which means the whole audio will be used. + + Returns: + audio (np.ndarray): audio + """ + + audio, sr = soundfile.read(adfile) + if len(audio.shape) > 1: + audio = audio[:, 0] + + if sampling_rate is not None and sr != sampling_rate: + audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ") + sr = sampling_rate + + if segment_duration is not None: + seg_length = int(sr * segment_duration) + audio = random_select_audio_segment(audio, seg_length) + + # Audio volume normalize + if volume_normalize: + audio = audio_volume_normalize(audio) + # check the audio length + if length is not None: + assert abs(audio.shape[0] - length) < 1000 + if audio.shape[0] > length: + audio = audio[:length] + else: + audio = np.pad(audio, (0, int(length - audio.shape[0]))) + return audio + + +def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray: + """get an audio segment given the length + + Args: + audio (np.ndarray): + length (int): audio length = sampling_rate * duration + """ + if audio.shape[0] < length: + audio = np.pad(audio, (0, int(length - audio.shape[0]))) + start_index = random.randint(0, audio.shape[0] - length) + end_index = int(start_index + length) + + return audio[start_index:end_index] + + +# ============================================================================= +# >> END: PASTE CODE FROM sparktts/utils/audio.py HERE (if needed by model) << +# ============================================================================= + + +class SparkTTSModel(PreTrainedModel, GenerationMixin): + """ + Spark-TTS model integrating a Language Model (LLM) for sequence generation, + a Wav2Vec2 model for feature extraction, and a BiCodec model for audio + tokenization and synthesis. Designed for compatibility with the Hugging Face ecosystem. + """ + config_class = SparkTTSConfig + base_model_prefix = "spark_tts" # Or perhaps "llm" if generation focuses there + main_input_name = "input_ids" # Crucial for GenerationMixin + + def __init__( + self, + config: SparkTTSConfig, + llm: Optional[PreTrainedModel] = None, + wav2vec2_model: Optional[PreTrainedModel] = None, + wav2vec2_processor: Optional[Wav2Vec2FeatureExtractor] = None, # Store processor too + bicodec: Optional[nn.Module] = None, # Should be the loaded BiCodec instance + ): + super().__init__(config) + self.config = config # Stores the main SparkTTSConfig + + # Store the sub-components + self.llm = llm + self.wav2vec2_model = wav2vec2_model + self.wav2vec2_processor = wav2vec2_processor # Store the processor used for features + self.bicodec = bicodec + + # Ensure Wav2Vec2 is configured for hidden states needed by BiCodec's feature extractor + if self.wav2vec2_model: + self.wav2vec2_model.config.output_hidden_states = True + + # Post initialization checks (optional but good practice) + if not all([self.llm, self.wav2vec2_model, self.wav2vec2_processor, self.bicodec]): + logger.warning( + "SparkTTSModel initialized without all sub-components. " + "Ensure `from_pretrained` is used for loading a complete model." + ) + + def get_input_embeddings(self): + """Returns the input embeddings of the LLM.""" + if self.llm: + return self.llm.get_input_embeddings() + return None + + def set_input_embeddings(self, value): + """Sets the input embeddings of the LLM.""" + if self.llm: + self.llm.set_input_embeddings(value) + + def _prepare_wav2vec2_features(self, wav: torch.Tensor) -> torch.Tensor: + """ + Extracts Wav2Vec2 features required by BiCodec. + Input wav should be a batch of waveforms [B, T_audio]. + """ + if not self.wav2vec2_model or not self.wav2vec2_processor: + raise ValueError("Wav2Vec2 model or processor not loaded.") + + # Get target device and dtype from the Wav2Vec2 model + target_device = self.wav2vec2_model.device + target_dtype = self.wav2vec2_model.dtype # Get the model's dtype (e.g., bfloat16) + + # Input wav tensor might be float32, processor usually expects float32 + wav_for_processor = wav.to(device=target_device, dtype=torch.float32) + + # Process using the Wav2Vec2FeatureExtractor + # The processor typically outputs float32 + inputs = self.wav2vec2_processor( + wav_for_processor, + sampling_rate=self.config.sample_rate, # Use config SR + return_tensors="pt", + padding=True, + ) + input_values = inputs.input_values.to(target_device) # Move to device + + # --- Cast the input_values to the model's expected dtype --- + input_values = input_values.to(dtype=target_dtype) + # ---------------------------------------------------------- + + # --- CRITICAL CHECK AND FIX --- + # Ensure input_values is 2D [Batch, Length] before passing to the model + if input_values.ndim == 3 and input_values.shape[1] == 1: + logger.warning(f"Processor returned 3D input_values {input_values.shape}. Squeezing the channel dimension.") + input_values = input_values.squeeze(1) + elif input_values.ndim != 2: + raise ValueError(f"Expected input_values from processor to be 2D [Batch, Length], but got shape {input_values.shape}") + # --- END CHECK AND FIX --- + + # Extract features using the Wav2Vec2Model + with torch.no_grad(): # Feature extraction should not require gradients here + # Now the input dtype matches the model's parameter dtype + feat_outputs = self.wav2vec2_model(input_values) + + # Combine specific hidden states as per original BiCodecTokenizer logic + if not feat_outputs.hidden_states: + raise ValueError("Wav2Vec2 model did not return hidden states. Ensure config.output_hidden_states=True.") + if len(feat_outputs.hidden_states) < 17: + # Wav2Vec2-large-xlsr has 24 layers + initial embeddings = 25 states + logger.warning(f"Wav2Vec2 model returned {len(feat_outputs.hidden_states)} hidden states. Expected at least 17 for default BiCodec indices (11, 14, 16). Check model architecture or BiCodec indices if this is unexpected.") + # Attempt to proceed if possible, otherwise raise error if indices are out of bounds + idx1, idx2, idx3 = 11, 14, 16 + if not (0 <= idx1 < len(feat_outputs.hidden_states) and \ + 0 <= idx2 < len(feat_outputs.hidden_states) and \ + 0 <= idx3 < len(feat_outputs.hidden_states)): + raise ValueError(f"Required hidden state indices ({idx1}, {idx2}, {idx3}) are out of bounds for the {len(feat_outputs.hidden_states)} hidden states returned.") + else: + idx1, idx2, idx3 = 11, 14, 16 + + + feats_mix = ( + feat_outputs.hidden_states[idx1] + + feat_outputs.hidden_states[idx2] + + feat_outputs.hidden_states[idx3] + ) / 3 + + # Ensure the output features also match the expected downstream dtype (e.g., bicodec) + # Usually okay if subsequent layers also use the same target_dtype + return feats_mix.to(dtype=target_dtype) # Return features in the target dtype # Shape: [B, T_feats, D_feats] + + @torch.no_grad() + def tokenize_audio(self, wav: torch.Tensor, ref_wav: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Tokenizes audio using the BiCodec model. + Args: + wav (torch.Tensor): The main audio waveform [B, T_audio]. (Should be float32 initially) + ref_wav (torch.Tensor): The reference audio waveform [B, T_ref_audio]. (Should be float32 initially) + + Returns: + Tuple[torch.Tensor, torch.Tensor]: global_tokens, semantic_tokens + """ + if not self.bicodec: + raise ValueError("BiCodec model not loaded.") + + # 1. Extract Wav2Vec2 features for the main audio + # _prepare_wav2vec2_features now handles internal dtype casting for w2v model + feats = self._prepare_wav2vec2_features(wav) # Returns features in model's target dtype + + # 2. Prepare batch for BiCodec + # Ensure tensors are on the BiCodec's device AND correct dtype + # Get device and dtype from a BiCodec submodule parameter + bicodec_param = next(self.bicodec.parameters()) + target_device = bicodec_param.device + target_dtype = bicodec_param.dtype # Get BiCodec's dtype + + batch = { + # Cast inputs to BiCodec's expected dtype + "wav": wav.to(device=target_device, dtype=target_dtype), + "ref_wav": ref_wav.to(device=target_device, dtype=target_dtype), + "feat": feats.to(device=target_device, dtype=target_dtype), # Ensure feats are also correct dtype + } + + # 3. Call BiCodec's tokenize method + semantic_tokens, global_tokens = self.bicodec.tokenize(batch) + + return global_tokens, semantic_tokens + + @torch.no_grad() + def detokenize_audio(self, global_tokens: torch.Tensor, semantic_tokens: torch.Tensor) -> np.ndarray: + """ + Detokenizes audio tokens back to a waveform using BiCodec. + Args: + global_tokens (torch.Tensor): Global tokens [B, ...]. + semantic_tokens (torch.Tensor): Semantic tokens [B, ...]. + + Returns: + np.ndarray: The reconstructed waveform [T_audio_out] if B=1, or [B, T_audio_out] if B > 1, + with dtype float32 and values clipped to [-1, 1]. + """ + if not self.bicodec: + raise ValueError("BiCodec model not loaded.") + + target_device = next(self.bicodec.parameters()).device + + # Adjust shapes as expected by BiCodec.detokenize if needed + if global_tokens.ndim == 2: # Example adjustment + global_tokens = global_tokens.unsqueeze(1) + + logger.debug(f"DEBUG: Detokenizing audio with global tokens {global_tokens.shape}, semantic tokens {semantic_tokens.shape}") + + wav_rec = self.bicodec.detokenize( + semantic_tokens.to(target_device), + global_tokens.to(target_device) + ) # Output tensor likely float32 or model's dtype + + # Convert to numpy, ensure float32, clip + wav_rec_np = wav_rec.detach().cpu().numpy().astype(np.float32) # Ensure float32 + wav_rec_np = np.clip(wav_rec_np, -1.0, 1.0) # Clip values + + logger.debug(f"DEBUG: Wav rec shape after detach and clip: {wav_rec_np.shape}") # Shape is likely (B, C, T) e.g., (1, 1, 24640) + + # ============================================================== + # CORRECTED SQUEEZE LOGIC + # ============================================================== + # Remove all dimensions of size 1 (batch and channel if they are 1) + # This handles both B=1, C=1 -> (T,) and potentially B>1, C=1 -> (B, T) + # If C > 1, it would return (B, C, T) or (C, T) if B=1. + # soundfile handles (T,) and (T, C) correctly. + output_wav = wav_rec_np.squeeze() + # ============================================================== + + logger.debug(f"DEBUG: Final output wav shape after squeeze: {output_wav.shape}") + + # Ensure the output is at least 1D even if squeeze removes everything (e.g., single sample output) + if output_wav.ndim == 0: + output_wav = np.expand_dims(output_wav, axis=0) + + return output_wav + + def prepare_inputs_for_generation( + self, input_ids: torch.LongTensor, past_key_values: Optional[list] = None, attention_mask: Optional[torch.Tensor] = None, **kwargs + ) -> dict: + """ + Prepares inputs for the generation process (standard method for GenerationMixin). + """ + # Add position_ids and handle past_key_values for causal LM generation + # This is a standard implementation for causal LMs. + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "position_ids": position_ids, + # Add any other inputs the LLM's forward method expects + } + + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + # Add other potential inputs for the LLM (position_ids, past_key_values, etc.) + position_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + """ + The forward pass primarily delegates to the underlying LLM. + It takes tokenized text/audio prompts as input_ids. + """ + if not self.llm: + raise ValueError("LLM component not loaded.") + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Pass arguments directly to the LLM's forward method + outputs = self.llm( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + return outputs # Should be CausalLMOutputWithPast or tuple + + @classmethod + @torch.no_grad() # Decorator often used for loading, though internal ops might need grads later + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + revision: str = "main", + use_safetensors: Optional[bool] = None, + # New args from base class signature to pass down if relevant + state_dict = None, # Pass state_dict explicitly is usually avoided with component loading + device_map = None, # Simplified handling + low_cpu_mem_usage = None, # Simplified handling + torch_dtype = "auto", # Keep "auto" as default + quantization_config = None, # Pass down if needed by components + trust_remote_code = None, # Default to None, will be set below + # Add other relevant args from base class if needed: subfolder, variant, etc. + subfolder: str = "", + variant: Optional[str] = None, + **kwargs, + ): + # --- Argument Handling & Initial Setup --- + # Pop device map and dtype early - handle placement later + if device_map: + logger.warning("`device_map` is not directly supported for this composite model. Use .to(device) after loading.") + if low_cpu_mem_usage: + logger.info("`low_cpu_mem_usage` is set, but simplified loading is used. Memory usage might not be optimized.") + + # Handle trust_remote_code explicitly for custom code loading + if trust_remote_code is None: + logger.warning( + "Loading SparkTTSModel requires custom code. Setting `trust_remote_code=True`. " + "Make sure you trust the source of the code you are loading." + ) + trust_remote_code = True + elif not trust_remote_code: + raise ValueError("Loading SparkTTSModel requires `trust_remote_code=True`.") + + # Pop unused kwargs specific to base class loading logic if not handled here + kwargs.pop("output_loading_info", None) + kwargs.pop("_from_auto", None) + kwargs.pop("attn_implementation", None) # LLM loader might handle this + + # --- 1. Resolve the main model directory --- + if state_dict is not None: + raise ValueError("Explicitly passing `state_dict` is not supported for this composite model. Load components individually if needed.") + if pretrained_model_name_or_path is None: + raise ValueError("`pretrained_model_name_or_path` must be provided.") + + is_local = Path(pretrained_model_name_or_path).is_dir() + if local_files_only and not is_local: + raise ValueError(f"Cannot find local directory at {pretrained_model_name_or_path} when `local_files_only=True`.") + + if is_local: + resolved_model_path = Path(pretrained_model_name_or_path) + logger.info(f"Loading model from local directory: {resolved_model_path}") + else: + logger.info(f"{pretrained_model_name_or_path} is not a local directory. Assuming Hub ID and downloading.") + try: + # Use snapshot_download to get all necessary files + resolved_model_path_str = snapshot_download( + repo_id=str(pretrained_model_name_or_path), + cache_dir=cache_dir, + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + allow_patterns=[ # Be more specific if possible + "*.json", "*.safetensors", "*.bin", "*.yaml", "*.txt", + "README.md", ".gitattributes", # Common files + "LLM/*", "BiCodec/*", "wav2vec2-large-xlsr-53/*" # Component folders + ], + ignore_patterns=["*.git*", "*.h5", "*.ot", "*.msgpack"], # Ignore unnecessary files + subfolder=subfolder, # Pass subfolder to snapshot_download + repo_type="model", # Specify repo type + ) + resolved_model_path = Path(resolved_model_path_str) + logger.info(f"Model files downloaded to cache: {resolved_model_path}") + except Exception as e: + raise OSError( + f"Failed to download model '{pretrained_model_name_or_path}' (subfolder: '{subfolder}') from Hugging Face Hub. " + f"Error: {e}" + ) + + if not resolved_model_path.is_dir(): + raise EnvironmentError(f"Resolved model path is not a directory: {resolved_model_path}") + + # If subfolder is used, update resolved_model_path to point inside it + if subfolder: + resolved_model_path = resolved_model_path / subfolder + if not resolved_model_path.is_dir(): + raise EnvironmentError(f"Subfolder '{subfolder}' not found within the resolved path: {resolved_model_path.parent}") + + + # --- 2. Load the main configuration --- + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else resolved_model_path + try: + loaded_config, model_kwargs = SparkTTSConfig.from_pretrained( + config_path, + *model_args, # Pass model_args here + cache_dir=cache_dir, + force_download=force_download if not is_local else False, + local_files_only=local_files_only or is_local, + token=token, + revision=revision, + trust_remote_code=trust_remote_code, # Crucial if config class is remote + subfolder="", # Config is usually at the root, not subfolder + return_unused_kwargs=True, + **kwargs, # Pass remaining kwargs for config loading + ) + config = loaded_config + kwargs = model_kwargs # Update kwargs with unused ones + except OSError as e: + raise OSError(f"Cannot load config for '{pretrained_model_name_or_path}'. Check `config.json` exists and is correctly formatted. Error: {e}") + # else: config object was passed directly + + # --- Determine final torch_dtype --- + final_torch_dtype = torch_dtype # Explicit arg has highest prio + if final_torch_dtype == "auto": + final_torch_dtype = getattr(config, "torch_dtype", None) # Use config value if present + # Convert string to torch.dtype object if needed + if isinstance(final_torch_dtype, str) and final_torch_dtype != "auto": + try: + final_torch_dtype = getattr(torch, final_torch_dtype) + except AttributeError: + logger.warning(f"Invalid torch_dtype string: {final_torch_dtype}. Falling back to default.") + final_torch_dtype = None # Fallback to None (which means float32 usually) + elif final_torch_dtype == "auto": + final_torch_dtype = None # Treat "auto" as None for component loading + + # --- Helper function to resolve paths relative to the main model directory --- + # (This handles components potentially being in subfolders specified in config) + def _resolve_sub_path(sub_path_str): + p = Path(sub_path_str) + if p.is_absolute(): + if not p.exists(): logger.warning(f"Absolute path specified for sub-component does not exist: {p}") + return str(p) + else: + # Resolve relative to the main model path (which might be in cache or local) + resolved = resolved_model_path / p + if not resolved.exists(): + # Check if the path exists without the leading './' often found in configs + resolved_alt = resolved_model_path / sub_path_str.lstrip('./') + if resolved_alt.exists(): + resolved = resolved_alt + else: + raise FileNotFoundError(f"Could not resolve sub-component path: {resolved} (relative to {resolved_model_path})") + return str(resolved) + + # --- Component Loading Arguments --- + component_loading_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "local_files_only": local_files_only, + "token": token, + "revision": revision, + "trust_remote_code": trust_remote_code, # Pass this down + "torch_dtype": final_torch_dtype, # Pass resolved dtype + "use_safetensors": use_safetensors, + # Pass quantization config if provided and relevant to component + "quantization_config": quantization_config if quantization_config else None, + # Pass variant if needed for specific component checkpoints + "variant": variant, + # Filter kwargs? For now, pass all remaining, component loaders should ignore unused ones. + **kwargs, + } + + # --- 3. Load Sub-components --- + + # --- Load LLM --- + llm_path = _resolve_sub_path(config.llm_model_name_or_path) + logger.info(f"Loading LLM from resolved path: {llm_path}") + try: + llm = AutoModelForCausalLM.from_pretrained( + llm_path, **component_loading_kwargs + ) + except Exception as e: + raise OSError(f"Failed to load LLM from {llm_path}: {e}") + + # --- Load Wav2Vec2 --- + w2v_path = _resolve_sub_path(config.wav2vec2_model_name_or_path) + logger.info(f"Loading Wav2Vec2 components from resolved path: {w2v_path}") + try: + # Use specific class for extractor, Auto* might not work if only config is present + wav2vec2_processor = Wav2Vec2FeatureExtractor.from_pretrained( + w2v_path, + cache_dir=cache_dir, # Pass relevant args + force_download=force_download, + local_files_only=local_files_only, + token=token, + revision=revision, + # No trust_remote_code needed usually for feature extractors + ) + wav2vec2_model = Wav2Vec2Model.from_pretrained( + w2v_path, **component_loading_kwargs # Pass full kwargs here + ) + wav2vec2_model.config.output_hidden_states = True # Ensure this is set + except Exception as e: + raise OSError(f"Failed to load Wav2Vec2 components from {w2v_path}: {e}") + + # --- Load BiCodec --- + bicodec_path = _resolve_sub_path(config.bicodec_model_name_or_path) + logger.info(f"Loading BiCodec from resolved path: {bicodec_path}") + if not config.bicodec_config: # Check if the nested config object exists + raise ValueError("BiCodec configuration (`bicodec_config`) not found or properly instantiated in SparkTTSConfig.") + try: + # Pass the SparkTTSBiCodecConfig *object* directly + bicodec = BiCodec.load_from_config_and_checkpoint( + model_dir=Path(bicodec_path), + bicodec_config_object=config.bicodec_config # Pass the object + ) + if not isinstance(bicodec, torch.nn.Module): + logger.warning("Loaded BiCodec component is not an instance of torch.nn.Module.") + # Apply torch_dtype to BiCodec if it's an nn.Module and dtype is set + if isinstance(bicodec, torch.nn.Module) and final_torch_dtype: + bicodec = bicodec.to(dtype=final_torch_dtype) + + except FileNotFoundError as e: + raise OSError(f"Failed to load BiCodec: A required file was not found in {bicodec_path}. Original error: {e}") + except Exception as e: + logger.error(f"Raw error loading BiCodec: {type(e).__name__}: {e}") + import traceback + traceback.print_exc() + raise OSError(f"Failed to load BiCodec from {bicodec_path}. Check BiCodec implementation, config, and file paths. Error: {e}") + + # --- 4. Instantiate the main model wrapper --- + model = cls( + config, + llm=llm, + wav2vec2_model=wav2vec2_model, + wav2vec2_processor=wav2vec2_processor, + bicodec=bicodec + ) + + # --- 5. Handle device placement (Simplified) --- + # Determine target device (simple logic: CUDA > MPS > CPU) + if torch.cuda.is_available(): + final_device = torch.device("cuda") + elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): # Check MPS availability + final_device = torch.device("mps") + else: + final_device = torch.device("cpu") + + logger.info(f"Placing SparkTTSModel and components on device: {final_device}") + try: + model.to(final_device) + except Exception as e: + logger.error(f"Failed to move model to device {final_device}. Error: {e}") + logger.warning("Device placement might be incomplete. Check component types and implementations.") + + # --- 6. Return the loaded and prepared model --- + return model \ No newline at end of file