ancv commited on
Commit
1de7b9d
·
verified ·
1 Parent(s): 72fe226

Delete processing_spark_tts.py

Browse files
Files changed (1) hide show
  1. processing_spark_tts.py +0 -345
processing_spark_tts.py DELETED
@@ -1,345 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 The SparkAudio Authors and The HuggingFace Inc. team. All rights reserved.
3
- # ... (license) ...
4
- """Processor class for SparkTTS."""
5
-
6
- import torch
7
- import re
8
- import numpy as np
9
- import warnings
10
- from typing import Optional, Dict, Any, Union, List, Tuple
11
- from pathlib import Path
12
-
13
- from transformers.processing_utils import ProcessorMixin
14
- from transformers.feature_extraction_utils import FeatureExtractionMixin
15
- from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
16
- from transformers import AutoTokenizer, Wav2Vec2FeatureExtractor
17
- from transformers.utils import logging
18
-
19
- # Import necessary items directly or ensure they are available via model reference
20
- # Note: Avoid direct model imports here if possible, rely on the model reference.
21
- # from .modeling_spark_tts import SparkTTSModel # Avoid direct model import if possible
22
- from .configuration_spark_tts import SparkTTSConfig # Config is okay
23
-
24
- # Import utils needed for prompt formatting (assuming they are merged into modeling)
25
- # We'll access them via the model reference if needed, or duplicate simple ones like token maps.
26
-
27
- logger = logging.get_logger(__name__)
28
-
29
- # --- Token Maps (Duplicate here for direct use in processor) ---
30
- TASK_TOKEN_MAP = {
31
- "tts": "<|task_tts|>",
32
- "controllable_tts": "<|task_controllable_tts|>",
33
- # Add other tasks if needed by processor logic
34
- }
35
- LEVELS_MAP = {"very_low": 0, "low": 1, "moderate": 2, "high": 3, "very_high": 4}
36
- GENDER_MAP = {"female": 0, "male": 1}
37
- # --- End Token Maps ---
38
-
39
-
40
- class SparkTTSProcessor(ProcessorMixin):
41
- r"""
42
- Constructs a SparkTTS processor which wraps a text tokenizer and an audio feature extractor
43
- into a single processor.
44
-
45
- [`SparkTTSProcessor`] offers all the functionalities of [`AutoTokenizer`] and [`Wav2Vec2FeatureExtractor`].
46
- It processes text input for the LLM and prepares audio inputs if needed (delegating actual audio tokenization
47
- to the model). It also handles decoding the final output.
48
-
49
- Args:
50
- tokenizer (`PreTrainedTokenizerBase`):
51
- An instance of [`AutoTokenizer`]. The tokenizer is used to encode the prompt text.
52
- feature_extractor (`Wav2Vec2FeatureExtractor`):
53
- An instance of [`Wav2Vec2FeatureExtractor`]. The feature extractor is used to processor reference audio
54
- (though the main processing happens inside the model).
55
- model (`PreTrainedModel`, *optional*):
56
- A reference to the loaded `SparkTTSModel`. This is REQUIRED for voice cloning (prompt audio processing)
57
- and final audio decoding, as these steps rely on the model's internal BiCodec and Wav2Vec2 components.
58
- Set this using `processor.model = model` after loading both.
59
- config (`SparkTTSConfig`, *optional*):
60
- The configuration object, needed for parameters like sample_rate. Can often be inferred from the model.
61
- """
62
- attributes = ["tokenizer", "feature_extractor"]
63
- tokenizer_class = ("Qwen2TokenizerFast", "Qwen2Tokenizer") # Specify the underlying tokenizer type
64
- feature_extractor_class = ("Wav2Vec2FeatureExtractor",) # Specify the underlying feature extractor type
65
-
66
- def __init__(self, tokenizer=None, feature_extractor=None, model=None, config=None, **kwargs):
67
- if tokenizer is None:
68
- raise ValueError("SparkTTSProcessor requires a `tokenizer`.")
69
- if feature_extractor is None:
70
- # Attempt to load default if path is known or provide clearer error
71
- raise ValueError("SparkTTSProcessor requires a `feature_extractor` (Wav2Vec2FeatureExtractor).")
72
-
73
- super().__init__(tokenizer, feature_extractor)
74
- self.model = model # Store model reference (can be None initially)
75
- self.config = config # Store config reference
76
-
77
- # Get sampling rate from config if available
78
- self.sampling_rate = None
79
- if self.config and hasattr(self.config, 'sample_rate'):
80
- self.sampling_rate = self.config.sample_rate
81
- elif self.model and hasattr(self.model, 'config') and hasattr(self.model.config, 'sample_rate'):
82
- self.sampling_rate = self.model.config.sample_rate
83
- else:
84
- # Try feature extractor default, or raise warning
85
- if hasattr(self.feature_extractor, 'sampling_rate'):
86
- self.sampling_rate = self.feature_extractor.sampling_rate
87
- else:
88
- logger.warning("Could not determine sampling rate. Defaulting to 16000. Set `processor.sampling_rate` manually if needed.")
89
- self.sampling_rate = 16000
90
-
91
-
92
- @classmethod
93
- def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
94
- """
95
- Instantiate a [`SparkTTSProcessor`] from a pretrained processor configuration.
96
-
97
- Args:
98
- pretrained_model_name_or_path (`str` or `os.PathLike`):
99
- This can be either:
100
- - a string, the *model id* of a pretrained processor hosted inside a model repo on huggingface.co.
101
- - a path to a *directory* containing processor files saved using the `save_pretrained()` method,
102
- e.g., `./my_model_directory/`.
103
- **kwargs:
104
- Additional keyword arguments passed along to both `AutoTokenizer.from_pretrained()` and
105
- `AutoFeatureExtractor.from_pretrained()`.
106
- """
107
- config = kwargs.pop("config", None)
108
- if config is None:
109
- # Try loading the specific config first
110
- try:
111
- config = SparkTTSConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
112
- except Exception:
113
- logger.warning(f"Could not load SparkTTSConfig from {pretrained_model_name_or_path}. Processor might lack some config values.")
114
- config = None
115
-
116
-
117
- # Resolve component paths relative to the main path
118
- def _resolve_path(sub_path):
119
- p = Path(sub_path)
120
- if p.is_absolute():
121
- return str(p)
122
- # Try resolving relative to the main path if it's a directory
123
- main_path = Path(pretrained_model_name_or_path)
124
- if main_path.is_dir():
125
- resolved = main_path / p
126
- if resolved.exists():
127
- return str(resolved)
128
- # Fallback to assuming sub_path is relative within a repo structure (might fail for local non-dirs)
129
- return sub_path
130
-
131
- # Determine paths from config or assume defaults
132
- llm_tokenizer_path = "./LLM"
133
- w2v_processor_path = "./wav2vec2-large-xlsr-53"
134
- if config:
135
- llm_tokenizer_path = getattr(config, 'llm_model_name_or_path', llm_tokenizer_path)
136
- w2v_processor_path = getattr(config, 'wav2vec2_model_name_or_path', w2v_processor_path)
137
-
138
- resolved_tokenizer_path = _resolve_path(llm_tokenizer_path)
139
- resolved_w2v_path = _resolve_path(w2v_processor_path)
140
-
141
- try:
142
- tokenizer = AutoTokenizer.from_pretrained(resolved_tokenizer_path, **kwargs)
143
- except Exception as e:
144
- raise OSError(f"Could not load tokenizer from {resolved_tokenizer_path}. Ensure path is correct and files exist. Original error: {e}")
145
-
146
- try:
147
- feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(resolved_w2v_path, **kwargs)
148
- except Exception as e:
149
- raise OSError(f"Could not load feature extractor from {resolved_w2v_path}. Ensure path is correct and files exist. Original error: {e}")
150
-
151
- # The 'model' attribute will be set later externally
152
- return cls(tokenizer=tokenizer, feature_extractor=feature_extractor, config=config)
153
-
154
-
155
- def __call__(self, text: str = None,
156
- prompt_speech_path: Optional[str] = None,
157
- prompt_text: Optional[str] = None,
158
- gender: Optional[str] = None,
159
- pitch: Optional[str] = None,
160
- speed: Optional[str] = None,
161
- return_tensors: Optional[str] = "pt",
162
- **kwargs) -> BatchEncoding:
163
- """
164
- Main method to process inputs for the SparkTTS model.
165
-
166
- Args:
167
- text (`str`): The text to be synthesized.
168
- prompt_speech_path (`str`, *optional*): Path to prompt audio for voice cloning.
169
- prompt_text (`str`, *optional*): Transcript of prompt audio.
170
- gender (`str`, *optional*): Target gender ('male' or 'female') for voice creation.
171
- pitch (`str`, *optional*): Target pitch level ('very_low'...'very_high') for voice creation.
172
- speed (`str`, *optional*): Target speed level ('very_low'...'very_high') for voice creation.
173
- return_tensors (`str`, *optional*, defaults to `"pt"`):
174
- Framework of the returned tensors (`"pt"` for PyTorch, `"np"` for NumPy).
175
- **kwargs: Additional arguments (currently ignored).
176
-
177
- Returns:
178
- `BatchEncoding`: A dictionary containing the `input_ids`, `attention_mask`, and optionally
179
- `global_token_ids_prompt` ready for the model's `.generate()` method.
180
- """
181
- if text is None:
182
- raise ValueError("`text` input must be provided.")
183
-
184
- global_token_ids_prompt = None
185
- llm_prompt_string = ""
186
-
187
- if prompt_speech_path is not None:
188
- # --- Voice Cloning Mode ---
189
- if self.model is None:
190
- raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for voice cloning.")
191
- if not hasattr(self.model, '_tokenize_audio'):
192
- raise AttributeError("The provided model object does not have the required '_tokenize_audio' method.")
193
-
194
- logger.info(f"Processing prompt audio: {prompt_speech_path}")
195
- # Delegate audio tokenization to the model
196
- try:
197
- # _tokenize_audio returns (global_tokens, semantic_tokens)
198
- global_tokens, semantic_tokens = self.model._tokenize_audio(prompt_speech_path)
199
- global_token_ids_prompt = global_tokens # Keep for decoding stage
200
- except Exception as e:
201
- logger.error(f"Error tokenizing prompt audio: {e}", exc_info=True)
202
- raise RuntimeError(f"Failed to process prompt audio file: {prompt_speech_path}. Check file integrity and model compatibility.") from e
203
-
204
- # Format prompt string using token maps
205
- global_tokens_str = "".join([f"<|bicodec_global_{i}|>" for i in global_tokens.squeeze().tolist()])
206
-
207
- if prompt_text and len(prompt_text) > 1:
208
- semantic_tokens_str = "".join([f"<|bicodec_semantic_{i}|>" for i in semantic_tokens.squeeze().tolist()])
209
- llm_prompt_parts = [
210
- TASK_TOKEN_MAP["tts"], "<|start_content|>", prompt_text, text, "<|end_content|>",
211
- "<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
212
- "<|start_semantic_token|>", semantic_tokens_str,
213
- ]
214
- else:
215
- llm_prompt_parts = [
216
- TASK_TOKEN_MAP["tts"], "<|start_content|>", text, "<|end_content|>",
217
- "<|start_global_token|>", global_tokens_str, "<|end_global_token|>",
218
- ]
219
- llm_prompt_string = "".join(llm_prompt_parts)
220
-
221
- elif gender is not None and pitch is not None and speed is not None:
222
- # --- Voice Creation Mode ---
223
- if gender not in GENDER_MAP: raise ValueError(f"Invalid gender '{gender}'.")
224
- if pitch not in LEVELS_MAP: raise ValueError(f"Invalid pitch '{pitch}'.")
225
- if speed not in LEVELS_MAP: raise ValueError(f"Invalid speed '{speed}'.")
226
-
227
- gender_id = GENDER_MAP[gender]
228
- pitch_level_id = LEVELS_MAP[pitch]
229
- speed_level_id = LEVELS_MAP[speed]
230
-
231
- attribute_tokens = f"<|gender_{gender_id}|><|pitch_label_{pitch_level_id}|><|speed_label_{speed_level_id}|>"
232
-
233
- llm_prompt_parts = [
234
- TASK_TOKEN_MAP["controllable_tts"], "<|start_content|>", text, "<|end_content|>",
235
- "<|start_style_label|>", attribute_tokens, "<|end_style_label|>",
236
- ]
237
- llm_prompt_string = "".join(llm_prompt_parts)
238
- # No global_token_ids_prompt needed
239
-
240
- else:
241
- raise ValueError("Processor requires either 'prompt_speech_path' (for cloning) or 'gender', 'pitch', and 'speed' (for creation).")
242
-
243
- # Tokenize the final LLM prompt string
244
- inputs = self.tokenizer(llm_prompt_string, return_tensors=return_tensors, padding=False, truncation=False)
245
-
246
- # Add prompt global tokens to the output if they exist (for passing to decode)
247
- if global_token_ids_prompt is not None:
248
- inputs["global_token_ids_prompt"] = global_token_ids_prompt
249
-
250
- return inputs
251
-
252
- def decode(self,
253
- generated_ids: Union[List[int], np.ndarray, torch.Tensor],
254
- global_token_ids_prompt: Optional[torch.Tensor] = None,
255
- input_ids_len: Optional[int] = None,
256
- skip_special_tokens: bool = True) -> Dict[str, Any]:
257
- """
258
- Decodes the raw token IDs generated by the model into an audio waveform.
259
-
260
- Args:
261
- generated_ids (`Union[List[int], np.ndarray, torch.Tensor]`):
262
- The token IDs generated by the `model.generate()` method. Assumed to be a single sequence (batch size 1).
263
- global_token_ids_prompt (`torch.Tensor`, *optional*):
264
- The global tokens obtained from the prompt audio during preprocessing (needed for voice cloning).
265
- Should be passed from the `__call__` output.
266
- input_ids_len (`int`, *optional*):
267
- The length of the original prompt `input_ids`. If provided, the prompt part will be stripped from
268
- `generated_ids` before decoding the text representation. If None, assumes `generated_ids` contains
269
- *only* the generated part.
270
- skip_special_tokens (`bool`, *optional*, defaults to `True`):
271
- Whether to skip special tokens when decoding the text representation for parsing.
272
-
273
- Returns:
274
- `Dict[str, Any]`: A dictionary containing:
275
- - `audio` (`np.ndarray`): The generated audio waveform.
276
- - `sampling_rate` (`int`): The sampling rate of the audio.
277
- """
278
- if self.model is None:
279
- raise ValueError("Processor requires a loaded `model` reference (`processor.model = model`) for decoding.")
280
- if not hasattr(self.model, '_detokenize_audio'):
281
- raise AttributeError("The provided model object does not have the required '_detokenize_audio' method.")
282
- if self.sampling_rate is None:
283
- raise ValueError("Processor could not determine sampling_rate. Set `processor.sampling_rate`.")
284
-
285
- # Ensure generated_ids is a tensor on the correct device
286
- if isinstance(generated_ids, (list, np.ndarray)):
287
- output_ids_tensor = torch.tensor(generated_ids)
288
- else:
289
- output_ids_tensor = generated_ids
290
-
291
- # Remove prompt if input_ids_len is provided
292
- if input_ids_len is not None:
293
- # Handle potential batch dimension if present (though usually not for decode)
294
- if output_ids_tensor.ndim > 1:
295
- output_ids = output_ids_tensor[0, input_ids_len:]
296
- else:
297
- output_ids = output_ids_tensor[input_ids_len:]
298
- else:
299
- if output_ids_tensor.ndim > 1:
300
- output_ids = output_ids_tensor[0]
301
- else:
302
- output_ids = output_ids_tensor
303
-
304
- if output_ids.numel() == 0:
305
- logger.warning("Received empty generated IDs after removing prompt. Returning empty audio.")
306
- return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
307
-
308
- # Decode the text representation to parse tokens
309
- predicts_text = self.tokenizer.decode(output_ids, skip_special_tokens=skip_special_tokens)
310
-
311
- # Extract semantic tokens
312
- semantic_matches = re.findall(r"bicodec_semantic_(\d+)", predicts_text)
313
- if not semantic_matches:
314
- logger.warning("No semantic tokens found in the generated output text. Cannot synthesize audio.")
315
- return {"audio": np.array([], dtype=np.float32), "sampling_rate": self.sampling_rate}
316
- # Use model's device for tensors
317
- device = self.model.device
318
- pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches], dtype=torch.long, device=device).unsqueeze(0) # Add batch dim
319
-
320
- # Determine global tokens
321
- if global_token_ids_prompt is not None:
322
- # Voice Cloning: Use prompt global tokens
323
- global_token_ids = global_token_ids_prompt.to(device)
324
- # Ensure correct shape (B, T_token, Q) or (B, D) - BiCodec detokenize needs to handle this
325
- if global_token_ids.ndim == 2: # If (B, D), maybe unsqueeze? Check BiCodec.detokenize expectation
326
- global_token_ids = global_token_ids.unsqueeze(1) # Assume (B, 1, D) might be needed
327
- else:
328
- # Voice Creation: Parse global tokens from generated text
329
- global_matches = re.findall(r"bicodec_global_(\d+)", predicts_text)
330
- if not global_matches:
331
- logger.error("Voice creation failed: No global tokens found in generated text.")
332
- raise ValueError("Voice creation failed: Could not find bicodec_global tokens in the LLM output.")
333
- global_token_ids = torch.tensor([int(token) for token in global_matches], dtype=torch.long, device=device).unsqueeze(0) # Add batch dim
334
- # Add sequence dimension if needed (check BiCodec.detokenize)
335
- if global_token_ids.ndim == 2:
336
- global_token_ids = global_token_ids.unsqueeze(1) # Assume (B, 1, D)
337
-
338
- # Detokenize audio using the model's method
339
- try:
340
- wav_np = self.model._detokenize_audio(global_token_ids, pred_semantic_ids)
341
- except Exception as e:
342
- logger.error(f"Error during audio detokenization: {e}", exc_info=True)
343
- raise RuntimeError("Failed to synthesize audio waveform from generated tokens.") from e
344
-
345
- return {"audio": wav_np, "sampling_rate": self.sampling_rate}