ancv commited on
Commit
f6c1e7e
·
verified ·
1 Parent(s): 1de7b9d

Upload 4 files

Browse files
config.json ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "spark-tts",
3
+ "architectures": [
4
+ "SparkTTSModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_spark_tts.SparkTTSConfig",
8
+ "AutoModel": "modeling_spark_tts.SparkTTSModel",
9
+ "AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
10
+ },
11
+ "processor_class": "processing_spark_tts.SparkTTSProcessor",
12
+ "llm_model_name_or_path": "./LLM",
13
+ "bicodec_model_name_or_path": "./BiCodec",
14
+ "wav2vec2_model_name_or_path": "./wav2vec2-large-xlsr-53",
15
+ "sample_rate": 16000,
16
+ "highpass_cutoff_freq": 40,
17
+ "latent_hop_length": 320,
18
+ "ref_segment_duration": 6.0,
19
+ "volume_normalize": true,
20
+ "torch_dtype": "bfloat16",
21
+ "transformers_version": "4.43.1",
22
+ "_commit_hash": null,
23
+ "bicodec_config": {
24
+ "mel_params": {
25
+ "sample_rate": 16000,
26
+ "n_fft": 1024,
27
+ "win_length": 640,
28
+ "hop_length": 320,
29
+ "mel_fmin": 10,
30
+ "mel_fmax": null,
31
+ "num_mels": 128
32
+ },
33
+ "encoder_config": {
34
+ "input_channels": 1024,
35
+ "vocos_dim": 384,
36
+ "vocos_intermediate_dim": 2048,
37
+ "vocos_num_layers": 12,
38
+ "out_channels": 1024,
39
+ "sample_ratios": [1, 1]
40
+ },
41
+ "decoder_config": {
42
+ "input_channel": 1024,
43
+ "channels": 1536,
44
+ "rates": [8, 5, 4, 2],
45
+ "kernel_sizes": [16, 11, 8, 4]
46
+ },
47
+ "quantizer_config": {
48
+ "input_dim": 1024,
49
+ "codebook_size": 8192,
50
+ "codebook_dim": 8,
51
+ "commitment": 0.25,
52
+ "codebook_loss_weight": 2.0,
53
+ "decay": 0.99,
54
+ "threshold_ema_dead_code": 0.2
55
+ },
56
+ "speaker_encoder_config": {
57
+ "input_dim": 128,
58
+ "out_dim": 1024,
59
+ "latent_dim": 128,
60
+ "token_num": 32,
61
+ "fsq_levels": [4, 4, 4, 4, 4, 4],
62
+ "fsq_num_quantizers": 1
63
+ },
64
+ "prenet_config": {
65
+ "input_channels": 1024,
66
+ "vocos_dim": 384,
67
+ "vocos_intermediate_dim": 2048,
68
+ "vocos_num_layers": 12,
69
+ "out_channels": 1024,
70
+ "condition_dim": 1024,
71
+ "sample_ratios": [1, 1],
72
+ "use_tanh_at_final": false
73
+ },
74
+ "postnet_config": {
75
+ "input_channels": 1024,
76
+ "vocos_dim": 384,
77
+ "vocos_intermediate_dim": 2048,
78
+ "vocos_num_layers": 6,
79
+ "out_channels": 1024,
80
+ "use_tanh_at_final": false
81
+ }
82
+ }
83
+ }
configuration_spark_tts.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 SparkAudio & The HuggingFace Inc. team. All rights reserved.
3
+ # ... (License headers remain the same) ...
4
+ """ SparkTTS model configuration"""
5
+
6
+ from transformers.configuration_utils import PretrainedConfig
7
+ from transformers.utils import logging
8
+ from typing import List, Optional # Added typing
9
+
10
+
11
+ logger = logging.get_logger(__name__)
12
+
13
+ # --- Define Individual Sub-Component Config Classes ---
14
+
15
+ class SparkTTSMelParamsConfig(PretrainedConfig):
16
+ """Configuration for Mel Spectrogram parameters."""
17
+ model_type = "spark-tts-mel-params"
18
+ def __init__(self, sample_rate=16000, n_fft=1024, win_length=640, hop_length=320,
19
+ mel_fmin=10, mel_fmax=None, num_mels=128, **kwargs):
20
+ super().__init__(**kwargs)
21
+ self.sample_rate = sample_rate
22
+ self.n_fft = n_fft
23
+ self.win_length = win_length
24
+ self.hop_length = hop_length
25
+ self.mel_fmin = mel_fmin
26
+ self.mel_fmax = mel_fmax
27
+ self.num_mels = num_mels
28
+
29
+ class SparkTTSEncoderConfig(PretrainedConfig):
30
+ """Configuration for the BiCodec Feature Encoder."""
31
+ model_type = "spark-tts-encoder"
32
+ def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
33
+ vocos_num_layers=12, out_channels=1024, sample_ratios=[1, 1], **kwargs):
34
+ super().__init__(**kwargs)
35
+ self.input_channels = input_channels
36
+ self.vocos_dim = vocos_dim
37
+ self.vocos_intermediate_dim = vocos_intermediate_dim
38
+ self.vocos_num_layers = vocos_num_layers
39
+ self.out_channels = out_channels
40
+ self.sample_ratios = sample_ratios
41
+
42
+ class SparkTTSDecoderConfig(PretrainedConfig):
43
+ """Configuration for the BiCodec Wave Generator (Decoder)."""
44
+ model_type = "spark-tts-decoder"
45
+ def __init__(self, input_channel=1024, channels=1536, rates=[8, 5, 4, 2],
46
+ kernel_sizes=[16, 11, 8, 4], **kwargs):
47
+ super().__init__(**kwargs)
48
+ self.input_channel = input_channel
49
+ self.channels = channels
50
+ self.rates = rates
51
+ self.kernel_sizes = kernel_sizes
52
+
53
+ class SparkTTSQuantizerConfig(PretrainedConfig):
54
+ """Configuration for the BiCodec Factorized Vector Quantizer."""
55
+ model_type = "spark-tts-quantizer"
56
+ def __init__(self, input_dim=1024, codebook_size=8192, codebook_dim=8,
57
+ commitment=0.25, codebook_loss_weight=2.0, decay=0.99,
58
+ threshold_ema_dead_code=0.2, **kwargs):
59
+ # Note: Removed use_l2_normlize as it wasn't in the original class __init__ args
60
+ # Add it back if it's actually used by the FactorizedVectorQuantize class init
61
+ super().__init__(**kwargs)
62
+ self.input_dim = input_dim
63
+ self.codebook_size = codebook_size
64
+ self.codebook_dim = codebook_dim
65
+ self.commitment = commitment
66
+ self.codebook_loss_weight = codebook_loss_weight
67
+ self.decay = decay
68
+ self.threshold_ema_dead_code = threshold_ema_dead_code
69
+
70
+ class SparkTTSSpeakerEncoderConfig(PretrainedConfig):
71
+ """Configuration for the BiCodec Speaker Encoder."""
72
+ model_type = "spark-tts-speaker-encoder"
73
+ def __init__(self, input_dim=128, out_dim=1024, latent_dim=128, token_num=32,
74
+ fsq_levels=[4, 4, 4, 4, 4, 4], fsq_num_quantizers=1, **kwargs):
75
+ super().__init__(**kwargs)
76
+ self.input_dim = input_dim
77
+ self.out_dim = out_dim
78
+ self.latent_dim = latent_dim
79
+ self.token_num = token_num
80
+ self.fsq_levels = fsq_levels
81
+ self.fsq_num_quantizers = fsq_num_quantizers
82
+
83
+ class SparkTTSPrenetConfig(PretrainedConfig):
84
+ """Configuration for the BiCodec Prenet."""
85
+ model_type = "spark-tts-prenet"
86
+ def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
87
+ vocos_num_layers=12, out_channels=1024, condition_dim=1024,
88
+ sample_ratios=[1, 1], use_tanh_at_final=False, **kwargs):
89
+ super().__init__(**kwargs)
90
+ self.input_channels = input_channels
91
+ self.vocos_dim = vocos_dim
92
+ self.vocos_intermediate_dim = vocos_intermediate_dim
93
+ self.vocos_num_layers = vocos_num_layers
94
+ self.out_channels = out_channels
95
+ self.condition_dim = condition_dim
96
+ self.sample_ratios = sample_ratios
97
+ self.use_tanh_at_final = use_tanh_at_final
98
+
99
+ class SparkTTSPostnetConfig(PretrainedConfig):
100
+ """Configuration for the BiCodec Postnet."""
101
+ model_type = "spark-tts-postnet"
102
+ def __init__(self, input_channels=1024, vocos_dim=384, vocos_intermediate_dim=2048,
103
+ vocos_num_layers=6, out_channels=1024, use_tanh_at_final=False, **kwargs):
104
+ # Note: Removed condition_dim as it wasn't in the original config example for postnet
105
+ super().__init__(**kwargs)
106
+ self.input_channels = input_channels
107
+ self.vocos_dim = vocos_dim
108
+ self.vocos_intermediate_dim = vocos_intermediate_dim
109
+ self.vocos_num_layers = vocos_num_layers
110
+ self.out_channels = out_channels
111
+ self.use_tanh_at_final = use_tanh_at_final
112
+
113
+
114
+ # --- Define the Intermediate BiCodec Config Class ---
115
+
116
+ class SparkTTSBiCodecConfig(PretrainedConfig):
117
+ """
118
+ Intermediate configuration class for the BiCodec component within SparkTTS.
119
+ It holds instances of the individual sub-component configurations.
120
+ """
121
+ model_type = "spark-tts-bicodec"
122
+ # Map keys in the 'bicodec_config' dict to their respective classes
123
+ sub_configs = {
124
+ "mel_params": SparkTTSMelParamsConfig,
125
+ "encoder_config": SparkTTSEncoderConfig,
126
+ "decoder_config": SparkTTSDecoderConfig,
127
+ "quantizer_config": SparkTTSQuantizerConfig,
128
+ "speaker_encoder_config": SparkTTSSpeakerEncoderConfig,
129
+ "prenet_config": SparkTTSPrenetConfig,
130
+ "postnet_config": SparkTTSPostnetConfig,
131
+ }
132
+
133
+ def __init__(
134
+ self,
135
+ mel_params=None,
136
+ encoder_config=None,
137
+ decoder_config=None,
138
+ quantizer_config=None,
139
+ speaker_encoder_config=None,
140
+ prenet_config=None,
141
+ postnet_config=None,
142
+ **kwargs,
143
+ ):
144
+ super().__init__(**kwargs)
145
+
146
+ # Instantiate sub-configs from dictionaries or use defaults/provided instances
147
+ self.mel_params = self._init_sub_config(mel_params, "mel_params")
148
+ self.encoder_config = self._init_sub_config(encoder_config, "encoder_config")
149
+ self.decoder_config = self._init_sub_config(decoder_config, "decoder_config")
150
+ self.quantizer_config = self._init_sub_config(quantizer_config, "quantizer_config")
151
+ self.speaker_encoder_config = self._init_sub_config(speaker_encoder_config, "speaker_encoder_config")
152
+ self.prenet_config = self._init_sub_config(prenet_config, "prenet_config")
153
+ self.postnet_config = self._init_sub_config(postnet_config, "postnet_config")
154
+
155
+ def _init_sub_config(self, config_input, config_key):
156
+ """Helper to initialize sub-configs."""
157
+ config_cls = self.sub_configs[config_key]
158
+ if isinstance(config_input, dict):
159
+ return config_cls(**config_input)
160
+ elif config_input is None:
161
+ return config_cls() # Initialize with defaults
162
+ elif isinstance(config_input, config_cls):
163
+ return config_input # Already an instance
164
+ else:
165
+ raise TypeError(f"Invalid type for {config_key}: {type(config_input)}. Expected dict, None, or {config_cls.__name__}.")
166
+
167
+
168
+ # --- Define the Main SparkTTS Config Class ---
169
+
170
+ class SparkTTSConfig(PretrainedConfig):
171
+ r"""
172
+ Main configuration class for SparkTTSModel, including nested BiCodec configuration.
173
+ Args:
174
+ llm_model_name_or_path (`str`, *optional*, defaults to `"./LLM"`): Path/ID for LLM.
175
+ bicodec_model_name_or_path (`str`, *optional*, defaults to `"./BiCodec"`): Path/ID for BiCodec checkpoint.
176
+ wav2vec2_model_name_or_path (`str`, *optional*, defaults to `"./wav2vec2-large-xlsr-53"`): Path/ID for Wav2Vec2.
177
+ sample_rate (`int`, *optional*, defaults to 16000): Audio sample rate.
178
+ # ... (other top-level args: highpass_cutoff_freq, latent_hop_length, ref_segment_duration, volume_normalize) ...
179
+ bicodec_config (`dict`, *optional*): Dictionary to initialize `SparkTTSBiCodecConfig`.
180
+ torch_dtype (`str`, *optional*, defaults to `"auto"`): Torch dtype.
181
+ kwargs (*optional*): Dictionary of keyword arguments.
182
+ """
183
+ model_type = "spark-tts"
184
+ # Map the key in config.json to the intermediate BiCodec config class
185
+ sub_configs = {"bicodec_config": SparkTTSBiCodecConfig}
186
+ attribute_map = {"hidden_size": "d_model"} # Example
187
+
188
+ def __init__(
189
+ self,
190
+ llm_model_name_or_path="./LLM",
191
+ bicodec_model_name_or_path="./BiCodec",
192
+ wav2vec2_model_name_or_path="./wav2vec2-large-xlsr-53",
193
+ sample_rate=16000,
194
+ highpass_cutoff_freq=40,
195
+ latent_hop_length=320,
196
+ ref_segment_duration=6.0,
197
+ volume_normalize=True,
198
+ bicodec_config=None, # Expects a dictionary or None
199
+ torch_dtype="auto",
200
+ **kwargs,
201
+ ):
202
+ # --- Top-level parameters ---
203
+ self.llm_model_name_or_path = llm_model_name_or_path
204
+ self.bicodec_model_name_or_path = bicodec_model_name_or_path
205
+ self.wav2vec2_model_name_or_path = wav2vec2_model_name_or_path
206
+ self.sample_rate = sample_rate
207
+ self.highpass_cutoff_freq = highpass_cutoff_freq
208
+ self.latent_hop_length = latent_hop_length
209
+ self.ref_segment_duration = ref_segment_duration
210
+ self.volume_normalize = volume_normalize
211
+ self.torch_dtype = torch_dtype
212
+
213
+ # --- Nested BiCodec Configuration ---
214
+ # Instantiate the intermediate BiCodec config class, which will handle its own sub-configs
215
+ if isinstance(bicodec_config, dict):
216
+ self.bicodec_config = self.sub_configs["bicodec_config"](**bicodec_config)
217
+ elif bicodec_config is None:
218
+ logger.info("`bicodec_config` not provided. Initializing `SparkTTSBiCodecConfig` with its defaults.")
219
+ self.bicodec_config = self.sub_configs["bicodec_config"]()
220
+ elif isinstance(bicodec_config, self.sub_configs["bicodec_config"]):
221
+ self.bicodec_config = bicodec_config # Use existing instance
222
+ else:
223
+ raise TypeError(f"Invalid type for bicodec_config: {type(bicodec_config)}. Expected dict, None, or SparkTTSBiCodecConfig.")
224
+
225
+
226
+ # Set processor class and auto_map
227
+ kwargs["processor_class"] = kwargs.get("processor_class", "SparkTTSProcessor")
228
+ kwargs["auto_map"] = kwargs.get("auto_map", {
229
+ "AutoConfig": "configuration_spark_tts.SparkTTSConfig",
230
+ "AutoModel": "modeling_spark_tts.SparkTTSModel",
231
+ "AutoProcessor": "processing_spark_tts.SparkTTSProcessor"
232
+ })
233
+ super().__init__(**kwargs)
modeling_spark_tts.py ADDED
The diff for this file is too large to render. See raw diff
 
