Upload 4 files
Browse files- config.json +83 -0
- configuration_spark_tts.py +233 -0
- modeling_spark_tts.py +0 -0
- processing_spark_tts.py +889 -0
config.json
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"model_type": "spark-tts",
|
3 |
+
"architectures": [
|
4 |
+
"SparkTTSModel"
|
5 |
+
],
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "configuration_spark_tts.SparkTTSConfig",
|
8 |
+
"AutoModel": "modeling_spark_tts.SparkTTSModel",
|
9 |
+
"AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
|
10 |
+
},
|
11 |
+
"processor_class": "processing_spark_tts.SparkTTSProcessor",
|
12 |
+
"llm_model_name_or_path": "./LLM",
|
13 |
+
"bicodec_model_name_or_path": "./BiCodec",
|
14 |
+
"wav2vec2_model_name_or_path": "./wav2vec2-large-xlsr-53",
|
15 |
+
"sample_rate": 16000,
|
16 |
+
"highpass_cutoff_freq": 40,
|
17 |
+
"latent_hop_length": 320,
|
18 |
+
"ref_segment_duration": 6.0,
|
19 |
+
"volume_normalize": true,
|
20 |
+
"torch_dtype": "bfloat16",
|
21 |
+
"transformers_version": "4.43.1",
|
22 |
+
"_commit_hash": null,
|
23 |
+
"bicodec_config": {
|
24 |
+
"mel_params": {
|
25 |
+
"sample_rate": 16000,
|
26 |
+
"n_fft": 1024,
|
27 |
+
"win_length": 640,
|
28 |
+
"hop_length": 320,
|
29 |
+
"mel_fmin": 10,
|
30 |
+
"mel_fmax": null,
|
31 |
+
"num_mels": 128
|
32 |
+
},
|
33 |
+
"encoder_config": {
|
34 |
+
"input_channels": 1024,
|
35 |
+
"vocos_dim": 384,
|
36 |
+
"vocos_intermediate_dim": 2048,
|
37 |
+
"vocos_num_layers": 12,
|
38 |
+
"out_channels": 1024,
|
39 |
+
"sample_ratios": [1, 1]
|
40 |
+
},
|
41 |
+
"decoder_config": {
|
42 |
+
"input_channel": 1024,
|
43 |
+
"channels": 1536,
|
44 |
+
"rates": [8, 5, 4, 2],
|
45 |
+
"kernel_sizes": [16, 11, 8, 4]
|
46 |
+
},
|
47 |
+
"quantizer_config": {
|
48 |
+
"input_dim": 1024,
|
49 |
+
"codebook_size": 8192,
|
50 |
+
"codebook_dim": 8,
|
51 |
+
"commitment": 0.25,
|
52 |
+
"codebook_loss_weight": 2.0,
|
53 |
+
"decay": 0.99,
|
54 |
+
"threshold_ema_dead_code": 0.2
|
55 |
+
},
|
56 |
+
"speaker_encoder_config": {
|
57 |
+
"input_dim": 128,
|
58 |
+
"out_dim": 1024,
|
59 |
+
"latent_dim": 128,
|
60 |
+
"token_num": 32,
|
61 |
+
"fsq_levels": [4, 4, 4, 4, 4, 4],
|
62 |
+
"fsq_num_quantizers": 1
|
63 |
+
},
|
64 |
+
"prenet_config": {
|
65 |
+
"input_channels": 1024,
|
66 |
+
"vocos_dim": 384,
|
67 |
+
"vocos_intermediate_dim": 2048,
|
68 |
+
"vocos_num_layers": 12,
|
69 |
+
"out_channels": 1024,
|
70 |
+
"condition_dim": 1024,
|
71 |
+
"sample_ratios": [1, 1],
|
72 |
+
"use_tanh_at_final": false
|
73 |
+
},
|
74 |
+
"postnet_config": {
|
75 |
+
"input_channels": 1024,
|
76 |
+
"vocos_dim": 384,
|
77 |
+
"vocos_intermediate_dim": 2048,
|
78 |
+
"vocos_num_layers": 6,
|
79 |
+
"out_channels": 1024,
|
80 |
+
"use_tanh_at_final": false
|
81 |
+
}
|
82 |
+
}
|
83 |
+
}
|
configuration_spark_tts.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
|
3 |
+
# ... (License headers remain the same) ...
|
4 |
+
""" SparkTTS model configuration"""
|
5 |
+
|
6 |
+
from transformers.configuration_utils import PretrainedConfig
|
7 |
+
from transformers.utils import logging
|
8 |
+
from typing import List, Optional # Added typing
|
9 |
+
|
10 |
+
|
11 |
+
logger = logging.get_logger(__name__)
|
12 |
+
|
13 |
+
# --- Define Individual Sub-Component Config Classes ---
|
14 |
+
|
15 |
+
class SparkTTSMelParamsConfig(PretrainedConfig):
|
16 |
+
"""Configuration for Mel Spectrogram parameters."""
|
17 |
+
model_type = "spark-tts-mel-params"
|
18 |
+
def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320,
|
19 |
+
mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs):
|
20 |
+
super().__init__(**kwargs)
|
21 |
+
self.sample_rate = sample_rate
|
22 |
+
self.n_fft = n_fft
|
23 |
+
self.win_length = win_length
|
24 |
+
self.hop_length = hop_length
|
25 |
+
self.mel_fmin = mel_fmin
|
26 |
+
self.mel_fmax = mel_fmax
|
27 |
+
self.num_mels = num_mels
|
28 |
+
|
29 |
+
class SparkTTSEncoderConfig(PretrainedConfig):
|
30 |
+
"""Configuration for the BiCodec Feature Encoder."""
|
31 |
+
model_type = "spark-tts-encoder"
|
32 |
+
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
33 |
+
vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs):
|
34 |
+
super().__init__(**kwargs)
|
35 |
+
self.input_channels = input_channels
|
36 |
+
self.vocos_dim = vocos_dim
|
37 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
38 |
+
self.vocos_num_layers = vocos_num_layers
|
39 |
+
self.out_channels = out_channels
|
40 |
+
self.sample_ratios = sample_ratios
|
41 |
+
|
42 |
+
class SparkTTSDecoderConfig(PretrainedConfig):
|
43 |
+
"""Configuration for the BiCodec Wave Generator (Decoder)."""
|
44 |
+
model_type = "spark-tts-decoder"
|
45 |
+
def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2],
|
46 |
+
kernel_sizes=[16, 11, 8, 4], **kwargs):
|
47 |
+
super().__init__(**kwargs)
|
48 |
+
self.input_channel = input_channel
|
49 |
+
self.channels = channels
|
50 |
+
self.rates = rates
|
51 |
+
self.kernel_sizes = kernel_sizes
|
52 |
+
|
53 |
+
class SparkTTSQuantizerConfig(PretrainedConfig):
|
54 |
+
"""Configuration for the BiCodec Factorized Vector Quantizer."""
|
55 |
+
model_type = "spark-tts-quantizer"
|
56 |
+
def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8,
|
57 |
+
commitment=0.25, codebook_loss_weight=2.0, decay=0.99,
|
58 |
+
threshold_ema_dead_code=0.2, **kwargs):
|
59 |
+
# Note: Removed use_l2_normlize as it wasn't in the original class __init__ args
|
60 |
+
# Add it back if it's actually used by the FactorizedVectorQuantize class init
|
61 |
+
super().__init__(**kwargs)
|
62 |
+
self.input_dim = input_dim
|
63 |
+
self.codebook_size = codebook_size
|
64 |
+
self.codebook_dim = codebook_dim
|
65 |
+
self.commitment = commitment
|
66 |
+
self.codebook_loss_weight = codebook_loss_weight
|
67 |
+
self.decay = decay
|
68 |
+
self.threshold_ema_dead_code = threshold_ema_dead_code
|
69 |
+
|
70 |
+
class SparkTTSSpeakerEncoderConfig(PretrainedConfig):
|
71 |
+
"""Configuration for the BiCodec Speaker Encoder."""
|
72 |
+
model_type = "spark-tts-speaker-encoder"
|
73 |
+
def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32,
|
74 |
+
fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs):
|
75 |
+
super().__init__(**kwargs)
|
76 |
+
self.input_dim = input_dim
|
77 |
+
self.out_dim = out_dim
|
78 |
+
self.latent_dim = latent_dim
|
79 |
+
self.token_num = token_num
|
80 |
+
self.fsq_levels = fsq_levels
|
81 |
+
self.fsq_num_quantizers = fsq_num_quantizers
|
82 |
+
|
83 |
+
class SparkTTSPrenetConfig(PretrainedConfig):
|
84 |
+
"""Configuration for the BiCodec Prenet."""
|
85 |
+
model_type = "spark-tts-prenet"
|
86 |
+
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
87 |
+
vocos_num_layers=12, out_channels=1024, condition_dim=1024,
|
88 |
+
sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs):
|
89 |
+
super().__init__(**kwargs)
|
90 |
+
self.input_channels = input_channels
|
91 |
+
self.vocos_dim = vocos_dim
|
92 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
93 |
+
self.vocos_num_layers = vocos_num_layers
|
94 |
+
self.out_channels = out_channels
|
95 |
+
self.condition_dim = condition_dim
|
96 |
+
self.sample_ratios = sample_ratios
|
97 |
+
self.use_tanh_at_final = use_tanh_at_final
|
98 |
+
|
99 |
+
class SparkTTSPostnetConfig(PretrainedConfig):
|
100 |
+
"""Configuration for the BiCodec Postnet."""
|
101 |
+
model_type = "spark-tts-postnet"
|
102 |
+
def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
|
103 |
+
vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs):
|
104 |
+
# Note: Removed condition_dim as it wasn't in the original config example for postnet
|
105 |
+
super().__init__(**kwargs)
|
106 |
+
self.input_channels = input_channels
|
107 |
+
self.vocos_dim = vocos_dim
|
108 |
+
self.vocos_intermediate_dim = vocos_intermediate_dim
|
109 |
+
self.vocos_num_layers = vocos_num_layers
|
110 |
+
self.out_channels = out_channels
|
111 |
+
self.use_tanh_at_final = use_tanh_at_final
|
112 |
+
|
113 |
+
|
114 |
+
# --- Define the Intermediate BiCodec Config Class ---
|
115 |
+
|
116 |
+
class SparkTTSBiCodecConfig(PretrainedConfig):
|
117 |
+
"""
|
118 |
+
Intermediate configuration class for the BiCodec component within SparkTTS.
|
119 |
+
It holds instances of the individual sub-component configurations.
|
120 |
+
"""
|
121 |
+
model_type = "spark-tts-bicodec"
|
122 |
+
# Map keys in the 'bicodec_config' dict to their respective classes
|
123 |
+
sub_configs = {
|
124 |
+
"mel_params": SparkTTSMelParamsConfig,
|
125 |
+
"encoder_config": SparkTTSEncoderConfig,
|
126 |
+
"decoder_config": SparkTTSDecoderConfig,
|
127 |
+
"quantizer_config": SparkTTSQuantizerConfig,
|
128 |
+
"speaker_encoder_config": SparkTTSSpeakerEncoderConfig,
|
129 |
+
"prenet_config": SparkTTSPrenetConfig,
|
130 |
+
"postnet_config": SparkTTSPostnetConfig,
|
131 |
+
}
|
132 |
+
|
133 |
+
def __init__(
|
134 |
+
self,
|
135 |
+
mel_params=None,
|
136 |
+
encoder_config=None,
|
137 |
+
decoder_config=None,
|
138 |
+
quantizer_config=None,
|
139 |
+
speaker_encoder_config=None,
|
140 |
+
prenet_config=None,
|
141 |
+
postnet_config=None,
|
142 |
+
**kwargs,
|
143 |
+
):
|
144 |
+
super().__init__(**kwargs)
|
145 |
+
|
146 |
+
# Instantiate sub-configs from dictionaries or use defaults/provided instances
|
147 |
+
self.mel_params = self._init_sub_config(mel_params, "mel_params")
|
148 |
+
self.encoder_config = self._init_sub_config(encoder_config, "encoder_config")
|
149 |
+
self.decoder_config = self._init_sub_config(decoder_config, "decoder_config")
|
150 |
+
self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config")
|
151 |
+
self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config")
|
152 |
+
self.prenet_config = self._init_sub_config(prenet_config, "prenet_config")
|
153 |
+
self.postnet_config = self._init_sub_config(postnet_config, "postnet_config")
|
154 |
+
|
155 |
+
def _init_sub_config(self, config_input, config_key):
|
156 |
+
"""Helper to initialize sub-configs."""
|
157 |
+
config_cls = self.sub_configs[config_key]
|
158 |
+
if isinstance(config_input, dict):
|
159 |
+
return config_cls(**config_input)
|
160 |
+
elif config_input is None:
|
161 |
+
return config_cls() # Initialize with defaults
|
162 |
+
elif isinstance(config_input, config_cls):
|
163 |
+
return config_input # Already an instance
|
164 |
+
else:
|
165 |
+
raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.")
|
166 |
+
|
167 |
+
|
168 |
+
# --- Define the Main SparkTTS Config Class ---
|
169 |
+
|
170 |
+
class SparkTTSConfig(PretrainedConfig):
|
171 |
+
r"""
|
172 |
+
Main configuration class for SparkTTSModel, including nested BiCodec configuration.
|
173 |
+
Args:
|
174 |
+
llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`): Path/ID for LLM.
|
175 |
+
bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`): Path/ID for BiCodec checkpoint.
|
176 |
+
wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`): Path/ID for Wav2Vec2.
|
177 |
+
sample_rate (`int`, *optional*, defaults to 16000): Audio sample rate.
|
178 |
+
# ... (other top-level args: highpass_cutoff_freq, latent_hop_length, ref_segment_duration, volume_normalize) ...
|
179 |
+
bicodec_config (`dict`, *optional*): Dictionary to initialize `SparkTTSBiCodecConfig`.
|
180 |
+
torch_dtype (`str`, *optional*, defaults to `"auto"`): Torch dtype.
|
181 |
+
kwargs (*optional*): Dictionary of keyword arguments.
|
182 |
+
"""
|
183 |
+
model_type = "spark-tts"
|
184 |
+
# Map the key in config.json to the intermediate BiCodec config class
|
185 |
+
sub_configs = {"bicodec_config": SparkTTSBiCodecConfig}
|
186 |
+
attribute_map = {"hidden_size": "d_model"} # Example
|
187 |
+
|
188 |
+
def __init__(
|
189 |
+
self,
|
190 |
+
llm_model_name_or_path="./LLM",
|
191 |
+
bicodec_model_name_or_path="./BiCodec",
|
192 |
+
wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
|
193 |
+
sample_rate=16000,
|
194 |
+
highpass_cutoff_freq=40,
|
195 |
+
latent_hop_length=320,
|
196 |
+
ref_segment_duration=6.0,
|
197 |
+
volume_normalize=True,
|
198 |
+
bicodec_config=None, # Expects a dictionary or None
|
199 |
+
torch_dtype="auto",
|
200 |
+
**kwargs,
|
201 |
+
):
|
202 |
+
# --- Top-level parameters ---
|
203 |
+
self.llm_model_name_or_path = llm_model_name_or_path
|
204 |
+
self.bicodec_model_name_or_path = bicodec_model_name_or_path
|
205 |
+
self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
|
206 |
+
self.sample_rate = sample_rate
|
207 |
+
self.highpass_cutoff_freq = highpass_cutoff_freq
|
208 |
+
self.latent_hop_length = latent_hop_length
|
209 |
+
self.ref_segment_duration = ref_segment_duration
|
210 |
+
self.volume_normalize = volume_normalize
|
211 |
+
self.torch_dtype = torch_dtype
|
212 |
+
|
213 |
+
# --- Nested BiCodec Configuration ---
|
214 |
+
# Instantiate the intermediate BiCodec config class, which will handle its own sub-configs
|
215 |
+
if isinstance(bicodec_config, dict):
|
216 |
+
self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config)
|
217 |
+
elif bicodec_config is None:
|
218 |
+
logger.info("`bicodec_config` not provided. Initializing `SparkTTSBiCodecConfig` with its defaults.")
|
219 |
+
self.bicodec_config = self.sub_configs["bicodec_config"]()
|
220 |
+
elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]):
|
221 |
+
self.bicodec_config = bicodec_config # Use existing instance
|
222 |
+
else:
|
223 |
+
raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}. Expected dict, None, or SparkTTSBiCodecConfig.")
|
224 |
+
|
225 |
+
|
226 |
+
# Set processor class and auto_map
|
227 |
+
kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor")
|
228 |
+
kwargs["auto_map"] = kwargs.get("auto_map", {
|
229 |
+
"AutoConfig": "configuration_spark_tts.SparkTTSConfig",
|
230 |
+
"AutoModel": "modeling_spark_tts.SparkTTSModel",
|
231 |
+
"AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
|
232 |
+
})
|
233 |
+
super().__init__(**kwargs)
|
modeling_spark_tts.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
processing_spark_tts.py
ADDED
@@ -0,0 +1,889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2025 SparkAudio & The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""
|
16 |
+
Processor class for SparkTTS. Combines text tokenization and audio feature extraction/processing.
|
17 |
+
"""
|
18 |
+
|
19 |
+
import os # Needed for save_pretrained
|
20 |
+
import re # For decoding
|
21 |
+
import torch
|
22 |
+
import numpy as np
|
23 |
+
import soundfile as sf # For audio loading
|
24 |
+
import soxr # For resampling
|
25 |
+
|
26 |
+
from pathlib import Path
|
27 |
+
from typing import Optional, Union, List, Dict, Tuple, Any
|
28 |
+
|
29 |
+
from transformers.processing_utils import ProcessorMixin
|
30 |
+
from transformers.tokenization_utils_base import BatchEncoding # Return type hint
|
31 |
+
from transformers.feature_extraction_utils import BatchFeature # Return type hint
|
32 |
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
33 |
+
from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
|
34 |
+
from transformers.utils import logging, PushToHubMixin # Added PushToHubMixin
|
35 |
+
from numpy.lib.stride_tricks import sliding_window_view
|
36 |
+
import soxr
|
37 |
+
import soundfile
|
38 |
+
import random
|
39 |
+
|
40 |
+
# Import custom config if needed for defaults
|
41 |
+
from .configuration_spark_tts import SparkTTSConfig
|
42 |
+
|
43 |
+
logger = logging.get_logger(__name__)
|
44 |
+
|
45 |
+
|
46 |
+
# =============================================================================
|
47 |
+
# >> START: PASTE CODE FROM sparktts/utils/* HERE <<
|
48 |
+
# =============================================================================
|
49 |
+
# IMPORTANT: Utility functions needed for processing (audio loading, token parsing)
|
50 |
+
# must be defined or imported here.
|
51 |
+
|
52 |
+
# --- Paste sparktts/utils/audio.py content here ---
|
53 |
+
|
54 |
+
def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
|
55 |
+
"""
|
56 |
+
Normalize the volume of an audio signal.
|
57 |
+
|
58 |
+
Parameters:
|
59 |
+
audio (numpy array): Input audio signal array.
|
60 |
+
coeff (float): Target coefficient for normalization, default is 0.2.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
numpy array: The volume-normalized audio signal.
|
64 |
+
"""
|
65 |
+
# Sort the absolute values of the audio signal
|
66 |
+
temp = np.sort(np.abs(audio))
|
67 |
+
|
68 |
+
# If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
|
69 |
+
if temp[-1] < 0.1:
|
70 |
+
scaling_factor = max(
|
71 |
+
temp[-1], 1e-3
|
72 |
+
) # Prevent division by zero with a small constant
|
73 |
+
audio = audio / scaling_factor * 0.1
|
74 |
+
|
75 |
+
# Filter out values less than 0.01 from temp
|
76 |
+
temp = temp[temp > 0.01]
|
77 |
+
L = temp.shape[0] # Length of the filtered array
|
78 |
+
|
79 |
+
# If there are fewer than or equal to 10 significant values, return the audio without further processing
|
80 |
+
if L <= 10:
|
81 |
+
return audio
|
82 |
+
|
83 |
+
# Compute the average of the top 10% to 1% of values in temp
|
84 |
+
volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
|
85 |
+
|
86 |
+
# Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
|
87 |
+
audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
|
88 |
+
|
89 |
+
# Ensure the maximum absolute value in the audio does not exceed 1
|
90 |
+
max_value = np.max(np.abs(audio))
|
91 |
+
if max_value > 1:
|
92 |
+
audio = audio / max_value
|
93 |
+
|
94 |
+
return audio
|
95 |
+
|
96 |
+
|
97 |
+
def load_audio(
|
98 |
+
adfile: Path,
|
99 |
+
sampling_rate: int = None,
|
100 |
+
length: int = None,
|
101 |
+
volume_normalize: bool = False,
|
102 |
+
segment_duration: int = None,
|
103 |
+
) -> np.ndarray:
|
104 |
+
r"""Load audio file with target sampling rate and lsength
|
105 |
+
|
106 |
+
Args:
|
107 |
+
adfile (Path): path to audio file.
|
108 |
+
sampling_rate (int, optional): target sampling rate. Defaults to None.
|
109 |
+
length (int, optional): target audio length. Defaults to None.
|
110 |
+
volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
|
111 |
+
segment_duration (int): random select a segment with duration of {segment_duration}s.
|
112 |
+
Defualt to None which means the whole audio will be used.
|
113 |
+
|
114 |
+
Returns:
|
115 |
+
audio (np.ndarray): audio
|
116 |
+
"""
|
117 |
+
|
118 |
+
audio, sr = soundfile.read(adfile)
|
119 |
+
if len(audio.shape) > 1:
|
120 |
+
audio = audio[:, 0]
|
121 |
+
|
122 |
+
if sampling_rate is not None and sr != sampling_rate:
|
123 |
+
audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
|
124 |
+
sr = sampling_rate
|
125 |
+
|
126 |
+
if segment_duration is not None:
|
127 |
+
seg_length = int(sr * segment_duration)
|
128 |
+
audio = random_select_audio_segment(audio, seg_length)
|
129 |
+
|
130 |
+
# Audio volume normalize
|
131 |
+
if volume_normalize:
|
132 |
+
audio = audio_volume_normalize(audio)
|
133 |
+
# check the audio length
|
134 |
+
if length is not None:
|
135 |
+
assert abs(audio.shape[0] - length) < 1000
|
136 |
+
if audio.shape[0] > length:
|
137 |
+
audio = audio[:length]
|
138 |
+
else:
|
139 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
140 |
+
return audio
|
141 |
+
|
142 |
+
|
143 |
+
def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
|
144 |
+
"""get an audio segment given the length
|
145 |
+
|
146 |
+
Args:
|
147 |
+
audio (np.ndarray):
|
148 |
+
length (int): audio length = sampling_rate * duration
|
149 |
+
"""
|
150 |
+
if audio.shape[0] < length:
|
151 |
+
audio = np.pad(audio, (0, int(length - audio.shape[0])))
|
152 |
+
start_index = random.randint(0, audio.shape[0] - length)
|
153 |
+
end_index = int(start_index + length)
|
154 |
+
|
155 |
+
return audio[start_index:end_index]
|
156 |
+
|
157 |
+
def get_ref_clip(wav: np.ndarray, config) -> np.ndarray: # Needs access to config attributes
|
158 |
+
"""Get reference audio clip for speaker embedding."""
|
159 |
+
# Make sure config has sample_rate, ref_segment_duration, latent_hop_length
|
160 |
+
if not all(hasattr(config, attr) for attr in ['sample_rate', 'ref_segment_duration', 'latent_hop_length']):
|
161 |
+
raise AttributeError("Config object missing required attributes for get_ref_clip")
|
162 |
+
ref_segment_length = (
|
163 |
+
int(config.sample_rate * config.ref_segment_duration)
|
164 |
+
// config.latent_hop_length
|
165 |
+
* config.latent_hop_length
|
166 |
+
)
|
167 |
+
wav_length = len(wav)
|
168 |
+
if ref_segment_length > wav_length:
|
169 |
+
wav = np.tile(wav, ref_segment_length // wav_length + 1)
|
170 |
+
return wav[:ref_segment_length]
|
171 |
+
|
172 |
+
|
173 |
+
# --- Paste sparktts/utils/token_parser.py content here ---
|
174 |
+
|
175 |
+
TASK_TOKEN_MAP = {
|
176 |
+
"vc": "<|task_vc|>",
|
177 |
+
"tts": "<|task_tts|>",
|
178 |
+
"asr": "<|task_asr|>",
|
179 |
+
"s2s": "<|task_s2s|>",
|
180 |
+
"t2s": "<|task_t2s|>",
|
181 |
+
"understand": "<|task_understand|>",
|
182 |
+
"caption": "<|task_cap|>",
|
183 |
+
"controllable_tts": "<|task_controllable_tts|>",
|
184 |
+
"prompt_tts": "<|task_prompt_tts|>",
|
185 |
+
"speech_edit": "<|task_edit|>",
|
186 |
+
}
|
187 |
+
|
188 |
+
LEVELS_MAP = {
|
189 |
+
"very_low": 0,
|
190 |
+
"low": 1,
|
191 |
+
"moderate": 2,
|
192 |
+
"high": 3,
|
193 |
+
"very_high": 4,
|
194 |
+
}
|
195 |
+
|
196 |
+
LEVELS_MAP_UI = {
|
197 |
+
1: 'very_low',
|
198 |
+
2: 'low',
|
199 |
+
3: 'moderate',
|
200 |
+
4: 'high',
|
201 |
+
5: 'very_high'
|
202 |
+
}
|
203 |
+
|
204 |
+
GENDER_MAP = {
|
205 |
+
"female": 0,
|
206 |
+
"male": 1,
|
207 |
+
}
|
208 |
+
|
209 |
+
AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
|
210 |
+
|
211 |
+
EMO_MAP = {
|
212 |
+
"UNKNOWN": 0,
|
213 |
+
"NEUTRAL": 1,
|
214 |
+
"ANGRY": 2,
|
215 |
+
"HAPPY": 3,
|
216 |
+
"SAD": 4,
|
217 |
+
"FEARFUL": 5,
|
218 |
+
"DISGUSTED": 6,
|
219 |
+
"SURPRISED": 7,
|
220 |
+
"SARCASTIC": 8,
|
221 |
+
"EXCITED": 9,
|
222 |
+
"SLEEPY": 10,
|
223 |
+
"CONFUSED": 11,
|
224 |
+
"EMPHASIS": 12,
|
225 |
+
"LAUGHING": 13,
|
226 |
+
"SINGING": 14,
|
227 |
+
"WORRIED": 15,
|
228 |
+
"WHISPER": 16,
|
229 |
+
"ANXIOUS": 17,
|
230 |
+
"NO-AGREEMENT": 18,
|
231 |
+
"APOLOGETIC": 19,
|
232 |
+
"CONCERNED": 20,
|
233 |
+
"ENUNCIATED": 21,
|
234 |
+
"ASSERTIVE": 22,
|
235 |
+
"ENCOURAGING": 23,
|
236 |
+
"CONTEMPT": 24,
|
237 |
+
}
|
238 |
+
|
239 |
+
|
240 |
+
class TokenParser:
|
241 |
+
"""Turn label to special token"""
|
242 |
+
|
243 |
+
def __init__(self):
|
244 |
+
pass
|
245 |
+
|
246 |
+
"""Parse the attributes of a person."""
|
247 |
+
|
248 |
+
def __init__(self):
|
249 |
+
pass
|
250 |
+
|
251 |
+
@staticmethod
|
252 |
+
def age(age: str) -> str:
|
253 |
+
"""Turn age token."""
|
254 |
+
age_id = AGE_MAP[age]
|
255 |
+
return f"<|age_{age_id}|>"
|
256 |
+
|
257 |
+
@staticmethod
|
258 |
+
def gender(gender: str) -> str:
|
259 |
+
"""Turn gender token."""
|
260 |
+
gender_id = GENDER_MAP[gender]
|
261 |
+
return f"<|gender_{gender_id}|>"
|
262 |
+
|
263 |
+
@staticmethod
|
264 |
+
def mel_value(mel: int):
|
265 |
+
"""Turn special token of mel scale pitch."""
|
266 |
+
mel = max(0, int(mel))
|
267 |
+
mel = min(1000, int(mel))
|
268 |
+
return f"<|pitch_value_{mel}|>"
|
269 |
+
|
270 |
+
@staticmethod
|
271 |
+
def mel_level(level: str):
|
272 |
+
"""Turn special token of mel level."""
|
273 |
+
level_tag = LEVELS_MAP[level]
|
274 |
+
return f"<|pitch_label_{level_tag}|>"
|
275 |
+
|
276 |
+
@staticmethod
|
277 |
+
def pitch_var_value(pitch_std: int):
|
278 |
+
"""Turn special token of pitch_std value."""
|
279 |
+
assert isinstance(pitch_std, int)
|
280 |
+
pitch_std = max(0, int(pitch_std))
|
281 |
+
pitch_std = min(10, int(pitch_std))
|
282 |
+
return f"<|pitch_var_value_{pitch_std}|>"
|
283 |
+
|
284 |
+
@staticmethod
|
285 |
+
def pitch_var_level(level: str):
|
286 |
+
"""Turn special token of pitch std level."""
|
287 |
+
level_tag = LEVELS_MAP[level]
|
288 |
+
return f"<|pitch_var_label_{level_tag}|>"
|
289 |
+
|
290 |
+
@staticmethod
|
291 |
+
def loudness_value(loudness: int):
|
292 |
+
"""Turn special toak of loudness value [0, 30]"""
|
293 |
+
assert loudness >= 0
|
294 |
+
loudness = max(0, int(loudness))
|
295 |
+
loudness = min(30, int(loudness))
|
296 |
+
return f"<|loudness_value_{loudness}|>"
|
297 |
+
|
298 |
+
@staticmethod
|
299 |
+
def loudness_level(level: str):
|
300 |
+
"""Turn special token of loudness level."""
|
301 |
+
level_tag = LEVELS_MAP[level]
|
302 |
+
return f"<|loudness_label_{level_tag}|>"
|
303 |
+
|
304 |
+
@staticmethod
|
305 |
+
def speed_value(speed: int):
|
306 |
+
"""Turn special token of speed value."""
|
307 |
+
speed = max(0, int(speed))
|
308 |
+
speed = min(10, int(speed))
|
309 |
+
return f"<|speed_value_{speed}|>"
|
310 |
+
|
311 |
+
@staticmethod
|
312 |
+
def speed_level(level: str):
|
313 |
+
"""Turn special token of speed level."""
|
314 |
+
level_tag = LEVELS_MAP[level]
|
315 |
+
return f"<|speed_label_{level_tag}|>"
|
316 |
+
|
317 |
+
@staticmethod
|
318 |
+
def task(task: str) -> str:
|
319 |
+
"""Turn special token of task."""
|
320 |
+
assert task in TASK_TOKEN_MAP.keys()
|
321 |
+
|
322 |
+
return TASK_TOKEN_MAP[task]
|
323 |
+
|
324 |
+
@staticmethod
|
325 |
+
def emotion(emotion: str):
|
326 |
+
emo_id = EMO_MAP[emotion]
|
327 |
+
|
328 |
+
return f"<|emotion_{emo_id}|>"
|
329 |
+
|
330 |
+
# =============================================================================
|
331 |
+
# >> END: PASTE CODE FROM sparktts/utils/* HERE <<
|
332 |
+
# =============================================================================
|
333 |
+
|
334 |
+
|
335 |
+
class SparkTTSProcessor(ProcessorMixin, PushToHubMixin): # Added PushToHubMixin
|
336 |
+
r"""
|
337 |
+
Constructs a SparkTTS processor which wraps a text tokenizer and relevant audio processing logic.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
tokenizer ([`PreTrainedTokenizer`]):
|
341 |
+
An instance of [`PreTrainedTokenizer`]. This handles the text tokenization for the LLM.
|
342 |
+
feature_extractor ([`Wav2Vec2FeatureExtractor`]):
|
343 |
+
An instance of [`Wav2Vec2FeatureExtractor`]. Although Wav2Vec2 features are extracted
|
344 |
+
within the model's `tokenize_audio`, the extractor's configuration (like sampling rate)
|
345 |
+
is useful, and it aligns with the ProcessorMixin pattern.
|
346 |
+
config ([`SparkTTSConfig`], *optional*):
|
347 |
+
An instance of [`SparkTTSConfig`] to access configuration parameters like sample rate.
|
348 |
+
"""
|
349 |
+
attributes = ["tokenizer", "feature_extractor"]
|
350 |
+
tokenizer_class = "AutoTokenizer"
|
351 |
+
feature_extractor_class = "Wav2Vec2FeatureExtractor" # Keep for consistency
|
352 |
+
|
353 |
+
def __init__(self, tokenizer, feature_extractor, config: Optional[SparkTTSConfig] = None, **kwargs):
|
354 |
+
super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, **kwargs)
|
355 |
+
self.model = None
|
356 |
+
self.config = config
|
357 |
+
# Set sampling rate
|
358 |
+
if config and hasattr(config, 'sample_rate'):
|
359 |
+
self.sampling_rate = config.sample_rate
|
360 |
+
elif feature_extractor and hasattr(feature_extractor, 'sampling_rate'):
|
361 |
+
self.sampling_rate = feature_extractor.sampling_rate
|
362 |
+
else:
|
363 |
+
self.sampling_rate = 16000
|
364 |
+
logger.warning(f"Could not determine sampling rate. Defaulting to {self.sampling_rate} Hz.")
|
365 |
+
|
366 |
+
# # Ensure tokenizer pad token
|
367 |
+
# if self.tokenizer.pad_token is None:
|
368 |
+
# if self.tokenizer.eos_token is not None:
|
369 |
+
# logger.warning("Tokenizer does not have a pad token. Setting pad_token to eos_token.")
|
370 |
+
# self.tokenizer.pad_token = self.tokenizer.eos_token
|
371 |
+
# else:
|
372 |
+
# logger.warning("Tokenizer lacks pad and eos token. Adding default pad token '<|pad|>'.")
|
373 |
+
# self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
|
374 |
+
|
375 |
+
def link_model(self, model):
|
376 |
+
"""Links the processor to a SparkTTSModel instance for audio processing calls."""
|
377 |
+
if not hasattr(model, 'tokenize_audio') or not hasattr(model, 'detokenize_audio'):
|
378 |
+
raise TypeError("The provided model instance does not have the required 'tokenize_audio' and 'detokenize_audio' methods.")
|
379 |
+
if not hasattr(model, 'config'):
|
380 |
+
logger.warning("Linked model does not have a 'config' attribute. Some processor functionalities might rely on it.")
|
381 |
+
|
382 |
+
self.model = model
|
383 |
+
logger.info("SparkTTSModel successfully linked to the processor.")
|
384 |
+
# Update sampling rate based on linked model's config if available
|
385 |
+
if hasattr(model, 'config') and hasattr(model.config, 'sample_rate'):
|
386 |
+
if self.sampling_rate != model.config.sample_rate:
|
387 |
+
logger.info(f"Updating processor sampling rate from {self.sampling_rate} to {model.config.sample_rate} based on linked model config.")
|
388 |
+
self.sampling_rate = model.config.sample_rate
|
389 |
+
# Also update feature extractor sampling rate if it differs
|
390 |
+
if hasattr(self, 'feature_extractor') and self.feature_extractor.sampling_rate != model.config.sample_rate:
|
391 |
+
logger.info(f"Updating feature_extractor sampling rate from {self.feature_extractor.sampling_rate} to {model.config.sample_rate}.")
|
392 |
+
self.feature_extractor.sampling_rate = model.config.sample_rate
|
393 |
+
|
394 |
+
|
395 |
+
def __call__(
|
396 |
+
self,
|
397 |
+
text: str,
|
398 |
+
prompt_speech_path: Optional[Union[str, Path]] = None,
|
399 |
+
prompt_text: Optional[str] = None,
|
400 |
+
gender: Optional[str] = None,
|
401 |
+
pitch: Optional[str] = None,
|
402 |
+
speed: Optional[str] = None,
|
403 |
+
return_tensors: Optional[str] = "pt",
|
404 |
+
**kwargs, # Allow passing other args like padding, truncation to tokenizer
|
405 |
+
) -> BatchEncoding:
|
406 |
+
"""
|
407 |
+
Processes the input text and optional prompt audio/control parameters into a format suitable for [`SparkTTSModel`].
|
408 |
+
|
409 |
+
Args:
|
410 |
+
text (`str`):
|
411 |
+
The main text to be synthesized.
|
412 |
+
prompt_speech_path (`str` or `Path`, *optional*):
|
413 |
+
Path to the prompt audio file for voice cloning. Required if `gender` is not set.
|
414 |
+
prompt_text (`str`, *optional*):
|
415 |
+
Transcript of the prompt audio. Used only in voice cloning mode.
|
416 |
+
gender (`str`, *optional*):
|
417 |
+
Target gender ("male" or "female") for controllable synthesis. If set, enables control mode.
|
418 |
+
pitch (`str`, *optional*):
|
419 |
+
Target pitch level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
|
420 |
+
speed (`str`, *optional*):
|
421 |
+
Target speed level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
|
422 |
+
return_tensors (`str`, *optional*, defaults to `"pt"`):
|
423 |
+
If set, will return tensors instead of list of python integers. Only "pt" (PyTorch) is supported currently.
|
424 |
+
**kwargs:
|
425 |
+
Additional arguments passed to the underlying tokenizer's `__call__` method.
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
[`BatchEncoding`]: A dictionary containing the `input_ids` and `attention_mask` for the LLM.
|
429 |
+
In voice cloning mode, it also includes `global_token_ids_prompt` (torch.Tensor) representing the
|
430 |
+
global tokens extracted from the prompt audio.
|
431 |
+
"""
|
432 |
+
|
433 |
+
global_token_ids_prompt = None # Initialize
|
434 |
+
|
435 |
+
# Determine mode: Control TTS or Voice Cloning (Prompt TTS)
|
436 |
+
is_control_mode = gender is not None
|
437 |
+
is_cloning_mode = prompt_speech_path is not None and not is_control_mode
|
438 |
+
|
439 |
+
if is_control_mode:
|
440 |
+
# --- Controllable TTS Prompt Construction ---
|
441 |
+
if not all([pitch, speed]):
|
442 |
+
raise ValueError("For controllable TTS, 'gender', 'pitch', and 'speed' must all be provided.")
|
443 |
+
if prompt_speech_path is not None:
|
444 |
+
logger.warning("`prompt_speech_path` provided but ignored because `gender` is set (controllable TTS mode).")
|
445 |
+
|
446 |
+
if not all(k in GENDER_MAP for k in [gender]): # Basic check
|
447 |
+
raise ValueError(f"Invalid gender provided: {gender}. Must be one of {list(GENDER_MAP.keys())}")
|
448 |
+
if not all(k in LEVELS_MAP for k in [pitch, speed]): # Basic check
|
449 |
+
raise ValueError(f"Invalid pitch or speed level provided. Must be one of {list(LEVELS_MAP.keys())}")
|
450 |
+
|
451 |
+
gender_id = GENDER_MAP[gender]
|
452 |
+
pitch_level_id = LEVELS_MAP[pitch]
|
453 |
+
speed_level_id = LEVELS_MAP[speed]
|
454 |
+
|
455 |
+
pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
|
456 |
+
speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
|
457 |
+
gender_tokens = f"<|gender_{gender_id}|>"
|
458 |
+
|
459 |
+
attribute_tokens = "".join([gender_tokens, pitch_label_tokens, speed_label_tokens])
|
460 |
+
|
461 |
+
prompt_list = [
|
462 |
+
TASK_TOKEN_MAP["controllable_tts"],
|
463 |
+
"<|start_content|>",
|
464 |
+
text,
|
465 |
+
"<|end_content|>",
|
466 |
+
"<|start_style_label|>",
|
467 |
+
attribute_tokens,
|
468 |
+
"<|end_style_label|>",
|
469 |
+
]
|
470 |
+
prompt_string = "".join(prompt_list)
|
471 |
+
|
472 |
+
elif is_cloning_mode:
|
473 |
+
# --- Voice Cloning Prompt Construction ---
|
474 |
+
if self.model is None:
|
475 |
+
raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before performing voice cloning.")
|
476 |
+
prompt_speech_path = Path(prompt_speech_path) # Ensure it's a Path object
|
477 |
+
if not prompt_speech_path.exists():
|
478 |
+
raise FileNotFoundError(f"Prompt audio file not found: {prompt_speech_path}")
|
479 |
+
|
480 |
+
# Load and process prompt audio
|
481 |
+
try:
|
482 |
+
model_config = self.model.config if self.model and hasattr(self.model, 'config') else self.config
|
483 |
+
if model_config is None:
|
484 |
+
raise ValueError("Configuration not available in processor or linked model.")
|
485 |
+
|
486 |
+
# Load main wav
|
487 |
+
wav = load_audio(
|
488 |
+
prompt_speech_path,
|
489 |
+
sampling_rate=self.sampling_rate,
|
490 |
+
volume_normalize=getattr(model_config, 'volume_normalize', True), # Use getattr for safety
|
491 |
+
)
|
492 |
+
# Get reference clip
|
493 |
+
wav_ref_np = get_ref_clip(wav, model_config) # Pass config object
|
494 |
+
wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float()
|
495 |
+
wav_tensor = torch.from_numpy(wav).unsqueeze(0).float()
|
496 |
+
|
497 |
+
# Tokenize using the linked model's method
|
498 |
+
# Assuming tokenize_audio returns tensors with batch dim 1: [1, N_global], [1, N_semantic]
|
499 |
+
global_tokens_tensor, semantic_tokens_tensor = self.model.tokenize_audio(wav_tensor, wav_ref)
|
500 |
+
|
501 |
+
# Store the global tokens tensor (with batch dim) for the output dict
|
502 |
+
global_token_ids_prompt = global_tokens_tensor # Keep batch dim [1, N_global]
|
503 |
+
|
504 |
+
# Convert tensors to lists of ints for string formatting
|
505 |
+
global_token_list = global_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
|
506 |
+
semantic_token_list = semantic_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
|
507 |
+
|
508 |
+
except Exception as e:
|
509 |
+
logger.error(f"Error processing prompt audio {prompt_speech_path}: {e}")
|
510 |
+
import traceback
|
511 |
+
traceback.print_exc()
|
512 |
+
raise
|
513 |
+
|
514 |
+
# ==============================================================
|
515 |
+
# CORRECTED TOKEN STRING FORMATTING
|
516 |
+
# ==============================================================
|
517 |
+
# Create individual token strings for each ID
|
518 |
+
global_tokens_str = "".join([f"<|bicodec_global_{gid}|>" for gid in global_token_list])
|
519 |
+
semantic_tokens_str = "".join([f"<|bicodec_semantic_{sid}|>" for sid in semantic_token_list])
|
520 |
+
# ==============================================================
|
521 |
+
|
522 |
+
# Construct prompt list based on presence of prompt_text
|
523 |
+
if prompt_text is not None and prompt_text.strip(): # Check if prompt_text is meaningful
|
524 |
+
logger.info("Using prompt text in voice cloning prompt.")
|
525 |
+
prompt_list = [
|
526 |
+
TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]? Check original logic. Assuming "tts".
|
527 |
+
"<|start_content|>",
|
528 |
+
prompt_text, # Transcript first
|
529 |
+
text, # Then target text
|
530 |
+
"<|end_content|>",
|
531 |
+
"<|start_global_token|>",
|
532 |
+
global_tokens_str,
|
533 |
+
"<|end_global_token|>",
|
534 |
+
"<|start_semantic_token|>",
|
535 |
+
semantic_tokens_str,
|
536 |
+
# "<|end_semantic_token|>", # Original code didn't have this marker here
|
537 |
+
]
|
538 |
+
else:
|
539 |
+
# Simpler prompt without semantic tokens if no transcript provided
|
540 |
+
logger.info("No prompt text provided, using text-only voice cloning prompt.")
|
541 |
+
prompt_list = [
|
542 |
+
TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]?
|
543 |
+
"<|start_content|>",
|
544 |
+
text, # Only target text
|
545 |
+
"<|end_content|>",
|
546 |
+
"<|start_global_token|>",
|
547 |
+
global_tokens_str,
|
548 |
+
"<|end_global_token|>",
|
549 |
+
]
|
550 |
+
prompt_string = "".join(prompt_list)
|
551 |
+
logger.debug(f"Generated prompt string (cloning): {prompt_string[:200]}...") # Log start of prompt
|
552 |
+
|
553 |
+
else:
|
554 |
+
raise ValueError("Invalid input combination. Either provide `prompt_speech_path` for cloning or (`gender`, `pitch`, `speed`) for control.")
|
555 |
+
|
556 |
+
# --- Tokenize the final prompt string ---
|
557 |
+
# print(f"Tokenizing prompt: {prompt_string}")
|
558 |
+
inputs = self.tokenizer(
|
559 |
+
prompt_string,
|
560 |
+
return_tensors=return_tensors,
|
561 |
+
padding=kwargs.get("padding", False), # Often False for generation prompts unless batching > 1
|
562 |
+
truncation=kwargs.get("truncation", True),
|
563 |
+
max_length=kwargs.get("max_length", self.tokenizer.model_max_length),
|
564 |
+
add_special_tokens=kwargs.get("add_special_tokens", True), # Usually True unless handled manually
|
565 |
+
return_attention_mask=kwargs.get("return_attention_mask", True), # Need attention mask
|
566 |
+
**{k: v for k, v in kwargs.items() if k not in ["padding", "truncation", "max_length", "add_special_tokens", "return_attention_mask"]}
|
567 |
+
)
|
568 |
+
logger.debug(f"Tokenized input_ids shape: {inputs['input_ids'].shape}")
|
569 |
+
|
570 |
+
|
571 |
+
# Add the prompt's global tokens (as tensor with batch dim) to the output if in cloning mode
|
572 |
+
if is_cloning_mode and global_token_ids_prompt is not None:
|
573 |
+
if return_tensors == "pt":
|
574 |
+
inputs["global_token_ids_prompt"] = global_token_ids_prompt # Already has batch dim [1, N_global]
|
575 |
+
else:
|
576 |
+
# Handle non-tensor return if necessary
|
577 |
+
inputs["global_token_ids_prompt"] = global_token_ids_prompt.tolist()
|
578 |
+
|
579 |
+
return inputs
|
580 |
+
|
581 |
+
|
582 |
+
def decode(
|
583 |
+
self,
|
584 |
+
generated_ids: torch.Tensor,
|
585 |
+
global_token_ids_prompt: Optional[torch.Tensor] = None,
|
586 |
+
input_ids_len: Optional[int] = None,
|
587 |
+
skip_special_tokens: bool = True,
|
588 |
+
) -> Dict[str, Any]:
|
589 |
+
"""
|
590 |
+
Decodes the generated token IDs from [`SparkTTSModel`] into an audio waveform.
|
591 |
+
|
592 |
+
Args:
|
593 |
+
generated_ids (`torch.Tensor`):
|
594 |
+
Tensor of token IDs generated by `model.generate()`, including the input prompt part. Shape [B, seq_len].
|
595 |
+
global_token_ids_prompt (`torch.Tensor`, *optional*):
|
596 |
+
The global tokens extracted from the prompt audio during the `__call__` step (for voice cloning).
|
597 |
+
Shape [B, N_global]. Required if the generation was for voice cloning.
|
598 |
+
input_ids_len (`int`, *optional*):
|
599 |
+
The length of the original input prompt `input_ids` fed to `model.generate()`. Required to
|
600 |
+
correctly isolate the newly generated tokens.
|
601 |
+
skip_special_tokens (`bool`, *optional*, defaults to `True`):
|
602 |
+
Whether to skip special tokens during the text decoding step (used to extract audio tokens).
|
603 |
+
|
604 |
+
Returns:
|
605 |
+
Dict[str, Any]: A dictionary containing:
|
606 |
+
- "audio": The decoded audio waveform as a NumPy array. Shape [T_audio] (if B=1) or [B, T_audio].
|
607 |
+
- "sampling_rate": The sampling rate of the audio.
|
608 |
+
"""
|
609 |
+
if self.model is None:
|
610 |
+
raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before decoding.")
|
611 |
+
if input_ids_len is None:
|
612 |
+
raise ValueError("`input_ids_len` (length of the prompt input_ids) must be provided for decoding.")
|
613 |
+
|
614 |
+
# --- Isolate generated part and decode text ---
|
615 |
+
# Assumes generated_ids has shape [B, full_seq_len]
|
616 |
+
# Handle case where generated sequence is shorter than prompt (shouldn't happen with max_new_tokens > 0)
|
617 |
+
if generated_ids.shape[1] < input_ids_len:
|
618 |
+
logger.warning(f"Generated sequence length ({generated_ids.shape[1]}) is shorter than input prompt length ({input_ids_len}). Decoding might be incorrect.")
|
619 |
+
output_only_ids = generated_ids[:, input_ids_len:] # Will be empty if equal
|
620 |
+
else:
|
621 |
+
output_only_ids = generated_ids[:, input_ids_len:]
|
622 |
+
|
623 |
+
|
624 |
+
# Decode the generated part to find audio tokens
|
625 |
+
# Need to handle batch decoding if B > 1
|
626 |
+
# print("decode token", self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=False))
|
627 |
+
decoded_texts = self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=skip_special_tokens)
|
628 |
+
|
629 |
+
# --- Extract Audio Tokens ---
|
630 |
+
# Handle batch processing correctly
|
631 |
+
batch_size = generated_ids.shape[0]
|
632 |
+
all_semantic_ids = []
|
633 |
+
all_global_tokens = []
|
634 |
+
successful_indices = [] # Keep track of which batch items were successful
|
635 |
+
|
636 |
+
for i in range(batch_size):
|
637 |
+
decoded_text = decoded_texts[i]
|
638 |
+
current_semantic_ids = None
|
639 |
+
current_global_tokens = None
|
640 |
+
|
641 |
+
# Extract semantic tokens
|
642 |
+
try:
|
643 |
+
pred_semantic_indices = [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", decoded_text)]
|
644 |
+
if not pred_semantic_indices:
|
645 |
+
logger.warning(f"Batch item {i}: No semantic tokens found in decoded text: '{decoded_text[:200]}...'")
|
646 |
+
continue # Skip this item
|
647 |
+
|
648 |
+
current_semantic_ids = torch.tensor(pred_semantic_indices).long() # Shape [N_semantic]
|
649 |
+
except Exception as e:
|
650 |
+
logger.error(f"Batch item {i}: Error parsing semantic tokens from: '{decoded_text[:200]}...'. Error: {e}")
|
651 |
+
continue # Skip this item
|
652 |
+
|
653 |
+
# Determine global tokens
|
654 |
+
if global_token_ids_prompt is not None:
|
655 |
+
# Cloning mode: Use the provided prompt global tokens for this batch item
|
656 |
+
if global_token_ids_prompt.shape[0] != batch_size:
|
657 |
+
raise ValueError(f"Batch size mismatch: generated_ids has {batch_size}, but global_token_ids_prompt has {global_token_ids_prompt.shape[0]}.")
|
658 |
+
current_global_tokens = global_token_ids_prompt[i] # Shape [N_global]
|
659 |
+
else:
|
660 |
+
# Control mode: Extract global tokens from the generated text
|
661 |
+
try:
|
662 |
+
pred_global_indices = [int(token) for token in re.findall(r"bicodec_global_(\d+)", decoded_text)]
|
663 |
+
if not pred_global_indices:
|
664 |
+
logger.warning(f"Batch item {i}: No global tokens found in decoded text for control mode: '{decoded_text[:200]}...'")
|
665 |
+
continue # Skip this item
|
666 |
+
|
667 |
+
current_global_tokens = torch.tensor(pred_global_indices).long() # Shape [N_global]
|
668 |
+
|
669 |
+
except Exception as e:
|
670 |
+
logger.error(f"Batch item {i}: Error parsing global tokens from: '{decoded_text[:200]}...'. Error: {e}")
|
671 |
+
continue # Skip this item
|
672 |
+
|
673 |
+
# If both tokens extracted successfully
|
674 |
+
all_semantic_ids.append(current_semantic_ids)
|
675 |
+
all_global_tokens.append(current_global_tokens)
|
676 |
+
successful_indices.append(i)
|
677 |
+
|
678 |
+
if not successful_indices:
|
679 |
+
logger.error("Failed to extract audio tokens for any item in the batch.")
|
680 |
+
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
|
681 |
+
|
682 |
+
# Pad sequences to the max length within the successful batch items for batch detokenization
|
683 |
+
# Note: BiCodec might not support batching if sequences have different lengths. Check its implementation.
|
684 |
+
# Assuming BiCodec *can* handle batches if padded (or if lengths are naturally equal).
|
685 |
+
# This padding might be unnecessary if BiCodec handles variable lengths or if B=1 anyway.
|
686 |
+
# For now, let's assume B=1 was handled correctly and skip complex padding.
|
687 |
+
if batch_size > 1 and len(successful_indices) < batch_size:
|
688 |
+
logger.warning(f"Only successfully decoded {len(successful_indices)} out of {batch_size} batch items.")
|
689 |
+
# Further processing might need to handle only the successful items.
|
690 |
+
|
691 |
+
# Let's proceed assuming B=1 or BiCodec handles batches appropriately.
|
692 |
+
# Stack the successful tokens.
|
693 |
+
try:
|
694 |
+
# Need to ensure tensors have the same length before stacking if BiCodec requires it.
|
695 |
+
# If BiCodec handles variable length, stacking might not be needed, just loop and call detokenize.
|
696 |
+
# Let's assume B=1 for simplicity of the example, matching original code's likely behavior.
|
697 |
+
if len(successful_indices) != 1:
|
698 |
+
raise NotImplementedError("Batch decoding (B > 1) requires verification of BiCodec's batch handling and potentially padding.")
|
699 |
+
|
700 |
+
final_semantic_ids = all_semantic_ids[0].unsqueeze(0) # Add batch dim [1, N_semantic]
|
701 |
+
final_global_tokens = all_global_tokens[0].unsqueeze(0) # Add batch dim [1, N_global]
|
702 |
+
|
703 |
+
except IndexError: # Should not happen if successful_indices is not empty
|
704 |
+
logger.error("Internal error during token batch preparation.")
|
705 |
+
return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
|
706 |
+
|
707 |
+
|
708 |
+
# --- Detokenize Audio ---
|
709 |
+
try:
|
710 |
+
# Call the linked model's detokenize method
|
711 |
+
# print(f"DEBUG: Detokenizing audio with global tokens {final_global_tokens.shape}, semantic tokens {final_semantic_ids.shape}")
|
712 |
+
output_wav = self.model.detokenize_audio(final_global_tokens, final_semantic_ids)
|
713 |
+
# detokenize_audio now returns numpy array float32 in [-1, 1]
|
714 |
+
|
715 |
+
# Optional: Double-check dtype here if needed, but should be handled by detokenize_audio now
|
716 |
+
# if output_wav.dtype != np.float32:
|
717 |
+
# logger.warning(f"Audio dtype after detokenize is {output_wav.dtype}. Converting to float32.")
|
718 |
+
# output_wav = output_wav.astype(np.float32)
|
719 |
+
# output_wav = np.clip(output_wav, -1.0, 1.0) # Clipping done in detokenize_audio
|
720 |
+
|
721 |
+
except Exception as e:
|
722 |
+
logger.error(f"Error during audio detokenization: {e}")
|
723 |
+
import traceback
|
724 |
+
traceback.print_exc()
|
725 |
+
raise RuntimeError("Audio detokenization failed.") from e
|
726 |
+
|
727 |
+
return {"audio": output_wav, "sampling_rate": self.sampling_rate}
|
728 |
+
|
729 |
+
|
730 |
+
@classmethod
|
731 |
+
def from_pretrained(
|
732 |
+
cls,
|
733 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
734 |
+
cache_dir: Optional[Union[str, os.PathLike]] = None,
|
735 |
+
force_download: bool = False,
|
736 |
+
local_files_only: bool = False,
|
737 |
+
token: Optional[Union[str, bool]] = None,
|
738 |
+
revision: str = "main",
|
739 |
+
trust_remote_code: bool = False, # Allow passing this, needed for config potentially
|
740 |
+
**kwargs,
|
741 |
+
):
|
742 |
+
r"""
|
743 |
+
Instantiate a SparkTTSProcessor from pretrained components.
|
744 |
+
"""
|
745 |
+
# Pop specific kwargs for this method
|
746 |
+
config = kwargs.pop("config", None) # Allow passing config explicitly
|
747 |
+
|
748 |
+
# --- 1. Load Config (to find component paths) ---
|
749 |
+
# We need the config even if the processor doesn't store it permanently,
|
750 |
+
# just to find where the tokenizer/feature_extractor live.
|
751 |
+
loaded_config = None
|
752 |
+
if not isinstance(config, SparkTTSConfig):
|
753 |
+
try:
|
754 |
+
# Load the specific config class
|
755 |
+
loaded_config = SparkTTSConfig.from_pretrained(
|
756 |
+
pretrained_model_name_or_path,
|
757 |
+
cache_dir=cache_dir,
|
758 |
+
force_download=force_download,
|
759 |
+
local_files_only=local_files_only,
|
760 |
+
token=token,
|
761 |
+
revision=revision,
|
762 |
+
trust_remote_code=trust_remote_code, # Config might be custom
|
763 |
+
**kwargs, # Pass relevant kwargs
|
764 |
+
)
|
765 |
+
except Exception as e:
|
766 |
+
logger.warning(
|
767 |
+
f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. "
|
768 |
+
f"Attempting to load components from default relative paths ('LLM', 'wav2vec2-large-xlsr-53'). Error: {e}"
|
769 |
+
)
|
770 |
+
loaded_config = None # Fallback
|
771 |
+
else:
|
772 |
+
# Config object was passed directly
|
773 |
+
loaded_config = config
|
774 |
+
|
775 |
+
|
776 |
+
# --- 2. Determine Component Paths ---
|
777 |
+
llm_tokenizer_path_or_id = "./LLM" # Default relative path
|
778 |
+
w2v_processor_path_or_id = "./wav2vec2-large-xlsr-53" # Default relative path
|
779 |
+
|
780 |
+
if loaded_config:
|
781 |
+
llm_tokenizer_path_or_id = getattr(loaded_config, 'llm_model_name_or_path', llm_tokenizer_path_or_id)
|
782 |
+
w2v_processor_path_or_id = getattr(loaded_config, 'wav2vec2_model_name_or_path', w2v_processor_path_or_id)
|
783 |
+
|
784 |
+
# The component `from_pretrained` methods handle resolving these paths/IDs
|
785 |
+
# whether they are relative subfolders of `pretrained_model_name_or_path`
|
786 |
+
# or separate Hub IDs.
|
787 |
+
|
788 |
+
# --- 3. Load Components ---
|
789 |
+
# Pass down relevant kwargs for loading components
|
790 |
+
component_loading_kwargs = {
|
791 |
+
"cache_dir": cache_dir,
|
792 |
+
"force_download": force_download,
|
793 |
+
"local_files_only": local_files_only,
|
794 |
+
"token": token,
|
795 |
+
"revision": revision,
|
796 |
+
**kwargs # Pass other user kwargs
|
797 |
+
}
|
798 |
+
try:
|
799 |
+
# Tokenizer might require trust_remote_code if its class is custom
|
800 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
801 |
+
pretrained_model_name_or_path, # Main path
|
802 |
+
subfolder=llm_tokenizer_path_or_id.lstrip('./'), # Specify subfolder relative to main path
|
803 |
+
trust_remote_code=trust_remote_code,
|
804 |
+
**component_loading_kwargs
|
805 |
+
)
|
806 |
+
except Exception as e:
|
807 |
+
# Fallback: try loading directly using the path/id from config if different
|
808 |
+
if llm_tokenizer_path_or_id != "./LLM":
|
809 |
+
try:
|
810 |
+
logger.info(f"Retrying tokenizer load directly from: {llm_tokenizer_path_or_id}")
|
811 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
812 |
+
llm_tokenizer_path_or_id,
|
813 |
+
trust_remote_code=trust_remote_code,
|
814 |
+
**component_loading_kwargs
|
815 |
+
)
|
816 |
+
except Exception as e2:
|
817 |
+
raise OSError(f"Could not load tokenizer using main path + subfolder or directly from '{llm_tokenizer_path_or_id}'. Error: {e2}") from e
|
818 |
+
else:
|
819 |
+
raise OSError(f"Could not load tokenizer from subfolder '{llm_tokenizer_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
|
820 |
+
|
821 |
+
|
822 |
+
try:
|
823 |
+
# Feature extractor usually doesn't need trust_remote_code
|
824 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
825 |
+
pretrained_model_name_or_path, # Main path
|
826 |
+
subfolder=w2v_processor_path_or_id.lstrip('./'), # Specify subfolder relative to main path
|
827 |
+
**component_loading_kwargs
|
828 |
+
)
|
829 |
+
except Exception as e:
|
830 |
+
# Fallback: try loading directly using the path/id from config if different
|
831 |
+
if w2v_processor_path_or_id != "./wav2vec2-large-xlsr-53":
|
832 |
+
try:
|
833 |
+
logger.info(f"Retrying feature extractor load directly from: {w2v_processor_path_or_id}")
|
834 |
+
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
835 |
+
w2v_processor_path_or_id,
|
836 |
+
**component_loading_kwargs
|
837 |
+
)
|
838 |
+
except Exception as e2:
|
839 |
+
raise OSError(f"Could not load feature extractor using main path + subfolder or directly from '{w2v_processor_path_or_id}'. Error: {e2}") from e
|
840 |
+
else:
|
841 |
+
raise OSError(f"Could not load feature extractor from subfolder '{w2v_processor_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
|
842 |
+
|
843 |
+
|
844 |
+
# --- 4. Instantiate processor ---
|
845 |
+
# Pass the potentially loaded config object (or None)
|
846 |
+
return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=loaded_config)
|
847 |
+
|
848 |
+
|
849 |
+
def save_pretrained(
|
850 |
+
self,
|
851 |
+
save_directory: Union[str, os.PathLike],
|
852 |
+
push_to_hub: bool = False,
|
853 |
+
**kwargs,
|
854 |
+
):
|
855 |
+
"""
|
856 |
+
Save the processor's state (tokenizer and feature extractor files) to a directory.
|
857 |
+
|
858 |
+
Args:
|
859 |
+
save_directory (`str` or `os.PathLike`):
|
860 |
+
Directory where the processor files will be saved.
|
861 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
862 |
+
Whether or not to push your model to the Hugging Face Hub after saving it.
|
863 |
+
**kwargs:
|
864 |
+
Additional key word arguments passed along to the `push_to_hub` method.
|
865 |
+
"""
|
866 |
+
save_directory = Path(save_directory)
|
867 |
+
save_directory.mkdir(parents=True, exist_ok=True)
|
868 |
+
|
869 |
+
# Save tokenizer
|
870 |
+
self.tokenizer.save_pretrained(str(save_directory), **kwargs)
|
871 |
+
|
872 |
+
# Save feature extractor
|
873 |
+
self.feature_extractor.save_pretrained(str(save_directory), **kwargs)
|
874 |
+
|
875 |
+
# Save the main processor config (if it exists and has relevant info)
|
876 |
+
# Note: The SparkTTSConfig is usually saved with the *model*, not the processor.
|
877 |
+
# However, if the processor holds specific config needed for reloading *itself*,
|
878 |
+
# it could be saved here. Usually, relying on the model's config is sufficient.
|
879 |
+
# if self.config:
|
880 |
+
# self.config.save_pretrained(str(save_directory)) # Example if needed
|
881 |
+
|
882 |
+
logger.info(f"Processor components saved in {save_directory}")
|
883 |
+
|
884 |
+
if push_to_hub:
|
885 |
+
# Commit message and other hub kwargs can be passed via **kwargs
|
886 |
+
commit_message = kwargs.pop("commit_message", "Save processor")
|
887 |
+
return self.push_to_hub(save_directory, commit_message=commit_message, **kwargs)
|
888 |
+
|
889 |
+
return str(save_directory) # Return path consistent with Mixin
|