processing_spark_tts.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 SparkAudio & The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Processor class for SparkTTS. Combines text tokenization and audio feature extraction/processing.
17
+ """
18
+
19
+ import os # Needed for save_pretrained
20
+ import re # For decoding
21
+ import torch
22
+ import numpy as np
23
+ import soundfile as sf # For audio loading
24
+ import soxr # For resampling
25
+
26
+ from pathlib import Path
27
+ from typing import Optional, Union, List, Dict, Tuple, Any
28
+
29
+ from transformers.processing_utils import ProcessorMixin
30
+ from transformers.tokenization_utils_base import BatchEncoding # Return type hint
31
+ from transformers.feature_extraction_utils import BatchFeature # Return type hint
32
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
33
+ from transformers.models.wav2vec2.feature_extraction_wav2vec2 import Wav2Vec2FeatureExtractor
34
+ from transformers.utils import logging, PushToHubMixin # Added PushToHubMixin
35
+ from numpy.lib.stride_tricks import sliding_window_view
36
+ import soxr
37
+ import soundfile
38
+ import random
39
+
40
+ # Import custom config if needed for defaults
41
+ from .configuration_spark_tts import SparkTTSConfig
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+
46
+ # =============================================================================
47
+ # >> START: PASTE CODE FROM sparktts/utils/* HERE <<
48
+ # =============================================================================
49
+ # IMPORTANT: Utility functions needed for processing (audio loading, token parsing)
50
+ # must be defined or imported here.
51
+
52
+ # --- Paste sparktts/utils/audio.py content here ---
53
+
54
+ def audio_volume_normalize(audio: np.ndarray, coeff: float = 0.2) -> np.ndarray:
55
+ """
56
+ Normalize the volume of an audio signal.
57
+
58
+ Parameters:
59
+ audio (numpy array): Input audio signal array.
60
+ coeff (float): Target coefficient for normalization, default is 0.2.
61
+
62
+ Returns:
63
+ numpy array: The volume-normalized audio signal.
64
+ """
65
+ # Sort the absolute values of the audio signal
66
+ temp = np.sort(np.abs(audio))
67
+
68
+ # If the maximum value is less than 0.1, scale the array to have a maximum of 0.1
69
+ if temp[-1] < 0.1:
70
+ scaling_factor = max(
71
+ temp[-1], 1e-3
72
+ ) # Prevent division by zero with a small constant
73
+ audio = audio / scaling_factor * 0.1
74
+
75
+ # Filter out values less than 0.01 from temp
76
+ temp = temp[temp > 0.01]
77
+ L = temp.shape[0] # Length of the filtered array
78
+
79
+ # If there are fewer than or equal to 10 significant values, return the audio without further processing
80
+ if L <= 10:
81
+ return audio
82
+
83
+ # Compute the average of the top 10% to 1% of values in temp
84
+ volume = np.mean(temp[int(0.9 * L) : int(0.99 * L)])
85
+
86
+ # Normalize the audio to the target coefficient level, clamping the scale factor between 0.1 and 10
87
+ audio = audio * np.clip(coeff / volume, a_min=0.1, a_max=10)
88
+
89
+ # Ensure the maximum absolute value in the audio does not exceed 1
90
+ max_value = np.max(np.abs(audio))
91
+ if max_value > 1:
92
+ audio = audio / max_value
93
+
94
+ return audio
95
+
96
+
97
+ def load_audio(
98
+ adfile: Path,
99
+ sampling_rate: int = None,
100
+ length: int = None,
101
+ volume_normalize: bool = False,
102
+ segment_duration: int = None,
103
+ ) -> np.ndarray:
104
+ r"""Load audio file with target sampling rate and lsength
105
+
106
+ Args:
107
+ adfile (Path): path to audio file.
108
+ sampling_rate (int, optional): target sampling rate. Defaults to None.
109
+ length (int, optional): target audio length. Defaults to None.
110
+ volume_normalize (bool, optional): whether perform volume normalization. Defaults to False.
111
+ segment_duration (int): random select a segment with duration of {segment_duration}s.
112
+ Defualt to None which means the whole audio will be used.
113
+
114
+ Returns:
115
+ audio (np.ndarray): audio
116
+ """
117
+
118
+ audio, sr = soundfile.read(adfile)
119
+ if len(audio.shape) > 1:
120
+ audio = audio[:, 0]
121
+
122
+ if sampling_rate is not None and sr != sampling_rate:
123
+ audio = soxr.resample(audio, sr, sampling_rate, quality="VHQ")
124
+ sr = sampling_rate
125
+
126
+ if segment_duration is not None:
127
+ seg_length = int(sr * segment_duration)
128
+ audio = random_select_audio_segment(audio, seg_length)
129
+
130
+ # Audio volume normalize
131
+ if volume_normalize:
132
+ audio = audio_volume_normalize(audio)
133
+ # check the audio length
134
+ if length is not None:
135
+ assert abs(audio.shape[0] - length) < 1000
136
+ if audio.shape[0] > length:
137
+ audio = audio[:length]
138
+ else:
139
+ audio = np.pad(audio, (0, int(length - audio.shape[0])))
140
+ return audio
141
+
142
+
143
+ def random_select_audio_segment(audio: np.ndarray, length: int) -> np.ndarray:
144
+ """get an audio segment given the length
145
+
146
+ Args:
147
+ audio (np.ndarray):
148
+ length (int): audio length = sampling_rate * duration
149
+ """
150
+ if audio.shape[0] < length:
151
+ audio = np.pad(audio, (0, int(length - audio.shape[0])))
152
+ start_index = random.randint(0, audio.shape[0] - length)
153
+ end_index = int(start_index + length)
154
+
155
+ return audio[start_index:end_index]
156
+
157
+ def get_ref_clip(wav: np.ndarray, config) -> np.ndarray: # Needs access to config attributes
158
+ """Get reference audio clip for speaker embedding."""
159
+ # Make sure config has sample_rate, ref_segment_duration, latent_hop_length
160
+ if not all(hasattr(config, attr) for attr in ['sample_rate', 'ref_segment_duration', 'latent_hop_length']):
161
+ raise AttributeError("Config object missing required attributes for get_ref_clip")
162
+ ref_segment_length = (
163
+ int(config.sample_rate * config.ref_segment_duration)
164
+ // config.latent_hop_length
165
+ * config.latent_hop_length
166
+ )
167
+ wav_length = len(wav)
168
+ if ref_segment_length > wav_length:
169
+ wav = np.tile(wav, ref_segment_length // wav_length + 1)
170
+ return wav[:ref_segment_length]
171
+
172
+
173
+ # --- Paste sparktts/utils/token_parser.py content here ---
174
+
175
+ TASK_TOKEN_MAP = {
176
+ "vc": "<|task_vc|>",
177
+ "tts": "<|task_tts|>",
178
+ "asr": "<|task_asr|>",
179
+ "s2s": "<|task_s2s|>",
180
+ "t2s": "<|task_t2s|>",
181
+ "understand": "<|task_understand|>",
182
+ "caption": "<|task_cap|>",
183
+ "controllable_tts": "<|task_controllable_tts|>",
184
+ "prompt_tts": "<|task_prompt_tts|>",
185
+ "speech_edit": "<|task_edit|>",
186
+ }
187
+
188
+ LEVELS_MAP = {
189
+ "very_low": 0,
190
+ "low": 1,
191
+ "moderate": 2,
192
+ "high": 3,
193
+ "very_high": 4,
194
+ }
195
+
196
+ LEVELS_MAP_UI = {
197
+ 1: 'very_low',
198
+ 2: 'low',
199
+ 3: 'moderate',
200
+ 4: 'high',
201
+ 5: 'very_high'
202
+ }
203
+
204
+ GENDER_MAP = {
205
+ "female": 0,
206
+ "male": 1,
207
+ }
208
+
209
+ AGE_MAP = {"Child": 0, "Teenager": 1, "Youth-Adult": 2, "Middle-aged": 3, "Elderly": 4}
210
+
211
+ EMO_MAP = {
212
+ "UNKNOWN": 0,
213
+ "NEUTRAL": 1,
214
+ "ANGRY": 2,
215
+ "HAPPY": 3,
216
+ "SAD": 4,
217
+ "FEARFUL": 5,
218
+ "DISGUSTED": 6,
219
+ "SURPRISED": 7,
220
+ "SARCASTIC": 8,
221
+ "EXCITED": 9,
222
+ "SLEEPY": 10,
223
+ "CONFUSED": 11,
224
+ "EMPHASIS": 12,
225
+ "LAUGHING": 13,
226
+ "SINGING": 14,
227
+ "WORRIED": 15,
228
+ "WHISPER": 16,
229
+ "ANXIOUS": 17,
230
+ "NO-AGREEMENT": 18,
231
+ "APOLOGETIC": 19,
232
+ "CONCERNED": 20,
233
+ "ENUNCIATED": 21,
234
+ "ASSERTIVE": 22,
235
+ "ENCOURAGING": 23,
236
+ "CONTEMPT": 24,
237
+ }
238
+
239
+
240
+ class TokenParser:
241
+ """Turn label to special token"""
242
+
243
+ def __init__(self):
244
+ pass
245
+
246
+ """Parse the attributes of a person."""
247
+
248
+ def __init__(self):
249
+ pass
250
+
251
+ @staticmethod
252
+ def age(age: str) -> str:
253
+ """Turn age token."""
254
+ age_id = AGE_MAP[age]
255
+ return f"<|age_{age_id}|>"
256
+
257
+ @staticmethod
258
+ def gender(gender: str) -> str:
259
+ """Turn gender token."""
260
+ gender_id = GENDER_MAP[gender]
261
+ return f"<|gender_{gender_id}|>"
262
+
263
+ @staticmethod
264
+ def mel_value(mel: int):
265
+ """Turn special token of mel scale pitch."""
266
+ mel = max(0, int(mel))
267
+ mel = min(1000, int(mel))
268
+ return f"<|pitch_value_{mel}|>"
269
+
270
+ @staticmethod
271
+ def mel_level(level: str):
272
+ """Turn special token of mel level."""
273
+ level_tag = LEVELS_MAP[level]
274
+ return f"<|pitch_label_{level_tag}|>"
275
+
276
+ @staticmethod
277
+ def pitch_var_value(pitch_std: int):
278
+ """Turn special token of pitch_std value."""
279
+ assert isinstance(pitch_std, int)
280
+ pitch_std = max(0, int(pitch_std))
281
+ pitch_std = min(10, int(pitch_std))
282
+ return f"<|pitch_var_value_{pitch_std}|>"
283
+
284
+ @staticmethod
285
+ def pitch_var_level(level: str):
286
+ """Turn special token of pitch std level."""
287
+ level_tag = LEVELS_MAP[level]
288
+ return f"<|pitch_var_label_{level_tag}|>"
289
+
290
+ @staticmethod
291
+ def loudness_value(loudness: int):
292
+ """Turn special toak of loudness value [0, 30]"""
293
+ assert loudness >= 0
294
+ loudness = max(0, int(loudness))
295
+ loudness = min(30, int(loudness))
296
+ return f"<|loudness_value_{loudness}|>"
297
+
298
+ @staticmethod
299
+ def loudness_level(level: str):
300
+ """Turn special token of loudness level."""
301
+ level_tag = LEVELS_MAP[level]
302
+ return f"<|loudness_label_{level_tag}|>"
303
+
304
+ @staticmethod
305
+ def speed_value(speed: int):
306
+ """Turn special token of speed value."""
307
+ speed = max(0, int(speed))
308
+ speed = min(10, int(speed))
309
+ return f"<|speed_value_{speed}|>"
310
+
311
+ @staticmethod
312
+ def speed_level(level: str):
313
+ """Turn special token of speed level."""
314
+ level_tag = LEVELS_MAP[level]
315
+ return f"<|speed_label_{level_tag}|>"
316
+
317
+ @staticmethod
318
+ def task(task: str) -> str:
319
+ """Turn special token of task."""
320
+ assert task in TASK_TOKEN_MAP.keys()
321
+
322
+ return TASK_TOKEN_MAP[task]
323
+
324
+ @staticmethod
325
+ def emotion(emotion: str):
326
+ emo_id = EMO_MAP[emotion]
327
+
328
+ return f"<|emotion_{emo_id}|>"
329
+
330
+ # =============================================================================
331
+ # >> END: PASTE CODE FROM sparktts/utils/* HERE <<
332
+ # =============================================================================
333
+
334
+
335
+ class SparkTTSProcessor(ProcessorMixin, PushToHubMixin): # Added PushToHubMixin
336
+ r"""
337
+ Constructs a SparkTTS processor which wraps a text tokenizer and relevant audio processing logic.
338
+
339
+ Args:
340
+ tokenizer ([`PreTrainedTokenizer`]):
341
+ An instance of [`PreTrainedTokenizer`]. This handles the text tokenization for the LLM.
342
+ feature_extractor ([`Wav2Vec2FeatureExtractor`]):
343
+ An instance of [`Wav2Vec2FeatureExtractor`]. Although Wav2Vec2 features are extracted
344
+ within the model's `tokenize_audio`, the extractor's configuration (like sampling rate)
345
+ is useful, and it aligns with the ProcessorMixin pattern.
346
+ config ([`SparkTTSConfig`], *optional*):
347
+ An instance of [`SparkTTSConfig`] to access configuration parameters like sample rate.
348
+ """
349
+ attributes = ["tokenizer", "feature_extractor"]
350
+ tokenizer_class = "AutoTokenizer"
351
+ feature_extractor_class = "Wav2Vec2FeatureExtractor" # Keep for consistency
352
+
353
+ def __init__(self, tokenizer, feature_extractor, config: Optional[SparkTTSConfig] = None, **kwargs):
354
+ super().__init__(tokenizer=tokenizer, feature_extractor=feature_extractor, **kwargs)
355
+ self.model = None
356
+ self.config = config
357
+ # Set sampling rate
358
+ if config and hasattr(config, 'sample_rate'):
359
+ self.sampling_rate = config.sample_rate
360
+ elif feature_extractor and hasattr(feature_extractor, 'sampling_rate'):
361
+ self.sampling_rate = feature_extractor.sampling_rate
362
+ else:
363
+ self.sampling_rate = 16000
364
+ logger.warning(f"Could not determine sampling rate. Defaulting to {self.sampling_rate} Hz.")
365
+
366
+ # # Ensure tokenizer pad token
367
+ # if self.tokenizer.pad_token is None:
368
+ # if self.tokenizer.eos_token is not None:
369
+ # logger.warning("Tokenizer does not have a pad token. Setting pad_token to eos_token.")
370
+ # self.tokenizer.pad_token = self.tokenizer.eos_token
371
+ # else:
372
+ # logger.warning("Tokenizer lacks pad and eos token. Adding default pad token '<|pad|>'.")
373
+ # self.tokenizer.add_special_tokens({'pad_token': '<|pad|>'})
374
+
375
+ def link_model(self, model):
376
+ """Links the processor to a SparkTTSModel instance for audio processing calls."""
377
+ if not hasattr(model, 'tokenize_audio') or not hasattr(model, 'detokenize_audio'):
378
+ raise TypeError("The provided model instance does not have the required 'tokenize_audio' and 'detokenize_audio' methods.")
379
+ if not hasattr(model, 'config'):
380
+ logger.warning("Linked model does not have a 'config' attribute. Some processor functionalities might rely on it.")
381
+
382
+ self.model = model
383
+ logger.info("SparkTTSModel successfully linked to the processor.")
384
+ # Update sampling rate based on linked model's config if available
385
+ if hasattr(model, 'config') and hasattr(model.config, 'sample_rate'):
386
+ if self.sampling_rate != model.config.sample_rate:
387
+ logger.info(f"Updating processor sampling rate from {self.sampling_rate} to {model.config.sample_rate} based on linked model config.")
388
+ self.sampling_rate = model.config.sample_rate
389
+ # Also update feature extractor sampling rate if it differs
390
+ if hasattr(self, 'feature_extractor') and self.feature_extractor.sampling_rate != model.config.sample_rate:
391
+ logger.info(f"Updating feature_extractor sampling rate from {self.feature_extractor.sampling_rate} to {model.config.sample_rate}.")
392
+ self.feature_extractor.sampling_rate = model.config.sample_rate
393
+
394
+
395
+ def __call__(
396
+ self,
397
+ text: str,
398
+ prompt_speech_path: Optional[Union[str, Path]] = None,
399
+ prompt_text: Optional[str] = None,
400
+ gender: Optional[str] = None,
401
+ pitch: Optional[str] = None,
402
+ speed: Optional[str] = None,
403
+ return_tensors: Optional[str] = "pt",
404
+ **kwargs, # Allow passing other args like padding, truncation to tokenizer
405
+ ) -> BatchEncoding:
406
+ """
407
+ Processes the input text and optional prompt audio/control parameters into a format suitable for [`SparkTTSModel`].
408
+
409
+ Args:
410
+ text (`str`):
411
+ The main text to be synthesized.
412
+ prompt_speech_path (`str` or `Path`, *optional*):
413
+ Path to the prompt audio file for voice cloning. Required if `gender` is not set.
414
+ prompt_text (`str`, *optional*):
415
+ Transcript of the prompt audio. Used only in voice cloning mode.
416
+ gender (`str`, *optional*):
417
+ Target gender ("male" or "female") for controllable synthesis. If set, enables control mode.
418
+ pitch (`str`, *optional*):
419
+ Target pitch level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
420
+ speed (`str`, *optional*):
421
+ Target speed level ("very_low", "low", "moderate", "high", "very_high") for control mode. Required if `gender` is set.
422
+ return_tensors (`str`, *optional*, defaults to `"pt"`):
423
+ If set, will return tensors instead of list of python integers. Only "pt" (PyTorch) is supported currently.
424
+ **kwargs:
425
+ Additional arguments passed to the underlying tokenizer's `__call__` method.
426
+
427
+ Returns:
428
+ [`BatchEncoding`]: A dictionary containing the `input_ids` and `attention_mask` for the LLM.
429
+ In voice cloning mode, it also includes `global_token_ids_prompt` (torch.Tensor) representing the
430
+ global tokens extracted from the prompt audio.
431
+ """
432
+
433
+ global_token_ids_prompt = None # Initialize
434
+
435
+ # Determine mode: Control TTS or Voice Cloning (Prompt TTS)
436
+ is_control_mode = gender is not None
437
+ is_cloning_mode = prompt_speech_path is not None and not is_control_mode
438
+
439
+ if is_control_mode:
440
+ # --- Controllable TTS Prompt Construction ---
441
+ if not all([pitch, speed]):
442
+ raise ValueError("For controllable TTS, 'gender', 'pitch', and 'speed' must all be provided.")
443
+ if prompt_speech_path is not None:
444
+ logger.warning("`prompt_speech_path` provided but ignored because `gender` is set (controllable TTS mode).")
445
+
446
+ if not all(k in GENDER_MAP for k in [gender]): # Basic check
447
+ raise ValueError(f"Invalid gender provided: {gender}. Must be one of {list(GENDER_MAP.keys())}")
448
+ if not all(k in LEVELS_MAP for k in [pitch, speed]): # Basic check
449
+ raise ValueError(f"Invalid pitch or speed level provided. Must be one of {list(LEVELS_MAP.keys())}")
450
+
451
+ gender_id = GENDER_MAP[gender]
452
+ pitch_level_id = LEVELS_MAP[pitch]
453
+ speed_level_id = LEVELS_MAP[speed]
454
+
455
+ pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
456
+ speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
457
+ gender_tokens = f"<|gender_{gender_id}|>"
458
+
459
+ attribute_tokens = "".join([gender_tokens, pitch_label_tokens, speed_label_tokens])
460
+
461
+ prompt_list = [
462
+ TASK_TOKEN_MAP["controllable_tts"],
463
+ "<|start_content|>",
464
+ text,
465
+ "<|end_content|>",
466
+ "<|start_style_label|>",
467
+ attribute_tokens,
468
+ "<|end_style_label|>",
469
+ ]
470
+ prompt_string = "".join(prompt_list)
471
+
472
+ elif is_cloning_mode:
473
+ # --- Voice Cloning Prompt Construction ---
474
+ if self.model is None:
475
+ raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before performing voice cloning.")
476
+ prompt_speech_path = Path(prompt_speech_path) # Ensure it's a Path object
477
+ if not prompt_speech_path.exists():
478
+ raise FileNotFoundError(f"Prompt audio file not found: {prompt_speech_path}")
479
+
480
+ # Load and process prompt audio
481
+ try:
482
+ model_config = self.model.config if self.model and hasattr(self.model, 'config') else self.config
483
+ if model_config is None:
484
+ raise ValueError("Configuration not available in processor or linked model.")
485
+
486
+ # Load main wav
487
+ wav = load_audio(
488
+ prompt_speech_path,
489
+ sampling_rate=self.sampling_rate,
490
+ volume_normalize=getattr(model_config, 'volume_normalize', True), # Use getattr for safety
491
+ )
492
+ # Get reference clip
493
+ wav_ref_np = get_ref_clip(wav, model_config) # Pass config object
494
+ wav_ref = torch.from_numpy(wav_ref_np).unsqueeze(0).float()
495
+ wav_tensor = torch.from_numpy(wav).unsqueeze(0).float()
496
+
497
+ # Tokenize using the linked model's method
498
+ # Assuming tokenize_audio returns tensors with batch dim 1: [1, N_global], [1, N_semantic]
499
+ global_tokens_tensor, semantic_tokens_tensor = self.model.tokenize_audio(wav_tensor, wav_ref)
500
+
501
+ # Store the global tokens tensor (with batch dim) for the output dict
502
+ global_token_ids_prompt = global_tokens_tensor # Keep batch dim [1, N_global]
503
+
504
+ # Convert tensors to lists of ints for string formatting
505
+ global_token_list = global_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
506
+ semantic_token_list = semantic_tokens_tensor.squeeze().tolist() # Remove batch dim -> list
507
+
508
+ except Exception as e:
509
+ logger.error(f"Error processing prompt audio {prompt_speech_path}: {e}")
510
+ import traceback
511
+ traceback.print_exc()
512
+ raise
513
+
514
+ # ==============================================================
515
+ # CORRECTED TOKEN STRING FORMATTING
516
+ # ==============================================================
517
+ # Create individual token strings for each ID
518
+ global_tokens_str = "".join([f"<|bicodec_global_{gid}|>" for gid in global_token_list])
519
+ semantic_tokens_str = "".join([f"<|bicodec_semantic_{sid}|>" for sid in semantic_token_list])
520
+ # ==============================================================
521
+
522
+ # Construct prompt list based on presence of prompt_text
523
+ if prompt_text is not None and prompt_text.strip(): # Check if prompt_text is meaningful
524
+ logger.info("Using prompt text in voice cloning prompt.")
525
+ prompt_list = [
526
+ TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]? Check original logic. Assuming "tts".
527
+ "<|start_content|>",
528
+ prompt_text, # Transcript first
529
+ text, # Then target text
530
+ "<|end_content|>",
531
+ "<|start_global_token|>",
532
+ global_tokens_str,
533
+ "<|end_global_token|>",
534
+ "<|start_semantic_token|>",
535
+ semantic_tokens_str,
536
+ # "<|end_semantic_token|>", # Original code didn't have this marker here
537
+ ]
538
+ else:
539
+ # Simpler prompt without semantic tokens if no transcript provided
540
+ logger.info("No prompt text provided, using text-only voice cloning prompt.")
541
+ prompt_list = [
542
+ TASK_TOKEN_MAP["tts"], # Or maybe TASK_TOKEN_MAP["prompt_tts"]?
543
+ "<|start_content|>",
544
+ text, # Only target text
545
+ "<|end_content|>",
546
+ "<|start_global_token|>",
547
+ global_tokens_str,
548
+ "<|end_global_token|>",
549
+ ]
550
+ prompt_string = "".join(prompt_list)
551
+ logger.debug(f"Generated prompt string (cloning): {prompt_string[:200]}...") # Log start of prompt
552
+
553
+ else:
554
+ raise ValueError("Invalid input combination. Either provide `prompt_speech_path` for cloning or (`gender`, `pitch`, `speed`) for control.")
555
+
556
+ # --- Tokenize the final prompt string ---
557
+ # print(f"Tokenizing prompt: {prompt_string}")
558
+ inputs = self.tokenizer(
559
+ prompt_string,
560
+ return_tensors=return_tensors,
561
+ padding=kwargs.get("padding", False), # Often False for generation prompts unless batching > 1
562
+ truncation=kwargs.get("truncation", True),
563
+ max_length=kwargs.get("max_length", self.tokenizer.model_max_length),
564
+ add_special_tokens=kwargs.get("add_special_tokens", True), # Usually True unless handled manually
565
+ return_attention_mask=kwargs.get("return_attention_mask", True), # Need attention mask
566
+ **{k: v for k, v in kwargs.items() if k not in ["padding", "truncation", "max_length", "add_special_tokens", "return_attention_mask"]}
567
+ )
568
+ logger.debug(f"Tokenized input_ids shape: {inputs['input_ids'].shape}")
569
+
570
+
571
+ # Add the prompt's global tokens (as tensor with batch dim) to the output if in cloning mode
572
+ if is_cloning_mode and global_token_ids_prompt is not None:
573
+ if return_tensors == "pt":
574
+ inputs["global_token_ids_prompt"] = global_token_ids_prompt # Already has batch dim [1, N_global]
575
+ else:
576
+ # Handle non-tensor return if necessary
577
+ inputs["global_token_ids_prompt"] = global_token_ids_prompt.tolist()
578
+
579
+ return inputs
580
+
581
+
582
+ def decode(
583
+ self,
584
+ generated_ids: torch.Tensor,
585
+ global_token_ids_prompt: Optional[torch.Tensor] = None,
586
+ input_ids_len: Optional[int] = None,
587
+ skip_special_tokens: bool = True,
588
+ ) -> Dict[str, Any]:
589
+ """
590
+ Decodes the generated token IDs from [`SparkTTSModel`] into an audio waveform.
591
+
592
+ Args:
593
+ generated_ids (`torch.Tensor`):
594
+ Tensor of token IDs generated by `model.generate()`, including the input prompt part. Shape [B, seq_len].
595
+ global_token_ids_prompt (`torch.Tensor`, *optional*):
596
+ The global tokens extracted from the prompt audio during the `__call__` step (for voice cloning).
597
+ Shape [B, N_global]. Required if the generation was for voice cloning.
598
+ input_ids_len (`int`, *optional*):
599
+ The length of the original input prompt `input_ids` fed to `model.generate()`. Required to
600
+ correctly isolate the newly generated tokens.
601
+ skip_special_tokens (`bool`, *optional*, defaults to `True`):
602
+ Whether to skip special tokens during the text decoding step (used to extract audio tokens).
603
+
604
+ Returns:
605
+ Dict[str, Any]: A dictionary containing:
606
+ - "audio": The decoded audio waveform as a NumPy array. Shape [T_audio] (if B=1) or [B, T_audio].
607
+ - "sampling_rate": The sampling rate of the audio.
608
+ """
609
+ if self.model is None:
610
+ raise RuntimeError("Processor must be linked to a SparkTTSModel instance via `processor.link_model(model)` before decoding.")
611
+ if input_ids_len is None:
612
+ raise ValueError("`input_ids_len` (length of the prompt input_ids) must be provided for decoding.")
613
+
614
+ # --- Isolate generated part and decode text ---
615
+ # Assumes generated_ids has shape [B, full_seq_len]
616
+ # Handle case where generated sequence is shorter than prompt (shouldn't happen with max_new_tokens > 0)
617
+ if generated_ids.shape[1] < input_ids_len:
618
+ logger.warning(f"Generated sequence length ({generated_ids.shape[1]}) is shorter than input prompt length ({input_ids_len}). Decoding might be incorrect.")
619
+ output_only_ids = generated_ids[:, input_ids_len:] # Will be empty if equal
620
+ else:
621
+ output_only_ids = generated_ids[:, input_ids_len:]
622
+
623
+
624
+ # Decode the generated part to find audio tokens
625
+ # Need to handle batch decoding if B > 1
626
+ # print("decode token", self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=False))
627
+ decoded_texts = self.tokenizer.batch_decode(output_only_ids, skip_special_tokens=skip_special_tokens)
628
+
629
+ # --- Extract Audio Tokens ---
630
+ # Handle batch processing correctly
631
+ batch_size = generated_ids.shape[0]
632
+ all_semantic_ids = []
633
+ all_global_tokens = []
634
+ successful_indices = [] # Keep track of which batch items were successful
635
+
636
+ for i in range(batch_size):
637
+ decoded_text = decoded_texts[i]
638
+ current_semantic_ids = None
639
+ current_global_tokens = None
640
+
641
+ # Extract semantic tokens
642
+ try:
643
+ pred_semantic_indices = [int(token) for token in re.findall(r"bicodec_semantic_(\d+)", decoded_text)]
644
+ if not pred_semantic_indices:
645
+ logger.warning(f"Batch item {i}: No semantic tokens found in decoded text: '{decoded_text[:200]}...'")
646
+ continue # Skip this item
647
+
648
+ current_semantic_ids = torch.tensor(pred_semantic_indices).long() # Shape [N_semantic]
649
+ except Exception as e:
650
+ logger.error(f"Batch item {i}: Error parsing semantic tokens from: '{decoded_text[:200]}...'. Error: {e}")
651
+ continue # Skip this item
652
+
653
+ # Determine global tokens
654
+ if global_token_ids_prompt is not None:
655
+ # Cloning mode: Use the provided prompt global tokens for this batch item
656
+ if global_token_ids_prompt.shape[0] != batch_size:
657
+ raise ValueError(f"Batch size mismatch: generated_ids has {batch_size}, but global_token_ids_prompt has {global_token_ids_prompt.shape[0]}.")
658
+ current_global_tokens = global_token_ids_prompt[i] # Shape [N_global]
659
+ else:
660
+ # Control mode: Extract global tokens from the generated text
661
+ try:
662
+ pred_global_indices = [int(token) for token in re.findall(r"bicodec_global_(\d+)", decoded_text)]
663
+ if not pred_global_indices:
664
+ logger.warning(f"Batch item {i}: No global tokens found in decoded text for control mode: '{decoded_text[:200]}...'")
665
+ continue # Skip this item
666
+
667
+ current_global_tokens = torch.tensor(pred_global_indices).long() # Shape [N_global]
668
+
669
+ except Exception as e:
670
+ logger.error(f"Batch item {i}: Error parsing global tokens from: '{decoded_text[:200]}...'. Error: {e}")
671
+ continue # Skip this item
672
+
673
+ # If both tokens extracted successfully
674
+ all_semantic_ids.append(current_semantic_ids)
675
+ all_global_tokens.append(current_global_tokens)
676
+ successful_indices.append(i)
677
+
678
+ if not successful_indices:
679
+ logger.error("Failed to extract audio tokens for any item in the batch.")
680
+ return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
681
+
682
+ # Pad sequences to the max length within the successful batch items for batch detokenization
683
+ # Note: BiCodec might not support batching if sequences have different lengths. Check its implementation.
684
+ # Assuming BiCodec *can* handle batches if padded (or if lengths are naturally equal).
685
+ # This padding might be unnecessary if BiCodec handles variable lengths or if B=1 anyway.
686
+ # For now, let's assume B=1 was handled correctly and skip complex padding.
687
+ if batch_size > 1 and len(successful_indices) < batch_size:
688
+ logger.warning(f"Only successfully decoded {len(successful_indices)} out of {batch_size} batch items.")
689
+ # Further processing might need to handle only the successful items.
690
+
691
+ # Let's proceed assuming B=1 or BiCodec handles batches appropriately.
692
+ # Stack the successful tokens.
693
+ try:
694
+ # Need to ensure tensors have the same length before stacking if BiCodec requires it.
695
+ # If BiCodec handles variable length, stacking might not be needed, just loop and call detokenize.
696
+ # Let's assume B=1 for simplicity of the example, matching original code's likely behavior.
697
+ if len(successful_indices) != 1:
698
+ raise NotImplementedError("Batch decoding (B > 1) requires verification of BiCodec's batch handling and potentially padding.")
699
+
700
+ final_semantic_ids = all_semantic_ids[0].unsqueeze(0) # Add batch dim [1, N_semantic]
701
+ final_global_tokens = all_global_tokens[0].unsqueeze(0) # Add batch dim [1, N_global]
702
+
703
+ except IndexError: # Should not happen if successful_indices is not empty
704
+ logger.error("Internal error during token batch preparation.")
705
+ return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
706
+
707
+
708
+ # --- Detokenize Audio ---
709
+ try:
710
+ # Call the linked model's detokenize method
711
+ # print(f"DEBUG: Detokenizing audio with global tokens {final_global_tokens.shape}, semantic tokens {final_semantic_ids.shape}")
712
+ output_wav = self.model.detokenize_audio(final_global_tokens, final_semantic_ids)
713
+ # detokenize_audio now returns numpy array float32 in [-1, 1]
714
+
715
+ # Optional: Double-check dtype here if needed, but should be handled by detokenize_audio now
716
+ # if output_wav.dtype != np.float32:
717
+ # logger.warning(f"Audio dtype after detokenize is {output_wav.dtype}. Converting to float32.")
718
+ # output_wav = output_wav.astype(np.float32)
719
+ # output_wav = np.clip(output_wav, -1.0, 1.0) # Clipping done in detokenize_audio
720
+
721
+ except Exception as e:
722
+ logger.error(f"Error during audio detokenization: {e}")
723
+ import traceback
724
+ traceback.print_exc()
725
+ raise RuntimeError("Audio detokenization failed.") from e
726
+
727
+ return {"audio": output_wav, "sampling_rate": self.sampling_rate}
728
+
729
+
730
+ @classmethod
731
+ def from_pretrained(
732
+ cls,
733
+ pretrained_model_name_or_path: Union[str, os.PathLike],
734
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
735
+ force_download: bool = False,
736
+ local_files_only: bool = False,
737
+ token: Optional[Union[str, bool]] = None,
738
+ revision: str = "main",
739
+ trust_remote_code: bool = False, # Allow passing this, needed for config potentially
740
+ **kwargs,
741
+ ):
742
+ r"""
743
+ Instantiate a SparkTTSProcessor from pretrained components.
744
+ """
745
+ # Pop specific kwargs for this method
746
+ config = kwargs.pop("config", None) # Allow passing config explicitly
747
+
748
+ # --- 1. Load Config (to find component paths) ---
749
+ # We need the config even if the processor doesn't store it permanently,
750
+ # just to find where the tokenizer/feature_extractor live.
751
+ loaded_config = None
752
+ if not isinstance(config, SparkTTSConfig):
753
+ try:
754
+ # Load the specific config class
755
+ loaded_config = SparkTTSConfig.from_pretrained(
756
+ pretrained_model_name_or_path,
757
+ cache_dir=cache_dir,
758
+ force_download=force_download,
759
+ local_files_only=local_files_only,
760
+ token=token,
761
+ revision=revision,
762
+ trust_remote_code=trust_remote_code, # Config might be custom
763
+ **kwargs, # Pass relevant kwargs
764
+ )
765
+ except Exception as e:
766
+ logger.warning(
767
+ f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. "
768
+ f"Attempting to load components from default relative paths ('LLM', 'wav2vec2-large-xlsr-53'). Error: {e}"
769
+ )
770
+ loaded_config = None # Fallback
771
+ else:
772
+ # Config object was passed directly
773
+ loaded_config = config
774
+
775
+
776
+ # --- 2. Determine Component Paths ---
777
+ llm_tokenizer_path_or_id = "./LLM" # Default relative path
778
+ w2v_processor_path_or_id = "./wav2vec2-large-xlsr-53" # Default relative path
779
+
780
+ if loaded_config:
781
+ llm_tokenizer_path_or_id = getattr(loaded_config, 'llm_model_name_or_path', llm_tokenizer_path_or_id)
782
+ w2v_processor_path_or_id = getattr(loaded_config, 'wav2vec2_model_name_or_path', w2v_processor_path_or_id)
783
+
784
+ # The component `from_pretrained` methods handle resolving these paths/IDs
785
+ # whether they are relative subfolders of `pretrained_model_name_or_path`
786
+ # or separate Hub IDs.
787
+
788
+ # --- 3. Load Components ---
789
+ # Pass down relevant kwargs for loading components
790
+ component_loading_kwargs = {
791
+ "cache_dir": cache_dir,
792
+ "force_download": force_download,
793
+ "local_files_only": local_files_only,
794
+ "token": token,
795
+ "revision": revision,
796
+ **kwargs # Pass other user kwargs
797
+ }
798
+ try:
799
+ # Tokenizer might require trust_remote_code if its class is custom
800
+ tokenizer = AutoTokenizer.from_pretrained(
801
+ pretrained_model_name_or_path, # Main path
802
+ subfolder=llm_tokenizer_path_or_id.lstrip('./'), # Specify subfolder relative to main path
803
+ trust_remote_code=trust_remote_code,
804
+ **component_loading_kwargs
805
+ )
806
+ except Exception as e:
807
+ # Fallback: try loading directly using the path/id from config if different
808
+ if llm_tokenizer_path_or_id != "./LLM":
809
+ try:
810
+ logger.info(f"Retrying tokenizer load directly from: {llm_tokenizer_path_or_id}")
811
+ tokenizer = AutoTokenizer.from_pretrained(
812
+ llm_tokenizer_path_or_id,
813
+ trust_remote_code=trust_remote_code,
814
+ **component_loading_kwargs
815
+ )
816
+ except Exception as e2:
817
+ raise OSError(f"Could not load tokenizer using main path + subfolder or directly from '{llm_tokenizer_path_or_id}'. Error: {e2}") from e
818
+ else:
819
+ raise OSError(f"Could not load tokenizer from subfolder '{llm_tokenizer_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
820
+
821
+
822
+ try:
823
+ # Feature extractor usually doesn't need trust_remote_code
824
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
825
+ pretrained_model_name_or_path, # Main path
826
+ subfolder=w2v_processor_path_or_id.lstrip('./'), # Specify subfolder relative to main path
827
+ **component_loading_kwargs
828
+ )
829
+ except Exception as e:
830
+ # Fallback: try loading directly using the path/id from config if different
831
+ if w2v_processor_path_or_id != "./wav2vec2-large-xlsr-53":
832
+ try:
833
+ logger.info(f"Retrying feature extractor load directly from: {w2v_processor_path_or_id}")
834
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
835
+ w2v_processor_path_or_id,
836
+ **component_loading_kwargs
837
+ )
838
+ except Exception as e2:
839
+ raise OSError(f"Could not load feature extractor using main path + subfolder or directly from '{w2v_processor_path_or_id}'. Error: {e2}") from e
840
+ else:
841
+ raise OSError(f"Could not load feature extractor from subfolder '{w2v_processor_path_or_id}' within '{pretrained_model_name_or_path}'. Error: {e}")
842
+
843
+
844
+ # --- 4. Instantiate processor ---
845
+ # Pass the potentially loaded config object (or None)
846
+ return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=loaded_config)
847
+
848
+
849
+ def save_pretrained(
850
+ self,
851
+ save_directory: Union[str, os.PathLike],
852
+ push_to_hub: bool = False,
853
+ **kwargs,
854
+ ):
855
+ """
856
+ Save the processor's state (tokenizer and feature extractor files) to a directory.
857
+
858
+ Args:
859
+ save_directory (`str` or `os.PathLike`):
860
+ Directory where the processor files will be saved.
861
+ push_to_hub (`bool`, *optional*, defaults to `False`):
862
+ Whether or not to push your model to the Hugging Face Hub after saving it.
863
+ **kwargs:
864
+ Additional key word arguments passed along to the `push_to_hub` method.
865
+ """
866
+ save_directory = Path(save_directory)
867
+ save_directory.mkdir(parents=True, exist_ok=True)
868
+
869
+ # Save tokenizer
870
+ self.tokenizer.save_pretrained(str(save_directory), **kwargs)
871
+
872
+ # Save feature extractor
873
+ self.feature_extractor.save_pretrained(str(save_directory), **kwargs)
874
+
875
+ # Save the main processor config (if it exists and has relevant info)
876
+ # Note: The SparkTTSConfig is usually saved with the *model*, not the processor.
877
+ # However, if the processor holds specific config needed for reloading *itself*,
878
+ # it could be saved here. Usually, relying on the model's config is sufficient.
879
+ # if self.config:
880
+ # self.config.save_pretrained(str(save_directory)) # Example if needed
881
+
882
+ logger.info(f"Processor components saved in {save_directory}")
883
+
884
+ if push_to_hub:
885
+ # Commit message and other hub kwargs can be passed via **kwargs
886
+ commit_message = kwargs.pop("commit_message", "Save processor")
887
+ return self.push_to_hub(save_directory, commit_message=commit_message, **kwargs)
888
+
889
+ return str(save_directory) # Return path consistent with Mixin